Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
47077129
Commit
47077129
authored
Oct 16, 2025
by
yuguo
Browse files
[DCU] remove redundant gemm
parent
aa62d24c
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
6 additions
and
494 deletions
+6
-494
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+3
-81
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+0
-405
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+1
-6
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+2
-2
No files found.
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
47077129
...
@@ -1166,82 +1166,13 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
...
@@ -1166,82 +1166,13 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
stream
);
stream
);
}
}
// add for batchgemm
void
nvte_cublas_batchgemm_v2
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_cublas_batchgemm_v2
);
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputB
=
convertNVTETensorCheck
(
B
);
Tensor
*
outputD
=
convertNVTETensor
(
D
);
const
Tensor
*
biasTensor
=
convertNVTETensor
(
bias
);
Tensor
*
outputGelu
=
convertNVTETensor
(
pre_gelu_out
);
Tensor
*
wspace
=
convertNVTETensor
(
workspace
);
if
((
biasTensor
->
data
.
dptr
!=
nullptr
)
||
(
outputGelu
->
data
.
dptr
!=
nullptr
))
{
NVTE_ERROR
(
"MOE batchgemm not surpport bias or gelu."
);
}
int
m
,
n
,
k
;
if
(
!
transa
&&
transb
)
{
// for NT
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
else
if
(
transa
&&
!
transb
){
// for TN
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
else
if
(
!
transa
&&
!
transb
){
// for NN
m
=
transa
?
inputA
->
data
.
shape
[
0
]
/
batch_count
:
inputA
->
data
.
shape
[
1
];
k
=
transa
?
inputA
->
data
.
shape
[
1
]
:
inputA
->
data
.
shape
[
0
]
/
batch_count
;
n
=
transb
?
inputB
->
data
.
shape
[
1
]
:
inputB
->
data
.
shape
[
0
]
/
batch_count
;
}
int
lda
,
ldb
,
ldd
;
if
(
transa
&&
!
transb
)
{
// TN
lda
=
k
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
!
transb
)
{
// NN
lda
=
m
;
ldb
=
k
;
ldd
=
m
;
}
else
if
(
!
transa
&&
transb
)
{
// NT
lda
=
m
;
ldb
=
n
;
ldd
=
m
;
}
else
{
// TT
NVTE_ERROR
(
"TT layout not allowed."
);
}
hipblas_batchgemm
(
inputA
,
inputB
,
outputD
,
biasTensor
,
outputGelu
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
batch_count
,
stream
);
}
// add for batchgemm
// add for batchgemm
void
nvte_cublas_batchgemm_
v3
(
const
NVTETensor
A
,
const
NVTETensor
B
,
const
NVTETensor
A_scales
,
const
NVTETensor
B_scales
,
NVTETensor
D
,
const
NVTETensor
bias
,
void
nvte_cublas_batchgemm_
tensorwise_int8
(
const
NVTETensor
A
,
const
NVTETensor
B
,
const
NVTETensor
A_scales
,
const
NVTETensor
B_scales
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
)
{
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_cublas_batchgemm_
v3
);
NVTE_API_CALL
(
nvte_cublas_batchgemm_
tensorwise_int8
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputA
=
convertNVTETensorCheck
(
A
);
const
Tensor
*
inputB
=
convertNVTETensorCheck
(
B
);
const
Tensor
*
inputB
=
convertNVTETensorCheck
(
B
);
...
@@ -1297,16 +1228,7 @@ void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTE
...
@@ -1297,16 +1228,7 @@ void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTE
handle
=
hipblaslt_handles
[
0
];
handle
=
hipblaslt_handles
[
0
];
hipblaslt_batchgemm_tensorwise_int8
(
inputA
,
inputB
,
inputA_scales
,
inputB_scales
,
outputD
,
biasTensor
,
outputGelu
,
NVTE_ERROR
(
"Remove nvte_cublas_batchgemm_tensorwise_int8 for now."
);
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
(
transa
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
(
transb
)
?
HIPBLAS_OP_T
:
HIPBLAS_OP_N
,
grad
,
wspace
->
data
.
dptr
,
wspace
->
data
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
0
,
0
,
false
,
nullptr
,
batch_count
,
stream
,
handle
);
}
}
#endif
#endif
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
47077129
This diff is collapsed.
Click to expand it.
transformer_engine/common/include/transformer_engine/gemm.h
View file @
47077129
...
@@ -152,12 +152,7 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
...
@@ -152,12 +152,7 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
void
nvte_cublas_batchgemm_v2
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
void
nvte_cublas_batchgemm_tensorwise_int8
(
const
NVTETensor
A
,
const
NVTETensor
B
,
const
NVTETensor
A_scales
,
const
NVTETensor
B_scales
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
void
nvte_cublas_batchgemm_v3
(
const
NVTETensor
A
,
const
NVTETensor
B
,
const
NVTETensor
A_scales
,
const
NVTETensor
B_scales
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
int
math_sm_count
,
int
batch_count
,
cudaStream_t
stream
);
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
47077129
...
@@ -588,7 +588,7 @@ std::vector<py::object> generic_batchgemm(py::handle A, bool transa, py::handle
...
@@ -588,7 +588,7 @@ std::vector<py::object> generic_batchgemm(py::handle A, bool transa, py::handle
}
else
{
}
else
{
// Launch GEMM
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_batchgemm
_v2
(
A_tensor
.
data
(),
B_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
nvte_cublas_batchgemm
(
A_tensor
.
data
(),
B_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sms
,
batch_count
,
main_stream
);
accumulate
,
use_split_accumulator
,
num_math_sms
,
batch_count
,
main_stream
);
});
});
...
@@ -724,7 +724,7 @@ std::vector<py::object> tensorwise_int8_batchgemm(py::handle A, bool transa, py:
...
@@ -724,7 +724,7 @@ std::vector<py::object> tensorwise_int8_batchgemm(py::handle A, bool transa, py:
}
else
{
}
else
{
// Launch GEMM
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_batchgemm_
v3
(
A_tensor
.
data
(),
B_tensor
.
data
(),
A_scales_tensor
.
data
(),
B_scales_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
nvte_cublas_batchgemm_
tensorwise_int8
(
A_tensor
.
data
(),
B_tensor
.
data
(),
A_scales_tensor
.
data
(),
B_scales_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sms
,
batch_count
,
main_stream
);
accumulate
,
use_split_accumulator
,
num_math_sms
,
batch_count
,
main_stream
);
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment