Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
c37084b9
"tests/cpp/operator/test_multi_padding.cu" did not exist on "7f2703304dd3f90b282ea323d10f9f59b8d859fb"
Commit
c37084b9
authored
May 07, 2025
by
yuguo
Browse files
[DCU] surpport NVTE_MOE_BATCHCOUNT
parent
c686efc1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
17 deletions
+37
-17
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+18
-4
transformer_engine/pytorch/module/batched_linear.py
transformer_engine/pytorch/module/batched_linear.py
+19
-13
No files found.
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
c37084b9
...
...
@@ -731,6 +731,21 @@ void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHan
#endif
#ifdef __HIP_PLATFORM_AMD__
static
inline
int
getIntEnv
(
const
char
*
name
,
int
defval
,
int
minval
)
{
int
val
=
defval
;
const
char
*
env
=
std
::
getenv
(
name
);
if
(
env
!=
nullptr
&&
env
[
0
]
!=
'\0'
)
{
val
=
atoi
(
env
);
if
(
val
<
minval
)
{
val
=
minval
;
}
}
return
val
;
}
void
nvte_multi_stream_cublas_batchgemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
...
...
@@ -739,18 +754,17 @@ void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_stream_cublas_batchgemm
);
using
namespace
transformer_engine
;
assert
(
num_gemms
%
num_batchgemm_streams
==
0
);
static
int
batch_count
=
num_gemms
/
num_batchgemm_streams
;
int
batch_count
=
getIntEnv
(
"NVTE_MOE_BATCHCOUNT"
,
2
,
1
);;
// Inits streams and events (once, globally)
std
::
call_once
(
init_flag_batchgemm
,
init_streams_and_events_batchgemm
);
int
num_stream_used
=
num_batchgemm_streams
;
int
num_stream_used
=
std
::
min
(
num_batchgemm_streams
,
num_gemms
)
;
// wait for current stream to finish
NVTE_CHECK_CUDA
(
cudaEventRecord
(
cublas_event_batchgemm
[
0
],
stream
));
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
compute_streams_batchgemm
[
s
],
cublas_event_batchgemm
[
0
]));
}
for
(
int
i
=
0
;
i
<
num_
stream_used
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_
gemms
;
i
++
)
{
nvte_cublas_batchgemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
workspace
[
i
%
num_batchgemm_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
batch_count
,
compute_streams_batchgemm
[
i
%
num_batchgemm_streams
]);
...
...
transformer_engine/pytorch/module/batched_linear.py
View file @
c37084b9
...
...
@@ -4,7 +4,7 @@
"""BatchedLinear API"""
from
typing
import
Union
,
Optional
,
Callable
,
Tuple
,
List
import
os
import
torch
import
transformer_engine_torch
as
tex
...
...
@@ -79,6 +79,8 @@ class _BatchedLinear(torch.autograd.Function):
*
weights_and_biases
,
)
->
torch
.
Tensor
:
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
# pylint: disable=missing-function-docstring
num_gemms
=
len
(
m_splits
)
weights
=
weights_and_biases
[:
num_gemms
]
...
...
@@ -158,8 +160,9 @@ class _BatchedLinear(torch.autograd.Function):
biases
=
[
cast_if_needed
(
bias
,
bias_dtype
)
for
bias
in
biases
]
if
use_bias
else
biases
assert
weights_fp8
[
0
].
size
(
0
)
%
batch_num
==
0
,
"weights_fp8[0].size(0) should be batch_num multiply."
out
=
torch
.
empty
(
[
sum
(
m_splits
),
weights_fp8
[
0
].
size
(
0
)],
[
sum
(
m_splits
),
weights_fp8
[
0
].
size
(
0
)
//
batch_num
],
dtype
=
activation_dtype
,
device
=
device
,
)
...
...
@@ -448,7 +451,9 @@ class BatchedLinear(TransformerEngineBaseModule):
super
().
__init__
()
params_dtype
=
torch
.
get_default_dtype
()
if
params_dtype
is
None
else
params_dtype
self
.
num_gemms
=
num_gemms
self
.
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
assert
num_gemms
%
self
.
batch_num
==
0
,
"Number of GEMMs should be batch_num multiply."
self
.
num_gemms
=
num_gemms
//
self
.
batch_num
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
...
...
@@ -464,7 +469,7 @@ class BatchedLinear(TransformerEngineBaseModule):
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
_offsets
=
{
"input"
:
0
,
"weight"
:
num_gemms
,
"output"
:
2
*
num_gemms
,
"grad_output"
:
0
}
self
.
_offsets
=
{
"input"
:
0
,
"weight"
:
self
.
num_gemms
,
"output"
:
2
*
self
.
num_gemms
,
"grad_output"
:
0
}
if
tp_group
is
None
:
self
.
tp_size
=
tp_size
...
...
@@ -483,17 +488,17 @@ class BatchedLinear(TransformerEngineBaseModule):
if
self
.
parallel_mode
==
"column"
:
self
.
out_features
=
divide
(
self
.
out_features
,
self
.
tp_size
)
elif
self
.
parallel_mode
==
"row"
:
self
.
in_features
=
divide
(
self
.
in_features
,
self
.
tp_size
)
self
.
in_features
=
divide
(
self
.
in_features
*
self
.
batch_num
,
self
.
tp_size
)
self
.
sequence_parallel
=
(
self
.
tp_size
>
1
)
and
sequence_parallel
# In batchgemm, we use batch=batch_num to launch blas batchgemm
for
i
in
range
(
self
.
num_gemms
):
# Construct weight parameter
self
.
register_parameter
(
f
"weight
{
i
}
"
,
torch
.
nn
.
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
out_features
*
self
.
batch_num
,
self
.
in_features
,
device
=
device
,
dtype
=
params_dtype
,
...
...
@@ -548,15 +553,15 @@ class BatchedLinear(TransformerEngineBaseModule):
# Set parallelism attributes for linear biases
if
self
.
use_bias
:
for
i
in
range
(
self
.
num_gemms
)
:
for
bias
in
self
.
bias_names
:
if
self
.
parallel_mode
==
"row"
:
setattr
(
getattr
(
self
,
f
"
bias
{
i
}
"
),
getattr
(
self
,
bias
),
"sequence_parallel"
,
self
.
sequence_parallel
,
)
elif
self
.
parallel_mode
==
"column"
:
set_tensor_model_parallel_attributes
(
getattr
(
self
,
f
"
bias
{
i
}
"
),
True
,
0
,
1
)
set_tensor_model_parallel_attributes
(
getattr
(
self
,
bias
),
True
,
0
,
1
)
@
no_torch_dynamo
()
def
forward
(
...
...
@@ -591,7 +596,8 @@ class BatchedLinear(TransformerEngineBaseModule):
assert
not
isinstance
(
inp
,
Float8Tensor
),
"BatchedLinear doesn't support input tensor in FP8."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
m_splits_batch_gemm
=
[
x
*
self
.
batch_num
for
x
in
m_splits
[
0
:
int
(
self
.
num_gemms
)]]
assert
len
(
m_splits_batch_gemm
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
skip_fp8_weight_update
=
FP8GlobalStateManager
.
get_skip_fp8_weight_update_tensor
()
if
skip_fp8_weight_update
is
not
None
:
...
...
@@ -641,7 +647,7 @@ class BatchedLinear(TransformerEngineBaseModule):
args
=
[
None
]
args
+=
(
inp
,
m_splits
,
m_splits
_batch_gemm
,
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
,
is_first_microbatch
,
self
.
fp8
,
...
...
@@ -668,7 +674,7 @@ class BatchedLinear(TransformerEngineBaseModule):
[
o
+
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
o
,
b
in
zip
(
torch
.
split
(
out
.
view
(
-
1
,
self
.
out_features
),
m_splits
),
bias_tensors
torch
.
split
(
out
.
view
(
-
1
,
self
.
out_features
),
m_splits
_batch_gemm
),
bias_tensors
)
]
).
view
(
out_shape
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment