Commit dcb27bde authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.5'

parents 0a0d309c 8665c111
...@@ -1603,6 +1603,8 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1603,6 +1603,8 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
return; return;
} }
// Make sure to initialize everytime the algo changes
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
// Get the default values from the grouepdgemm object // Get the default values from the grouepdgemm object
groupedgemm.getDefaultValueForDeviceUserArguments(userArgs); groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
// Copy them to device memory // Copy them to device memory
...@@ -1612,9 +1614,6 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1612,9 +1614,6 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
userArgs, userArgs,
m.size() * sizeof(hipblaslt_ext::UserArguments), m.size() * sizeof(hipblaslt_ext::UserArguments),
hipMemcpyHostToDevice, stream)); hipMemcpyHostToDevice, stream));
// Make sure to initialize everytime the algo changes
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream)); NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream)); // NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream)); // NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment