Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
e698a0a7
Commit
e698a0a7
authored
Dec 15, 2025
by
wenjh
Browse files
Fix blaslt group gemm crush
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
4086a4cc
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
213 additions
and
59 deletions
+213
-59
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+2
-1
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+211
-58
No files found.
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
e698a0a7
...
@@ -1367,6 +1367,7 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
...
@@ -1367,6 +1367,7 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
bool
use_split_accumulator
,
int
math_sm_count
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
if
(
num_gemms
==
0
)
{
return
;
}
std
::
vector
<
const
Tensor
*>
inputA
;
std
::
vector
<
const
Tensor
*>
inputA
;
std
::
vector
<
const
Tensor
*>
inputB
;
std
::
vector
<
const
Tensor
*>
inputB
;
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
e698a0a7
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include <sstream>
#include <sstream>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "../util/hip_runtime.h"
#endif
#endif
#ifdef USE_ROCBLAS
#ifdef USE_ROCBLAS
...
@@ -887,11 +888,17 @@ static inline int getIntEnv(const char* name, int defval, int minval) {
...
@@ -887,11 +888,17 @@ static inline int getIntEnv(const char* name, int defval, int minval) {
}
//namespace
}
//namespace
static
inline
void
CreateHipBlasLtHandle
(
hipblasLtHandle_t
*
handle
)
{
static
void
CreateHipBlasLtHandle
(
hipblasLtHandle_t
*
handle
)
{
NVTE_CHECK_HIPBLASLT
(
hipblasLtCreate
(
handle
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtCreate
(
handle
));
}
}
using
hipBlasLtHandleManager
=
detail
::
HandleManager
<
hipblasLtHandle_t
,
CreateHipBlasLtHandle
>
;
static
void
DestroyHipBlasLtHandle
(
hipblasLtHandle_t
handle
)
{
if
(
handle
!=
nullptr
)
NVTE_CHECK_HIPBLASLT
(
hipblasLtDestroy
(
handle
));
}
}
using
hipBlasLtHandleManager
=
detail
::
HandleManager
<
hipblasLtHandle_t
,
CreateHipBlasLtHandle
,
DestroyHipBlasLtHandle
>
;
transformer_engine
::
DType
get_transformer_engine_dtype_from_hipblaslt_dtype
(
const
hipDataType
t
)
{
transformer_engine
::
DType
get_transformer_engine_dtype_from_hipblaslt_dtype
(
const
hipDataType
t
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
@@ -1240,40 +1247,183 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
...
@@ -1240,40 +1247,183 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
}
}
struct
HipBlasLtUserArgsDeleter
{
struct
HipBlasltUserArgs
void
operator
()(
hipblaslt_ext
::
UserArguments
*
ptr
)
const
noexcept
{
{
hipFree
(
ptr
);
HipBlasltUserArgs
()
:
stream_
(
nullptr
),
raw_
(
nullptr
),
event_
(
nullptr
)
{}
HipBlasltUserArgs
(
hipStream_t
stream
,
size_t
size
,
bool
host
)
:
stream_
(
stream
),
raw_
(
nullptr
),
event_
(
nullptr
)
{
hipblaslt_ext
::
UserArguments
*
raw_ptr
=
nullptr
;
if
(
host
)
{
NVTE_CHECK_CUDA
(
hipHostMalloc
(
&
raw_ptr
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
}
else
{
NVTE_CHECK_CUDA
(
hipMalloc
(
&
raw_ptr
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
}
raw_
=
raw_ptr
;
hipEvent_t
event
=
nullptr
;
if
(
host
)
{
NVTE_CHECK_CUDA
(
hipEventCreateWithFlags
(
&
event
,
hipEventBlockingSync
));
}
else
{
NVTE_CHECK_CUDA
(
hipEventCreateWithFlags
(
&
event
,
hipEventDisableTiming
));
}
event_
=
event
;
}
HipBlasltUserArgs
(
const
HipBlasltUserArgs
&
)
=
delete
;
HipBlasltUserArgs
(
HipBlasltUserArgs
&&
other
)
{
stream_
=
other
.
stream_
;
raw_
=
other
.
raw_
;
event_
=
other
.
event_
;
other
.
stream_
=
nullptr
;
other
.
raw_
=
nullptr
;
other
.
event_
=
nullptr
;
}
HipBlasltUserArgs
&
operator
=
(
const
HipBlasltUserArgs
&
)
=
delete
;
HipBlasltUserArgs
&
operator
=
(
HipBlasltUserArgs
&&
other
)
{
if
(
this
!=
&
other
)
{
free
();
stream_
=
other
.
stream_
;
raw_
=
other
.
raw_
;
event_
=
other
.
event_
;
other
.
stream_
=
nullptr
;
other
.
raw_
=
nullptr
;
other
.
event_
=
nullptr
;
}
return
*
this
;
}
inline
hipStream_t
getStream
()
const
noexcept
{
return
stream_
;
}
inline
hipblaslt_ext
::
UserArguments
*
getArgs
()
const
noexcept
{
return
raw_
;
}
inline
hipEvent_t
getEvent
()
const
noexcept
{
return
event_
;
}
}
inline
void
setStream
(
hipStream_t
stream
)
noexcept
{
stream_
=
stream
;
}
~
HipBlasltUserArgs
()
{
free
();
}
private:
void
free
()
{
if
(
raw_
)
{
if
(
event_
)
{
NVTE_CHECK_CUDA
(
hipEventSynchronize
(
event_
));
NVTE_CHECK_CUDA
(
hipEventDestroy
(
event_
));
event_
=
nullptr
;
}
NVTE_CHECK_CUDA
(
hipFree
(
raw_
));
raw_
=
nullptr
;
}
}
hipStream_t
stream_
;
hipblaslt_ext
::
UserArguments
*
raw_
;
hipEvent_t
event_
;
};
};
using
HipBlasLtUserArgsPtr
=
std
::
unique_ptr
<
hipblaslt_ext
::
UserArguments
,
HipBlasLtUserArgsDeleter
>
;
struct
HipBlasltUserArgsBuffer
{
HipBlasltUserArgsBuffer
()
{}
HipBlasltUserArgsBuffer
(
hipStream_t
stream
,
size_t
size
,
bool
host
)
{
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
buffer_
[
i
]
=
std
::
move
(
HipBlasltUserArgs
(
stream
,
size
,
host
));
}
}
HipBlasltUserArgsBuffer
(
const
HipBlasltUserArgsBuffer
&
)
=
delete
;
HipBlasltUserArgsBuffer
(
HipBlasltUserArgsBuffer
&&
other
)
{
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
buffer_
[
i
]
=
std
::
move
(
other
.
buffer_
[
i
]);
}
index_
=
other
.
index_
;
}
HipBlasltUserArgsBuffer
&
operator
=
(
const
HipBlasltUserArgsBuffer
&
)
=
delete
;
HipBlasltUserArgsBuffer
&
operator
=
(
HipBlasltUserArgsBuffer
&&
other
)
{
if
(
this
!=
&
other
)
{
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
buffer_
[
i
]
=
std
::
move
(
other
.
buffer_
[
i
]);
}
index_
=
other
.
index_
;
}
return
*
this
;
}
HipBlasltUserArgs
&
getUserArgs
()
{
HipBlasltUserArgs
&
args
=
buffer_
[
index_
];
inline
HipBlasLtUserArgsPtr
make_hipblaslt_user_args_ptr
(
size_t
size
,
bool
host
)
{
if
(
index_
<
3
)
hipblaslt_ext
::
UserArguments
*
raw_ptr
=
nullptr
;
{
if
(
host
)
{
++
index_
;
NVTE_CHECK_CUDA
(
hipHostMalloc
(
&
raw_ptr
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
}
else
{
NVTE_CHECK_CUDA
(
hipMalloc
(
&
raw_ptr
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
}
}
return
HipBlasLtUserArgsPtr
(
raw_ptr
);
else
}
{
index_
=
0
;
}
return
args
;
}
private:
int
index_
=
0
;
HipBlasltUserArgs
buffer_
[
4
];
};
// using HipBlasltUserArgsBufferPtr = std::unique_ptr<HipBlasltUserArgsBuffer>;
inline
hipblaslt_ext
::
UserArguments
*
get_hipblaslt_user_args
(
size_t
size
,
bool
host
)
{
struct
HipBlasltUserArgsCache
thread_local
static
std
::
unordered_map
<
size_t
,
HipBlasLtUserArgsPtr
>
host_userargs_cache
;
{
thread_local
static
std
::
unordered_map
<
size_t
,
HipBlasLtUserArgsPtr
>
device_userargs_cache
;
HipBlasltUserArgsCache
()
{}
std
::
unordered_map
<
size_t
,
HipBlasLtUserArgsPtr
>&
user_args_cache
=
host
?
host_userargs_cache
:
device_userargs_cache
;
HipBlasltUserArgsCache
(
const
HipBlasltUserArgsCache
&
)
=
delete
;
auto
size_it
=
user_args_cache
.
find
(
size
);
HipBlasltUserArgsBuffer
&
operator
=
(
const
HipBlasltUserArgsBuffer
&
)
=
delete
;
if
(
size_it
!=
user_args_cache
.
end
())
{
HipBlasltUserArgsBuffer
&
getBuffer
(
hipStream_t
stream
,
size_t
size
,
bool
host
)
return
size_it
->
second
.
get
();
{
std
::
unordered_map
<
size_t
,
HipBlasltUserArgsBuffer
>&
buffers
=
host
?
host_buffers_
:
device_buffers_
;
auto
size_it
=
buffers
.
find
(
size
);
if
(
size_it
!=
buffers
.
end
())
{
return
size_it
->
second
;
}
}
else
else
{
{
HipBlasLtUserArgsPtr
user_args
=
make_hipblaslt_user_args_ptr
(
size
,
host
);
return
buffers
.
emplace
(
size
,
HipBlasltUserArgsBuffer
{
stream
,
size
,
host
}).
first
->
second
;
hipblaslt_ext
::
UserArguments
*
raw_ptr
=
user_args
.
get
();
user_args_cache
[
size
]
=
std
::
move
(
user_args
);
return
raw_ptr
;
}
}
}
}
private:
std
::
unordered_map
<
size_t
,
HipBlasltUserArgsBuffer
>
host_buffers_
;
std
::
unordered_map
<
size_t
,
HipBlasltUserArgsBuffer
>
device_buffers_
;
};
struct
HipBlasltUserArgsCacheManager
{
static
HipBlasltUserArgsCacheManager
&
instance
()
{
static
thread_local
HipBlasltUserArgsCacheManager
instance_
;
return
instance_
;
}
HipBlasltUserArgsCache
&
getCache
()
{
const
int
device_id
=
cuda
::
current_device
();
NVTE_CHECK
(
0
<=
device_id
&&
device_id
<
caches_
.
size
(),
"invalid CUDA device ID"
);
return
caches_
[
device_id
];
}
private:
HipBlasltUserArgsCacheManager
()
:
caches_
(
cuda
::
num_devices
())
{}
std
::
vector
<
HipBlasltUserArgsCache
>
caches_
;
};
void
hipblaslt_groupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
void
hipblaslt_groupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
...
@@ -1285,18 +1435,20 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
...
@@ -1285,18 +1435,20 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid.
// Check compute_stream_offset valid.
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
hipblaslt_ext
::
UserArguments
*
userArgs
=
get_hipblaslt_user_args
(
m
.
size
(),
true
);
hipblasLtHandle_t
handle
=
hipBlasLtHandleManager
::
Instance
().
GetHandle
();
hipblaslt_ext
::
UserArguments
*
d_userArgs
=
get_hipblaslt_user_args
(
m
.
size
(),
false
);
// hipblaslt_ext::UserArguments* userArgs;
HipBlasltUserArgs
&
device_user_args
=
HipBlasltUserArgsCacheManager
::
instance
().
getCache
().
getBuffer
(
stream
,
m
.
size
(),
false
).
getUserArgs
();
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
hipblaslt_ext
::
UserArguments
*
device_args
=
device_user_args
.
getArgs
();
hipEvent_t
device_event
=
device_user_args
.
getEvent
();
hipStream_t
device_stream
=
device_user_args
.
getStream
();
hipblasLtHandle_t
handle
=
hipBlasLtHandleManager
::
Instance
().
GetHandle
();
HipBlasltUserArgs
&
host_user_args
=
HipBlasltUserArgsCacheManager
::
instance
().
getCache
().
getBuffer
(
stream
,
m
.
size
(),
true
).
getUserArgs
();
hipblaslt_ext
::
UserArguments
*
host_args
=
host_user_args
.
getArgs
();
hipEvent_t
host_event
=
host_user_args
.
getEvent
();
const
hipDataType
A_type
=
get_hipblaslt_dtype
(
inputA
[
0
]
->
data
.
dtype
);
const
hipDataType
A_type
=
get_hipblaslt_dtype
(
inputA
[
0
]
->
data
.
dtype
);
const
hipDataType
B_type
=
get_hipblaslt_dtype
(
inputB
[
0
]
->
data
.
dtype
);
const
hipDataType
B_type
=
get_hipblaslt_dtype
(
inputB
[
0
]
->
data
.
dtype
);
const
hipDataType
D_type
=
get_hipblaslt_dtype
(
outputD
[
0
]
->
data
.
dtype
);
const
hipDataType
D_type
=
get_hipblaslt_dtype
(
outputD
[
0
]
->
data
.
dtype
);
hipblasComputeType_t
computeType
=
HIPBLAS_COMPUTE_32F
;
hipblasComputeType_t
computeType
=
HIPBLAS_COMPUTE_32F
;
float
one
=
1.0
;
float
one
=
1.0
;
...
@@ -1313,16 +1465,14 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
...
@@ -1313,16 +1465,14 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
computeType
=
HIPBLAS_COMPUTE_32I
;
computeType
=
HIPBLAS_COMPUTE_32I
;
}
}
hipblaslt_ext
::
GemmPreference
gemmPref
;
// No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
gemmPref
.
setMaxWorkspaceBytes
(
workspaceSize
);
std
::
vector
<
hipblaslt_ext
::
GemmEpilogue
>
epilogue
{
hipblaslt_ext
::
GemmEpilogue
()};
hipblaslt_ext
::
GroupedGemm
groupedgemm
(
handle
,
transa
,
transb
,
A_type
,
B_type
,
D_type
,
D_type
,
computeType
);
std
::
vector
<
hipblaslt_ext
::
GemmEpilogue
>
epilogue
{
hipblaslt_ext
::
GemmEpilogue
()};
// No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
std
::
vector
<
hipblaslt_ext
::
GemmInputs
>
inputs
(
m
.
size
());
std
::
vector
<
hipblaslt_ext
::
GemmInputs
>
inputs
(
m
.
size
());
for
(
int
i
=
0
;
i
<
m
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
m
.
size
();
i
++
)
{
assert
(
m
[
i
]
!=
0
);
assert
(
n
[
i
]
!=
0
);
assert
(
k
[
i
]
!=
0
);
assert
(
b
[
i
]
!=
0
);
inputs
[
i
].
a
=
inputA
[
i
]
->
data
.
dptr
;
inputs
[
i
].
a
=
inputA
[
i
]
->
data
.
dptr
;
inputs
[
i
].
b
=
inputB
[
i
]
->
data
.
dptr
;
inputs
[
i
].
b
=
inputB
[
i
]
->
data
.
dptr
;
inputs
[
i
].
c
=
outputD
[
i
]
->
data
.
dptr
;
inputs
[
i
].
c
=
outputD
[
i
]
->
data
.
dptr
;
...
@@ -1330,35 +1480,38 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
...
@@ -1330,35 +1480,38 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
inputs
[
i
].
alpha
=
use_int8
?
static_cast
<
void
*>
(
&
int_one
)
:
static_cast
<
void
*>
(
&
one
);
inputs
[
i
].
alpha
=
use_int8
?
static_cast
<
void
*>
(
&
int_one
)
:
static_cast
<
void
*>
(
&
one
);
inputs
[
i
].
beta
=
use_int8
?
static_cast
<
void
*>
(
&
int_beta
)
:
static_cast
<
void
*>
(
&
beta
);
inputs
[
i
].
beta
=
use_int8
?
static_cast
<
void
*>
(
&
int_beta
)
:
static_cast
<
void
*>
(
&
beta
);
}
}
// hipblaslt_ext::GemmEpilogue supports broadcasting
groupedgemm
.
setProblem
(
m
,
n
,
k
,
b
,
epilogue
,
inputs
);
const
int
request_solutions
=
1
;
const
int
request_solutions
=
1
;
std
::
vector
<
hipblasLtMatmulHeuristicResult_t
>
heuristicResult
;
std
::
vector
<
hipblasLtMatmulHeuristicResult_t
>
heuristicResult
;
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
algoGetHeuristic
(
request_solutions
,
gemmPref
,
heuristicResult
));
hipblaslt_ext
::
GemmPreference
gemmPref
;
gemmPref
.
setMaxWorkspaceBytes
(
0
);
hipblaslt_ext
::
GroupedGemm
groupedgemm
(
handle
,
transa
,
transb
,
A_type
,
B_type
,
D_type
,
D_type
,
computeType
);
// hipblaslt_ext::GemmEpilogue supports broadcasting
groupedgemm
.
setProblem
(
m
,
n
,
k
,
b
,
epilogue
,
inputs
);
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
algoGetHeuristic
(
request_solutions
,
gemmPref
,
heuristicResult
));
if
(
heuristicResult
.
empty
())
{
if
(
heuristicResult
.
empty
())
{
std
::
cerr
<<
"No valid solution found!"
<<
std
::
endl
;
std
::
cerr
<<
"No valid solution found!"
<<
std
::
endl
;
return
;
return
;
}
}
// Make sure to initialize everytime the algo changes
// Make sure to initialize everytime the algo changes
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
initialize
(
heuristicResult
[
0
].
algo
,
workspace
));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
initialize
(
heuristicResult
[
0
].
algo
,
nullptr
,
true
,
stream
));
NVTE_CHECK_CUDA
(
hipEventSynchronize
(
host_event
));
// Get the default values from the grouepdgemm object
// Get the default values from the grouepdgemm object
groupedgemm
.
getDefaultValueForDeviceUserArguments
(
userArgs
);
groupedgemm
.
getDefaultValueForDeviceUserArguments
(
host_args
);
if
(
stream
!=
device_stream
)
{
NVTE_CHECK_CUDA
(
hipStreamWaitEvent
(
stream
,
device_event
,
0
));
}
// Copy them to device memory
// Copy them to device memory
// hipblaslt_ext::UserArguments* d_userArgs;
NVTE_CHECK_CUDA
(
hipMemcpyAsync
(
device_args
,
host_args
,
m
.
size
()
*
sizeof
(
hipblaslt_ext
::
UserArguments
),
hipMemcpyHostToDevice
,
stream
));
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
NVTE_CHECK_CUDA
(
hipEventRecord
(
host_event
,
stream
));
NVTE_CHECK_CUDA
(
hipMemcpy
(
d_userArgs
,
userArgs
,
m
.
size
()
*
sizeof
(
hipblaslt_ext
::
UserArguments
),
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
device_args
,
stream
));
hipMemcpyHostToDevice
));
device_user_args
.
setStream
(
stream
);
NVTE_CHECK_CUDA
(
hipEventRecord
(
device_event
,
stream
));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
d_userArgs
,
stream
));
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
// NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
// NVTE_CHECK_CUDA(hipFree(userArgs));
}
}
#endif //USE_HIPBLASLT
#endif //USE_HIPBLASLT
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment