Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
a9601800
Commit
a9601800
authored
Dec 19, 2025
by
wenjh
Browse files
Fix build error
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
5cf21c3b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
5 deletions
+6
-5
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+6
-5
No files found.
transformer_engine/common/gemm/rocm_gemm.cu
View file @
a9601800
...
...
@@ -893,7 +893,7 @@ static void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) {
}
static
void
DestroyHipBlasLtHandle
(
hipblasLtHandle_t
handle
)
{
if
(
handle
!=
nullptr
)
if
(
handle
!=
nullptr
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtDestroy
(
handle
));
}
}
...
...
@@ -1391,7 +1391,7 @@ struct HipBlasltUserArgsCache
{
HipBlasltUserArgsCache
()
{}
HipBlasltUserArgsCache
(
const
HipBlasltUserArgsCache
&
)
=
delete
;
HipBlasltUserArgs
Buffer
&
operator
=
(
const
HipBlasltUserArgs
Buffer
&
)
=
delete
;
HipBlasltUserArgs
Cache
&
operator
=
(
const
HipBlasltUserArgs
Cache
&
)
=
delete
;
HipBlasltUserArgsBuffer
&
getBuffer
(
hipStream_t
stream
,
size_t
size
,
bool
host
)
{
std
::
unordered_map
<
size_t
,
HipBlasltUserArgsBuffer
>&
buffers
=
host
?
host_buffers_
:
device_buffers_
;
...
...
@@ -1524,13 +1524,14 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
DType
input_type
=
inputB
[
0
]
->
data
.
dtype
;
DType
bias_type
=
bias
[
0
]
->
data
.
dtype
;
NVTE_CHECK
(
bias_type
==
DType
::
kFloat32
||
bias_type
==
DType
::
kFloat16
||
bias_type
==
DType
::
kBFloat16
);
for
(
int
i
=
0
;
i
<
m
.
size
();
++
i
)
{
void
*
input_ptr
=
inputB
[
i
]
->
data
.
dptr
;
void
*
bias_ptr
=
bias
[
i
]
->
data
.
dptr
;
batch_size
=
k
[
i
];
output_dim
=
n
[
i
];
int
batch_size
=
static_cast
<
int
>
(
k
[
i
]
)
;
int
output_dim
=
static_cast
<
int
>
(
n
[
i
]
)
;
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
input_
d
type
,
IType
,
input_type
,
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
bias_type
,
OType
,
detail
::
bias_gradient_kernelLauncher
<
IType
,
OType
>
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment