Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
1261da47
Commit
1261da47
authored
Dec 13, 2025
by
wenjh
Browse files
Complete fix blaslt group gemm dump
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
0a90777e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
98 additions
and
113 deletions
+98
-113
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+2
-0
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+96
-113
No files found.
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
1261da47
...
...
@@ -991,6 +991,7 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_grouped_gemm
);
if
(
num_gemms
==
0
)
{
return
;
}
using
namespace
transformer_engine
;
std
::
vector
<
const
Tensor
*>
inputA
;
...
...
@@ -1029,6 +1030,7 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
n
.
push_back
(
B0
);
}
}
Tensor
*
wspace
=
convertNVTETensorCheck
(
workspace
[
0
]);
if
((
biasTensor
[
0
]
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
[
0
]
->
data
.
dptr
!=
nullptr
))
{
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
1261da47
...
...
@@ -20,6 +20,7 @@
#include <sstream>
#include <unordered_map>
#include <vector>
#include "../util/hip_runtime.h"
#endif
#ifdef USE_ROCBLAS
...
...
@@ -1244,39 +1245,57 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
}
struct
HipBlaslt
Host
UserArgs
struct
HipBlasltUserArgs
{
HipBlaslt
Host
UserArgs
()
:
raw_
(
nullptr
),
event_
(
nullptr
)
{}
HipBlaslt
Host
UserArgs
(
size_t
size
)
:
raw_
(
nullptr
),
event_
(
nullptr
)
HipBlasltUserArgs
()
:
stream_
(
nullptr
),
raw_
(
nullptr
),
event_
(
nullptr
)
{}
HipBlasltUserArgs
(
hipStream_t
stream
,
size_t
size
,
bool
host
)
:
stream_
(
stream
),
raw_
(
nullptr
),
event_
(
nullptr
)
{
hipblaslt_ext
::
UserArguments
*
raw_ptr
=
nullptr
;
NVTE_CHECK_CUDA
(
hipHostMalloc
(
&
raw_ptr
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
if
(
host
)
{
NVTE_CHECK_CUDA
(
hipHostMalloc
(
&
raw_ptr
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
}
else
{
NVTE_CHECK_CUDA
(
hipMalloc
(
&
raw_ptr
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
}
raw_
=
raw_ptr
;
hipEvent_t
event
=
nullptr
;
NVTE_CHECK_CUDA
(
hipEventCreateWithFlags
(
&
event
,
hipEventBlockingSync
));
if
(
host
)
{
NVTE_CHECK_CUDA
(
hipEventCreateWithFlags
(
&
event
,
hipEventBlockingSync
));
}
else
{
NVTE_CHECK_CUDA
(
hipEventCreateWithFlags
(
&
event
,
hipEventDisableTiming
));
}
event_
=
event
;
}
HipBlaslt
Host
UserArgs
(
const
HipBlaslt
Host
UserArgs
&
)
=
delete
;
HipBlaslt
Host
UserArgs
(
HipBlaslt
Host
UserArgs
&&
other
)
HipBlasltUserArgs
(
const
HipBlasltUserArgs
&
)
=
delete
;
HipBlasltUserArgs
(
HipBlasltUserArgs
&&
other
)
{
stream_
=
other
.
stream_
;
raw_
=
other
.
raw_
;
event_
=
other
.
event_
;
other
.
stream_
=
nullptr
;
other
.
raw_
=
nullptr
;
other
.
event_
=
nullptr
;
}
HipBlaslt
Host
UserArgs
&
operator
=
(
const
HipBlaslt
Host
UserArgs
&
)
=
delete
;
HipBlaslt
Host
UserArgs
&
operator
=
(
HipBlaslt
Host
UserArgs
&&
other
)
HipBlasltUserArgs
&
operator
=
(
const
HipBlasltUserArgs
&
)
=
delete
;
HipBlasltUserArgs
&
operator
=
(
HipBlasltUserArgs
&&
other
)
{
if
(
this
!=
&
other
)
{
free
();
stream_
=
other
.
stream_
;
raw_
=
other
.
raw_
;
event_
=
other
.
event_
;
other
.
stream_
=
nullptr
;
other
.
raw_
=
nullptr
;
other
.
event_
=
nullptr
;
}
return
*
this
;
}
inline
hipStream_t
getStream
()
const
noexcept
{
return
stream_
;
}
inline
hipblaslt_ext
::
UserArguments
*
getArgs
()
const
noexcept
{
return
raw_
;
...
...
@@ -1285,12 +1304,16 @@ struct HipBlasltHostUserArgs
{
return
event_
;
}
~
HipBlasltHostUserArgs
()
inline
void
setStream
(
hipStream_t
stream
)
noexcept
{
stream_
=
stream
;
}
~
HipBlasltUserArgs
()
{
free
();
}
private:
inline
void
free
()
void
free
()
{
if
(
raw_
)
{
...
...
@@ -1304,34 +1327,35 @@ private:
raw_
=
nullptr
;
}
}
hipStream_t
stream_
;
hipblaslt_ext
::
UserArguments
*
raw_
;
hipEvent_t
event_
;
};
struct
HipBlaslt
Host
UserArgsBuffer
struct
HipBlasltUserArgsBuffer
{
HipBlaslt
Host
UserArgsBuffer
()
{}
HipBlaslt
Host
UserArgsBuffer
(
size_t
size
)
HipBlasltUserArgsBuffer
()
{}
HipBlasltUserArgsBuffer
(
hipStream_t
stream
,
size_t
size
,
bool
host
)
{
for
(
int
i
=
0
;
i
<
8
;
++
i
)
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
buffer_
[
i
]
=
std
::
move
(
HipBlaslt
Host
UserArgs
(
s
ize
));
buffer_
[
i
]
=
std
::
move
(
HipBlasltUserArgs
(
s
tream
,
size
,
host
));
}
}
HipBlaslt
Host
UserArgsBuffer
(
const
HipBlaslt
Host
UserArgsBuffer
&
)
=
delete
;
HipBlaslt
Host
UserArgsBuffer
(
HipBlaslt
Host
UserArgsBuffer
&&
other
)
{
for
(
int
i
=
0
;
i
<
8
;
++
i
)
HipBlasltUserArgsBuffer
(
const
HipBlasltUserArgsBuffer
&
)
=
delete
;
HipBlasltUserArgsBuffer
(
HipBlasltUserArgsBuffer
&&
other
)
{
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
buffer_
[
i
]
=
std
::
move
(
other
.
buffer_
[
i
]);
}
index_
=
other
.
index_
;
}
HipBlaslt
Host
UserArgsBuffer
&
operator
=
(
const
HipBlaslt
Host
UserArgsBuffer
&
)
=
delete
;
HipBlaslt
Host
UserArgsBuffer
&
operator
=
(
HipBlaslt
Host
UserArgsBuffer
&&
other
)
HipBlasltUserArgsBuffer
&
operator
=
(
const
HipBlasltUserArgsBuffer
&
)
=
delete
;
HipBlasltUserArgsBuffer
&
operator
=
(
HipBlasltUserArgsBuffer
&&
other
)
{
if
(
this
!=
&
other
)
{
for
(
int
i
=
0
;
i
<
8
;
++
i
)
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
buffer_
[
i
]
=
std
::
move
(
other
.
buffer_
[
i
]);
}
...
...
@@ -1339,11 +1363,11 @@ struct HipBlasltHostUserArgsBuffer
}
return
*
this
;
}
HipBlaslt
Host
UserArgs
&
get
Host
UserArgs
()
HipBlasltUserArgs
&
getUserArgs
()
{
HipBlaslt
Host
UserArgs
&
args
=
buffer_
[
index_
];
HipBlasltUserArgs
&
args
=
buffer_
[
index_
];
if
(
index_
<
7
)
if
(
index_
<
3
)
{
++
index_
;
}
...
...
@@ -1356,94 +1380,48 @@ struct HipBlasltHostUserArgsBuffer
}
private:
int
index_
=
0
;
HipBlaslt
Host
UserArgs
buffer_
[
8
];
HipBlasltUserArgs
buffer_
[
4
];
};
using
HipBlas
LtHos
tUserArgsBufferPtr
=
std
::
unique_ptr
<
HipBlaslt
Host
UserArgsBuffer
>
;
//
using HipBlas
l
tUserArgsBufferPtr = std::unique_ptr<HipBlasltUserArgsBuffer>;
HipBlasltHostUserArgsBuffer
*
get
HipBlas
LtHos
tUserArgs
Buffer
(
size_t
size
)
struct
HipBlas
l
tUserArgs
Cache
{
static
thread_local
std
::
unordered_map
<
size_t
,
HipBlasLtHostUserArgsBufferPtr
>
user_args_cache
;
auto
size_it
=
user_args_cache
.
find
(
size
);
if
(
size_it
!=
user_args_cache
.
end
())
{
return
size_it
->
second
.
get
();
}
else
HipBlasltUserArgsCache
()
{}
HipBlasltUserArgsCache
(
const
HipBlasltUserArgsCache
&
)
=
delete
;
HipBlasltUserArgsBuffer
&
operator
=
(
const
HipBlasltUserArgsBuffer
&
)
=
delete
;
HipBlasltUserArgsBuffer
&
getBuffer
(
hipStream_t
stream
,
size_t
size
,
bool
host
)
{
HipBlasLtHostUserArgsBufferPtr
user_args
(
new
HipBlasltHostUserArgsBuffer
(
size
));
HipBlasltHostUserArgsBuffer
*
raw_ptr
=
user_args
.
get
();
user_args_cache
[
size
]
=
std
::
move
(
user_args
);
return
raw_ptr
;
}
}
struct
HipBlasLtDeviceUserArgs
{
HipBlasLtDeviceUserArgs
()
:
stream_
(
nullptr
),
raw_
(
nullptr
)
{}
HipBlasLtDeviceUserArgs
(
hipStream_t
stream
,
size_t
size
)
:
stream_
(
stream
),
raw_
(
nullptr
)
{
hipblaslt_ext
::
UserArguments
*
raw_ptr
=
nullptr
;
NVTE_CHECK_CUDA
(
hipMalloc
(
&
raw_ptr
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
raw_
=
raw_ptr
;
}
HipBlasLtDeviceUserArgs
(
const
HipBlasLtDeviceUserArgs
&
)
=
delete
;
HipBlasLtDeviceUserArgs
(
HipBlasLtDeviceUserArgs
&&
other
)
{
stream_
=
other
.
stream_
;
raw_
=
other
.
raw_
;
other
.
stream_
=
nullptr
;
other
.
raw_
=
nullptr
;
}
HipBlasLtDeviceUserArgs
&
operator
=
(
const
HipBlasLtDeviceUserArgs
&
)
=
delete
;
HipBlasLtDeviceUserArgs
&
operator
=
(
HipBlasLtDeviceUserArgs
&&
other
)
{
if
(
this
!=
&
other
)
{
free
();
stream_
=
other
.
stream_
;
raw_
=
other
.
raw_
;
other
.
stream_
=
nullptr
;
other
.
raw_
=
nullptr
;
std
::
unordered_map
<
size_t
,
HipBlasltUserArgsBuffer
>&
buffers
=
host
?
host_buffers_
:
device_buffers_
;
auto
size_it
=
buffers
.
find
(
size
);
if
(
size_it
!=
buffers
.
end
())
{
return
size_it
->
second
;
}
return
*
this
;
}
inline
hipblaslt_ext
::
UserArguments
*
get
()
const
noexcept
{
return
raw_
;
}
~
HipBlasLtDeviceUserArgs
()
{
free
();
}
protected:
inline
void
free
()
{
if
(
raw_
)
else
{
NVTE_CHECK_CUDA
(
hipFreeAsync
(
raw_
,
stream_
));
raw_
=
nullptr
;
return
buffers
.
emplace
(
size
,
HipBlasltUserArgsBuffer
{
stream
,
size
,
host
}).
first
->
second
;
}
}
hipStream_t
stream_
;
hipblaslt_ext
::
UserArguments
*
raw_
;
private:
std
::
unordered_map
<
size_t
,
HipBlasltUserArgsBuffer
>
host_buffers_
;
std
::
unordered_map
<
size_t
,
HipBlasltUserArgsBuffer
>
device_buffers_
;
};
using
HipBlasLtDeviceUserArgsPtr
=
std
::
unique_ptr
<
HipBlasLtDeviceUserArgs
>
;
HipBlasLtDeviceUserArgs
*
getHipBlasLtDeviceUserArgs
(
hipStream_t
stream
,
size_t
size
)
{
static
thread_local
std
::
unordered_map
<
size_t
,
HipBlasLtDeviceUserArgsPtr
>
user_args_cache
;
auto
size_it
=
user_args_cache
.
find
(
size
);
if
(
size_it
!=
user_args_cache
.
end
())
{
return
size_it
->
second
.
get
();
struct
HipBlasltUserArgsCacheManager
{
static
HipBlasltUserArgsCacheManager
&
instance
()
{
static
thread_local
HipBlasltUserArgsCacheManager
instance_
;
return
instance_
;
}
else
{
HipBlasLtDeviceUserArgsPtr
user_args
(
new
HipBlasLtDeviceUserArgs
(
stream
,
size
));
HipBlasLtDeviceUserArgs
*
raw_ptr
=
user_args
.
get
();
user_args_cache
[
size
]
=
std
::
move
(
user_args
);
return
raw_ptr
;
HipBlasltUserArgsCache
&
getCache
()
{
const
int
device_id
=
cuda
::
current_device
();
NVTE_CHECK
(
0
<=
device_id
&&
device_id
<
caches_
.
size
(),
"invalid CUDA device ID"
);
return
caches_
[
device_id
];
}
}
private:
HipBlasltUserArgsCacheManager
()
:
caches_
(
cuda
::
num_devices
())
{}
std
::
vector
<
HipBlasltUserArgsCache
>
caches_
;
};
void
hipblaslt_goupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
std
::
vector
<
Tensor
*>&
outputD
,
std
::
vector
<
int64_t
>&
m
,
...
...
@@ -1456,13 +1434,14 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
hipblasLtHandle_t
handle
=
hipBlasLtHandleManager
::
Instance
().
GetHandle
();
HipBlasLtDeviceUserArgs
*
device_user_args
=
getHipBlasLtDeviceUserArgs
(
stream
,
m
.
size
());
hipblaslt_ext
::
UserArguments
*
d_userArgs
=
device_user_args
->
get
();
HipBlasltUserArgs
&
device_user_args
=
HipBlasltUserArgsCacheManager
::
instance
().
getCache
().
getBuffer
(
stream
,
m
.
size
(),
false
).
getUserArgs
();
hipblaslt_ext
::
UserArguments
*
device_args
=
device_user_args
.
getArgs
();
hipEvent_t
device_event
=
device_user_args
.
getEvent
();
hipStream_t
device_stream
=
device_user_args
.
getStream
();
HipBlasltHostUserArgsBuffer
*
host_user_args_buffer
=
getHipBlasLtHostUserArgsBuffer
(
m
.
size
());
HipBlasltHostUserArgs
&
host_user_args
=
host_user_args_buffer
->
getHostUserArgs
();
hipblaslt_ext
::
UserArguments
*
userArgs
=
host_user_args
.
getArgs
();
hipEvent_t
event
=
host_user_args
.
getEvent
();
HipBlasltUserArgs
&
host_user_args
=
HipBlasltUserArgsCacheManager
::
instance
().
getCache
().
getBuffer
(
stream
,
m
.
size
(),
true
).
getUserArgs
();
hipblaslt_ext
::
UserArguments
*
host_args
=
host_user_args
.
getArgs
();
hipEvent_t
host_event
=
host_user_args
.
getEvent
();
const
hipDataType
A_type
=
get_hipblaslt_dtype
(
inputA
[
0
]
->
data
.
dtype
);
const
hipDataType
B_type
=
get_hipblaslt_dtype
(
inputB
[
0
]
->
data
.
dtype
);
...
...
@@ -1502,8 +1481,6 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
const
int
request_solutions
=
1
;
std
::
vector
<
hipblasLtMatmulHeuristicResult_t
>
heuristicResult
;
NVTE_CHECK_CUDA
(
hipEventSynchronize
(
event
));
hipblaslt_ext
::
GemmPreference
gemmPref
;
gemmPref
.
setMaxWorkspaceBytes
(
0
);
...
...
@@ -1519,13 +1496,19 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Make sure to initialize everytime the algo changes
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
initialize
(
heuristicResult
[
0
].
algo
,
nullptr
,
true
,
stream
));
NVTE_CHECK_CUDA
(
hipEventSynchronize
(
host_event
));
// Get the default values from the grouepdgemm object
groupedgemm
.
getDefaultValueForDeviceUserArguments
(
userArgs
);
groupedgemm
.
getDefaultValueForDeviceUserArguments
(
host_args
);
if
(
stream
!=
device_stream
)
{
NVTE_CHECK_CUDA
(
hipStreamWaitEvent
(
stream
,
device_event
,
0
));
}
// Copy them to device memory
NVTE_CHECK_CUDA
(
hipMemcpyAsync
(
d_userArgs
,
userArgs
,
m
.
size
()
*
sizeof
(
hipblaslt_ext
::
UserArguments
),
hipMemcpyHostToDevice
,
stream
));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
d_userArgs
,
stream
));
NVTE_CHECK_CUDA
(
hipEventRecord
(
event
,
stream
));
NVTE_CHECK_CUDA
(
hipMemcpyAsync
(
device_args
,
host_args
,
m
.
size
()
*
sizeof
(
hipblaslt_ext
::
UserArguments
),
hipMemcpyHostToDevice
,
stream
));
NVTE_CHECK_CUDA
(
hipEventRecord
(
host_event
,
stream
));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
device_args
,
stream
));
device_user_args
.
setStream
(
stream
);
NVTE_CHECK_CUDA
(
hipEventRecord
(
device_event
,
stream
));
}
#endif //USE_HIPBLASLT
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment