Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
686af9c3
Commit
686af9c3
authored
Jul 16, 2025
by
yuguo
Browse files
Merge branch 'develop_v2.4' of
http://10.16.6.30/dcutoolkit/deeplearing/TransformerEngine
parents
ee787b22
9406ff31
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
224 additions
and
7 deletions
+224
-7
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+60
-0
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+146
-0
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+18
-7
No files found.
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
686af9c3
...
...
@@ -867,6 +867,66 @@ void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHan
#endif
#ifdef __HIP_PLATFORM_AMD__
void
nvte_grouped_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
*
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_grouped_gemm
);
using
namespace
transformer_engine
;
std
::
vector
<
const
Tensor
*>
inputA
;
std
::
vector
<
const
Tensor
*>
inputB
;
std
::
vector
<
Tensor
*>
outputD
;
std
::
vector
<
const
Tensor
*>
biasTensor
;
std
::
vector
<
Tensor
*>
outputGelu
;
std
::
vector
<
int64_t
>
m
;
std
::
vector
<
int64_t
>
n
;
std
::
vector
<
int64_t
>
k
;
std
::
vector
<
int64_t
>
b
;
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
inputA
.
push_back
(
convertNVTETensorCheck
(
A
[
i
]));
inputB
.
push_back
(
convertNVTETensorCheck
(
B
[
i
]));
outputD
.
push_back
(
convertNVTETensorCheck
(
D
[
i
]));
biasTensor
.
push_back
(
convertNVTETensorCheck
(
bias
[
i
]));
outputGelu
.
push_back
(
convertNVTETensorCheck
(
pre_gelu_out
[
i
]));
b
.
push_back
(
1
);
size_t
A0
=
inputA
[
i
]
->
flat_first_dim
();
size_t
A1
=
inputA
[
i
]
->
flat_last_dim
();
size_t
B0
=
inputB
[
i
]
->
flat_first_dim
();
size_t
B1
=
inputB
[
i
]
->
flat_last_dim
();
if
(
transa
)
{
m
.
push_back
(
A0
);
k
.
push_back
(
A1
);
}
else
{
m
.
push_back
(
A1
);
k
.
push_back
(
A0
);
}
if
(
transb
)
{
n
.
push_back
(
B1
);
}
else
{
n
.
push_back
(
B0
);
}
}
Tensor
*
wspace
=
convertNVTETensorCheck
(
workspace
[
0
]);
if
((
biasTensor
[
0
]
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
[
0
]
->
data
.
dptr
!=
nullptr
))
{
NVTE_ERROR
(
"MOE nvte_grouped_gemm not surpport bias or gelu."
);
}
hipblaslt_goupedgemm
(
inputA
,
inputB
,
outputD
,
m
,
n
,
k
,
b
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
stream
);
}
void
nvte_multi_stream_cublas_batchgemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
686af9c3
...
...
@@ -22,6 +22,7 @@
#define ROCBLAS_BETA_FEATURES_API
#include <rocblas/rocblas.h>
#include <hipcub/hipcub.hpp>
#include <hipblaslt/hipblaslt-ext.hpp>
#endif
#include <iostream>
#include <cstdlib>
...
...
@@ -50,6 +51,10 @@ static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) {
return
HIP_R_8F_E4M3
;
case
DType
::
kFloat8E5M2
:
return
HIP_R_8F_E5M2
;
case
DType
::
kInt8
:
return
HIP_R_8I
;
case
DType
::
kInt32
:
return
HIP_R_32I
;
default:
NVTE_ERROR
(
"Invalid type"
);
}
...
...
@@ -1367,6 +1372,147 @@ void hipblaslt_gemm(const Tensor *inputA,
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatrixLayoutDestroy
(
Adesc
));
NVTE_CHECK_HIPBLASLT
(
hipblasLtMatmulDescDestroy
(
operationDesc
));
}
class
userArgsManager
{
public:
userArgsManager
()
{}
~
userArgsManager
()
{
// Release all userArgs when the manager is destroyed
for
(
auto
&
device_pair
:
userArgs_map_
)
{
hipFree
(
device_pair
.
second
);
// Only one userArgs per device
}
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext
::
UserArguments
*
get
(
int
device_id
,
size_t
size
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
// Check if the userArgs for this device exists
auto
device_it
=
userArgs_map_
.
find
(
device_id
);
if
(
device_it
!=
userArgs_map_
.
end
())
{
return
device_it
->
second
;
}
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext
::
UserArguments
*
userArgs
;
NVTE_CHECK_CUDA
(
hipHostMalloc
(
&
userArgs
,
size
*
sizeof
(
hipblaslt_ext
::
UserArguments
)));
// Store the userArgs in the map for this device
userArgs_map_
[
device_id
]
=
userArgs
;
return
userArgs
;
}
private:
std
::
unordered_map
<
int
,
hipblaslt_ext
::
UserArguments
*>
userArgs_map_
;
// Map from device_id to hipblasHandle
std
::
mutex
mutex_
;
};
// Define a static userArgs manager
// static userArgsManager UAManager;
void
hipblaslt_goupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
std
::
vector
<
Tensor
*>&
outputD
,
std
::
vector
<
int64_t
>&
m
,
std
::
vector
<
int64_t
>&
n
,
std
::
vector
<
int64_t
>&
k
,
std
::
vector
<
int64_t
>&
b
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
void
*
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
hipStream_t
stream
,
int
compute_stream_offset
=
0
)
{
// Check compute_stream_offset valid.
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
// int device_id;
// hipGetDevice(&device_id);
// hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
hipblasLtHandle_t
handle
=
nullptr
;
if
(
compute_stream_offset
!=
-
1
)
{
// Init hipblaslt handles (once, globally)
static
std
::
once_flag
init_flag
;
static
hipblasLtHandle_t
hipblaslt_handles
[
1
];
std
::
call_once
(
init_flag
,
init_hipblaslt_handles
,
hipblaslt_handles
);
handle
=
hipblaslt_handles
[
compute_stream_offset
];
}
const
hipDataType
A_type
=
get_hipblaslt_dtype
(
inputA
[
0
]
->
data
.
dtype
);
const
hipDataType
B_type
=
get_hipblaslt_dtype
(
inputB
[
0
]
->
data
.
dtype
);
const
hipDataType
D_type
=
get_hipblaslt_dtype
(
outputD
[
0
]
->
data
.
dtype
);
hipblasComputeType_t
computeType
=
HIPBLAS_COMPUTE_32F
;
float
one
=
1.0
;
float
zero
=
0.0
;
float
beta
=
(
accumulate
)
?
one
:
zero
;
int
int_one
=
1
;
int
int_zero
=
0
;
int
int_beta
=
int_zero
;
bool
use_int8
=
false
;
if
((
A_type
==
HIP_R_8I
)
&&
(
B_type
==
HIP_R_8I
)
&&
(
D_type
==
HIP_R_32I
))
{
NVTE_CHECK
(
!
accumulate
,
"Int8 gemm not support accumulate."
);
use_int8
=
true
;
computeType
=
HIPBLAS_COMPUTE_32I
;
}
hipblaslt_ext
::
GemmPreference
gemmPref
;
gemmPref
.
setMaxWorkspaceBytes
(
workspaceSize
);
hipblaslt_ext
::
GroupedGemm
groupedgemm
(
handle
,
transa
,
transb
,
A_type
,
B_type
,
D_type
,
D_type
,
computeType
);
std
::
vector
<
hipblaslt_ext
::
GemmEpilogue
>
epilogue
{
hipblaslt_ext
::
GemmEpilogue
()};
// No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
std
::
vector
<
hipblaslt_ext
::
GemmInputs
>
inputs
(
m
.
size
());
for
(
int
i
=
0
;
i
<
m
.
size
();
i
++
)
{
inputs
[
i
].
a
=
inputA
[
i
]
->
data
.
dptr
;
inputs
[
i
].
b
=
inputB
[
i
]
->
data
.
dptr
;
inputs
[
i
].
c
=
outputD
[
i
]
->
data
.
dptr
;
inputs
[
i
].
d
=
outputD
[
i
]
->
data
.
dptr
;
inputs
[
i
].
alpha
=
use_int8
?
static_cast
<
void
*>
(
&
int_one
)
:
static_cast
<
void
*>
(
&
one
);
inputs
[
i
].
beta
=
use_int8
?
static_cast
<
void
*>
(
&
int_beta
)
:
static_cast
<
void
*>
(
&
beta
);
}
// hipblaslt_ext::GemmEpilogue supports broadcasting
groupedgemm
.
setProblem
(
m
,
n
,
k
,
b
,
epilogue
,
inputs
);
const
int
request_solutions
=
1
;
std
::
vector
<
hipblasLtMatmulHeuristicResult_t
>
heuristicResult
;
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
algoGetHeuristic
(
request_solutions
,
gemmPref
,
heuristicResult
));
if
(
heuristicResult
.
empty
())
{
std
::
cerr
<<
"No valid solution found!"
<<
std
::
endl
;
return
;
}
// Get the default values from the grouepdgemm object
// groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
// Copy them to device memory
// hipblaslt_ext::UserArguments* d_userArgs;
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
// NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs,
// userArgs,
// m.size() * sizeof(hipblaslt_ext::UserArguments),
// hipMemcpyHostToDevice, stream));
// Make sure to initialize everytime the algo changes
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
initialize
(
heuristicResult
[
0
].
algo
,
workspace
,
false
,
stream
));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
stream
));
// NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
// NVTE_CHECK_CUDA(hipFree(userArgs));
}
#endif //USE_HIPBLASLT
#ifdef USE_ROCBLAS // Use rocblas + kernel, no fusion
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
686af9c3
...
...
@@ -564,13 +564,24 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
wrappers
.
emplace_back
(
std
::
move
(
wsp
));
}
// For now, we only have multi-stream cublas backend.
NVTE_SCOPED_GIL_RELEASE
({
nvte_multi_stream_cublas_gemm
(
te_A_vector
.
data
(),
te_B_vector
.
data
(),
te_D_vector
.
data
(),
te_bias_vector
.
data
(),
te_pre_gelu_out_vector
.
data
(),
te_A_vector
.
size
(),
transa
,
transb
,
grad
,
te_workspace_vector
.
data
(),
accumulate
,
use_split_accumulator
,
math_sm_count
,
at
::
cuda
::
getCurrentCUDAStream
());
});
const
char
*
NVTE_USE_HIPBLASLT_GROUPEDGEMM
=
std
::
getenv
(
"NVTE_USE_HIPBLASLT_GROUPEDGEMM"
);
if
(
NVTE_USE_HIPBLASLT_GROUPEDGEMM
!=
nullptr
&&
NVTE_USE_HIPBLASLT_GROUPEDGEMM
[
0
]
==
'1'
){
NVTE_SCOPED_GIL_RELEASE
({
nvte_grouped_gemm
(
te_A_vector
.
data
(),
te_B_vector
.
data
(),
te_D_vector
.
data
(),
te_bias_vector
.
data
(),
te_pre_gelu_out_vector
.
data
(),
te_A_vector
.
size
(),
transa
,
transb
,
grad
,
te_workspace_vector
.
data
(),
accumulate
,
use_split_accumulator
,
math_sm_count
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
else
{
NVTE_SCOPED_GIL_RELEASE
({
nvte_multi_stream_cublas_gemm
(
te_A_vector
.
data
(),
te_B_vector
.
data
(),
te_D_vector
.
data
(),
te_bias_vector
.
data
(),
te_pre_gelu_out_vector
.
data
(),
te_A_vector
.
size
(),
transa
,
transb
,
grad
,
te_workspace_vector
.
data
(),
accumulate
,
use_split_accumulator
,
math_sm_count
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
return
bias
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment