Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
b3833972
Commit
b3833972
authored
Nov 12, 2025
by
wenjh
Browse files
Sync All on groupedgemm.
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
66bd0b32
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
10 deletions
+10
-10
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+10
-10
No files found.
transformer_engine/common/gemm/rocm_gemm.cu
View file @
b3833972
...
...
@@ -1284,11 +1284,11 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid.
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
hipblaslt_ext
::
UserArguments
*
userArgs
=
get_hipblaslt_user_args
(
m
.
size
(),
true
);
hipblaslt_ext
::
UserArguments
*
d_userArgs
=
get_hipblaslt_user_args
(
m
.
size
(),
false
);
//
hipblaslt_ext::UserArguments* userArgs = get_hipblaslt_user_args(m.size(), true);
//
hipblaslt_ext::UserArguments* d_userArgs = get_hipblaslt_user_args(m.size(), false);
//
hipblaslt_ext::UserArguments* userArgs;
//
NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
hipblaslt_ext
::
UserArguments
*
userArgs
;
NVTE_CHECK_CUDA
(
hipHostMalloc
(
&
userArgs
,
m
.
size
()
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
hipblasLtHandle_t
handle
=
hipBlasLtHandleManager
::
Instance
().
GetHandle
();
...
...
@@ -1347,17 +1347,17 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Get the default values from the grouepdgemm object
groupedgemm
.
getDefaultValueForDeviceUserArguments
(
userArgs
);
// Copy them to device memory
//
hipblaslt_ext::UserArguments* d_userArgs;
//
NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
NVTE_CHECK_CUDA
(
hipMemcpy
(
d_userArgs
,
userArgs
,
m
.
size
()
*
sizeof
(
hipblaslt_ext
::
UserArguments
),
hipMemcpyHostToDevice
));
hipblaslt_ext
::
UserArguments
*
d_userArgs
;
NVTE_CHECK_CUDA
(
hipMallocAsync
(
&
d_userArgs
,
m
.
size
()
*
sizeof
(
hipblaslt_ext
::
UserArguments
),
stream
));
NVTE_CHECK_CUDA
(
hipMemcpy
Async
(
d_userArgs
,
userArgs
,
m
.
size
()
*
sizeof
(
hipblaslt_ext
::
UserArguments
),
hipMemcpyHostToDevice
)
,
stream
);
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
d_userArgs
,
stream
));
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
//
NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
//
NVTE_CHECK_CUDA(hipFree(userArgs));
NVTE_CHECK_CUDA
(
hipFreeAsync
(
d_userArgs
,
stream
));
NVTE_CHECK_CUDA
(
hipFree
(
userArgs
));
}
#endif //USE_HIPBLASLT
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment