Commit b3833972 authored by wenjh's avatar wenjh
Browse files

Sync All on groupedgemm.


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 66bd0b32
......@@ -1284,11 +1284,11 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
hipblaslt_ext::UserArguments* userArgs = get_hipblaslt_user_args(m.size(), true);
hipblaslt_ext::UserArguments* d_userArgs = get_hipblaslt_user_args(m.size(), false);
// hipblaslt_ext::UserArguments* userArgs = get_hipblaslt_user_args(m.size(), true);
// hipblaslt_ext::UserArguments* d_userArgs = get_hipblaslt_user_args(m.size(), false);
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
hipblaslt_ext::UserArguments* userArgs;
NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
hipblasLtHandle_t handle = hipBlasLtHandleManager::Instance().GetHandle();
......@@ -1347,17 +1347,17 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Get the default values from the grouepdgemm object
groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
// Copy them to device memory
// hipblaslt_ext::UserArguments* d_userArgs;
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
NVTE_CHECK_CUDA(hipMemcpy(d_userArgs, userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments),
hipMemcpyHostToDevice));
hipblaslt_ext::UserArguments* d_userArgs;
NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs, userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments),
hipMemcpyHostToDevice), stream);
NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
// NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
// NVTE_CHECK_CUDA(hipFree(userArgs));
NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
NVTE_CHECK_CUDA(hipFree(userArgs));
}
#endif //USE_HIPBLASLT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment