Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
067c2b3d
Commit
067c2b3d
authored
Jun 19, 2025
by
yuguo
Browse files
Merge branch 'develop_v2.4' of
http://10.16.6.30/dcutoolkit/deeplearing/TransformerEngine
parents
4cc47ca6
d6c32078
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
12 deletions
+51
-12
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+35
-0
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+12
-7
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
+3
-4
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
...mer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
+1
-1
No files found.
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
067c2b3d
...
@@ -596,8 +596,43 @@ static cudaEvent_t cublas_event_batchgemm[num_batchgemm_streams];
...
@@ -596,8 +596,43 @@ static cudaEvent_t cublas_event_batchgemm[num_batchgemm_streams];
// Warning: only call once per device!
// Warning: only call once per device!
static
void
init_streams_and_events_batchgemm
()
{
static
void
init_streams_and_events_batchgemm
()
{
int
comm_cu_nums
=
getIntEnv
(
"TORCH_COMM_CU_NUMS"
,
8
,
4
);
unsigned
int
cuMask
[
4
];
unsigned
int
cuMaskSize
=
4
;
if
(
comm_cu_nums
==
4
)
{
cuMask
[
0
]
=
0xfffffff0
;
cuMask
[
1
]
=
0xffffffff
;
cuMask
[
2
]
=
0xffffffff
;
cuMask
[
3
]
=
0xffffffff
;
}
else
if
(
comm_cu_nums
==
8
)
{
cuMask
[
0
]
=
0xffffff00
;
cuMask
[
1
]
=
0xffffffff
;
cuMask
[
2
]
=
0xffffffff
;
cuMask
[
3
]
=
0xffffffff
;
}
else
if
(
comm_cu_nums
==
16
)
{
cuMask
[
0
]
=
0xffff0000
;
cuMask
[
1
]
=
0xffffffff
;
cuMask
[
2
]
=
0xffffffff
;
cuMask
[
3
]
=
0xffffffff
;
}
else
if
(
comm_cu_nums
==
32
)
{
cuMask
[
0
]
=
0x00000000
;
cuMask
[
1
]
=
0xffffffff
;
cuMask
[
2
]
=
0xffffffff
;
cuMask
[
3
]
=
0xffffffff
;
}
else
{
NVTE_CHECK
(
false
,
"comm_cu_nums must be 4,8,16,32"
);
}
const
char
*
TORCH_COMM_CU_NUMS
=
std
::
getenv
(
"TORCH_COMM_CU_NUMS"
);
for
(
int
i
=
0
;
i
<
num_batchgemm_streams
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_batchgemm_streams
;
i
++
)
{
#ifdef __HIP_PLATFORM_AMD__
if
(
TORCH_COMM_CU_NUMS
!=
nullptr
&&
TORCH_COMM_CU_NUMS
[
0
]
!=
'\0'
)
{
NVTE_CHECK_CUDA
(
hipExtStreamCreateWithCUMask
(
&
compute_streams_batchgemm
[
i
],
cuMaskSize
,
cuMask
));
}
else
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
compute_streams_batchgemm
[
i
],
cudaStreamNonBlocking
,
-
1
));
}
#else
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
compute_streams_batchgemm
[
i
],
cudaStreamNonBlocking
,
-
1
));
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
compute_streams_batchgemm
[
i
],
cudaStreamNonBlocking
,
-
1
));
#endif
NVTE_CHECK_CUDA
(
cudaEventCreate
(
&
cublas_event_batchgemm
[
i
]));
NVTE_CHECK_CUDA
(
cudaEventCreate
(
&
cublas_event_batchgemm
[
i
]));
}
}
}
}
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
067c2b3d
...
@@ -209,7 +209,6 @@ def general_grouped_gemm(
...
@@ -209,7 +209,6 @@ def general_grouped_gemm(
if
int8_simulation_fp8
and
(
isinstance
(
A
[
0
],
Float8BlockwiseQTensorBase
)
or
isinstance
(
B
[
0
],
Float8BlockwiseQTensorBase
)):
if
int8_simulation_fp8
and
(
isinstance
(
A
[
0
],
Float8BlockwiseQTensorBase
)
or
isinstance
(
B
[
0
],
Float8BlockwiseQTensorBase
)):
assert
len
(
set
(
m_splits
))
==
1
,
"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert
len
(
set
(
m_splits
))
==
1
,
"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert
not
gelu
,
"GELU not supported with int8 simulation groupgemm."
assert
not
gelu
,
"GELU not supported with int8 simulation groupgemm."
assert
bias
is
None
,
"Bias not supported with int8 simulation groupgemm."
if
layout
==
"TN"
:
if
layout
==
"TN"
:
qx_data
=
[
qx_data
=
[
...
@@ -221,11 +220,14 @@ def general_grouped_gemm(
...
@@ -221,11 +220,14 @@ def general_grouped_gemm(
ref_scales_x
=
[
b
.
_rowwise_scale_inv
for
b
in
B
]
ref_scales_x
=
[
b
.
_rowwise_scale_inv
for
b
in
B
]
ref_scales_w
=
[
a
.
_rowwise_scale_inv
for
a
in
A
]
ref_scales_w
=
[
a
.
_rowwise_scale_inv
for
a
in
A
]
y
,
_
=
w8a8_block_int8_matmul_batched
(
num_gemms
=
len
(
A
)
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
128
,
128
],
seq_len
=
sum
(
m_splits
)
//
num_gemms
out
[
0
],
_
=
w8a8_block_int8_matmul_batched
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
)),
[
128
,
128
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
return
y
,
None
,
None
return
out
,
None
,
None
elif
layout
==
"NN"
:
elif
layout
==
"NN"
:
qdout_data
=
[
qdout_data
=
[
...
@@ -236,12 +238,15 @@ def general_grouped_gemm(
...
@@ -236,12 +238,15 @@ def general_grouped_gemm(
]
]
ref_scales_dout
=
[
b
.
_rowwise_scale_inv
for
b
in
B
]
ref_scales_dout
=
[
b
.
_rowwise_scale_inv
for
b
in
B
]
ref_scales_w
=
[
a
.
_columnwise_scale_inv
for
a
in
A
]
ref_scales_w
=
[
a
.
_columnwise_scale_inv
for
a
in
A
]
num_gemms
=
len
(
A
)
seq_len
=
sum
(
m_splits
)
//
num_gemms
y
,
_
=
w8a8_block_int8_matmul_batched
(
out
[
0
]
,
_
=
w8a8_block_int8_matmul_batched
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
128
,
128
],
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
out
[
0
].
view
(
num_gemms
,
seq_len
,
out
[
0
].
size
(
-
1
)),
[
128
,
128
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
return
y
,
None
,
None
return
out
,
None
,
None
elif
layout
==
"NT"
:
elif
layout
==
"NT"
:
qdout_data
=
[
qdout_data
=
[
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
View file @
067c2b3d
...
@@ -446,7 +446,7 @@ def _w8a8_block_int8_matmul_batched(
...
@@ -446,7 +446,7 @@ def _w8a8_block_int8_matmul_batched(
def
w8a8_block_int8_matmul_batched
(
def
w8a8_block_int8_matmul_batched
(
A_list
,
B_list
,
As_list
,
Bs_list
,
A_list
,
B_list
,
As_list
,
Bs_list
,
C
,
block_size
,
output_dtype
=
torch
.
float16
,
best_config
=
None
block_size
,
output_dtype
=
torch
.
float16
,
best_config
=
None
):
):
A
=
torch
.
stack
(
A_list
).
contiguous
()
# [B, M, K]
A
=
torch
.
stack
(
A_list
).
contiguous
()
# [B, M, K]
...
@@ -460,8 +460,7 @@ def w8a8_block_int8_matmul_batched(
...
@@ -460,8 +460,7 @@ def w8a8_block_int8_matmul_batched(
batch
,
N
,
K
=
B
.
shape
batch
,
N
,
K
=
B
.
shape
block_n
,
block_k
=
block_size
block_n
,
block_k
=
block_size
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
assert
C
.
size
(
-
1
)
==
N
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
config
=
{
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
...
@@ -507,7 +506,7 @@ def w8a8_block_int8_matmul_batched(
...
@@ -507,7 +506,7 @@ def w8a8_block_int8_matmul_batched(
**
config
,
**
config
,
)
)
return
C
return
C
.
view
(
-
1
,
C
.
size
(
-
1
))
def
apply_w8a8_block_int8_linear_batched_helper
(
m
:
int
,
def
apply_w8a8_block_int8_linear_batched_helper
(
m
:
int
,
n
:
int
,
n
:
int
,
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
View file @
067c2b3d
...
@@ -532,7 +532,7 @@ def w8a8_block_int8_matmul_wgrad_batched(
...
@@ -532,7 +532,7 @@ def w8a8_block_int8_matmul_wgrad_batched(
**
config
,
**
config
,
)
)
return
C
return
[
C
[
i
]
for
i
in
range
(
C
.
size
(
0
))]
def
apply_w8a8_block_int8_linear_batched_helper
(
m
:
int
,
def
apply_w8a8_block_int8_linear_batched_helper
(
m
:
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment