Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
460b006c
Commit
460b006c
authored
May 20, 2025
by
yuguo
Browse files
[DCU] surpport delay_wgrad_compute in batchgemm
parent
196a213f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
119 additions
and
44 deletions
+119
-44
tests/pytorch/test_batched_linear.py
tests/pytorch/test_batched_linear.py
+29
-4
transformer_engine/pytorch/module/batched_linear.py
transformer_engine/pytorch/module/batched_linear.py
+90
-40
No files found.
tests/pytorch/test_batched_linear.py
View file @
460b006c
...
...
@@ -171,7 +171,7 @@ def reset_global_fp8_state():
FP8GlobalStateManager
.
reset
()
def
_test_batched_linear_accuracy
(
block
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
block
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
batch_num
):
reset_rng_states
()
if
fp8
:
...
...
@@ -202,9 +202,31 @@ def _test_batched_linear_accuracy(
)
loss
=
out
.
sum
()
loss
.
backward
()
if
delay_wgrad_compute
:
if
isinstance
(
block
,
BatchedLinear
):
block
.
backward_dw
()
else
:
for
i
in
range
(
num_gemms
):
block
[
i
].
backward_dw
()
torch
.
cuda
.
synchronize
()
outputs
=
[
out
,
inp_hidden_states
.
grad
]
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
if
isinstance
(
block
,
BatchedLinear
):
if
getattr
(
p
,
"main_grad"
,
None
)
is
not
None
:
for
j
in
range
(
batch_num
):
outputs
.
append
(
p
.
main_grad
[
p
.
main_grad
.
shape
[
0
]
//
batch_num
*
j
:
p
.
main_grad
.
shape
[
0
]
//
batch_num
*
(
j
+
1
)])
assert
p
.
grad
is
None
# grad should be None if fuse_wgrad_accumulation is True
else
:
for
j
in
range
(
batch_num
):
outputs
.
append
(
p
.
grad
[
p
.
grad
.
shape
[
0
]
//
batch_num
*
j
:
p
.
grad
.
shape
[
0
]
//
batch_num
*
(
j
+
1
)])
else
:
if
getattr
(
p
,
"main_grad"
,
None
)
is
not
None
:
outputs
.
append
(
p
.
main_grad
)
assert
p
.
grad
is
None
# grad should be None if fuse_wgrad_accumulation is True
else
:
outputs
.
append
(
p
.
grad
)
return
outputs
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
...
...
@@ -215,6 +237,7 @@ def _test_batched_linear_accuracy(
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"delay_wgrad_compute"
,
all_boolean
)
def
test_batched_linear_accuracy
(
dtype
,
num_gemms
,
...
...
@@ -224,6 +247,7 @@ def test_batched_linear_accuracy(
recipe
,
fp8_model_params
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
parallel_mode
=
None
,
):
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
...
...
@@ -250,6 +274,7 @@ def test_batched_linear_accuracy(
parallel_mode
=
parallel_mode
,
device
=
"cuda"
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
delay_wgrad_compute
=
delay_wgrad_compute
,
).
eval
()
sequential_linear
=
torch
.
nn
.
ModuleList
(
[
...
...
@@ -281,10 +306,10 @@ def test_batched_linear_accuracy(
sequential_linear
[
i
*
batch_num
+
j
].
weight
.
main_grad
=
weight_i
.
main_grad
[
weight_i
.
main_grad
.
shape
[
0
]
//
batch_num
*
j
:
weight_i
.
main_grad
.
shape
[
0
]
//
batch_num
*
(
j
+
1
)].
clone
()
outputs_ref
=
_test_batched_linear_accuracy
(
sequential_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
sequential_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
batch_num
)
outputs
=
_test_batched_linear_accuracy
(
batched_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
batched_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
batch_num
)
# Shoule be bit-wise match
...
...
@@ -292,4 +317,4 @@ def test_batched_linear_accuracy(
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
6e-3
,
atol
=
6e-3
)
if
__name__
==
"__main__"
:
test_batched_linear_accuracy
(
torch
.
float32
,
2
,
1
,
"126m"
,
False
,
recipe
.
Float8CurrentScaling
(),
True
,
True
)
test_batched_linear_accuracy
(
torch
.
float32
,
2
,
1
,
"126m"
,
False
,
recipe
.
Float8CurrentScaling
(),
True
,
True
,
True
)
transformer_engine/pytorch/module/batched_linear.py
View file @
460b006c
...
...
@@ -6,7 +6,7 @@
import
os
import
logging
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
,
List
import
functools
import
torch
import
transformer_engine_torch
as
tex
...
...
@@ -18,6 +18,7 @@ from .base import (
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
)
from
._common
import
WeightGradStore
from
..fp8
import
get_fp8_te_dtype
,
FP8GlobalStateManager
from
..utils
import
(
divide
,
...
...
@@ -82,6 +83,7 @@ class _BatchLinear(torch.autograd.Function):
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
fp8_meta
:
Dict
[
str
,
Any
],
fuse_wgrad_accumulation
:
bool
,
cpu_offloading
:
bool
,
...
...
@@ -183,6 +185,7 @@ class _BatchLinear(torch.autograd.Function):
ctx
.
tp_size
=
tp_size
ctx
.
requires_dgrad
=
inp
.
requires_grad
ctx
.
reduce_and_update_bwd_fp8_tensors
=
False
ctx
.
wgrad_store
=
wgrad_store
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return
out
.
view
(
-
1
,
*
inp
.
shape
[
1
:
-
1
],
out
.
shape
[
-
1
])
...
...
@@ -246,53 +249,69 @@ class _BatchLinear(torch.autograd.Function):
torch
.
empty
(
w
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
w
.
device
)
for
w
in
weights
]
# WGRAD
_
,
grad_biases
,
_
=
batchgemm
(
inputmats
,
grad_output_mats
,
wgrad_list
,
ctx
.
activation_dtype
,
get_multi_stream_cublas_batchgemm_workspace
(),
batched_gemm_wgrad
=
functools
.
partial
(
batchgemm
,
dtype
=
ctx
.
activation_dtype
,
workspaces
=
get_multi_stream_cublas_batchgemm_workspace
(),
layout
=
"NT"
,
grad
=
True
,
use_bias
=
ctx
.
use_bias
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
)
# Deallocate input tensor
clear_tensor_data
(
*
inputmats
)
clear_tensor_data
(
*
inputmats_t
)
if
not
ctx
.
use_bias
:
grad_biases
=
[
None
]
*
ctx
.
num_gemms
def
handle_custom_ddp_from_mcore
(
w
,
wgrad
):
if
w
.
requires_grad
:
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
w
,
"grad_added_to_main_grad"
):
w
.
grad_added_to_main_grad
=
True
if
getattr
(
w
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
# WGRAD
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
inputmats
,
grad_output_mats
,
wgrad_list
],
batched_gemm_wgrad
)
else
:
_
,
grad_biases_
,
_
=
batched_gemm_wgrad
(
inputmats
,
grad_output_mats
,
wgrad_list
)
for
i
in
range
(
ctx
.
num_gemms
):
if
grad_biases
[
i
]
is
None
:
grad_biases
[
i
]
=
grad_biases_
[
i
]
del
grad_biases_
# Deallocate input tensor
clear_tensor_data
(
*
inputmats
)
clear_tensor_data
(
*
inputmats_t
)
def
handle_custom_ddp_from_mcore
(
w
,
wgrad
):
if
w
.
requires_grad
:
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
w
,
"grad_added_to_main_grad"
):
w
.
grad_added_to_main_grad
=
True
if
getattr
(
w
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
else
:
wgrad
=
torch
.
empty
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
else
:
wgrad
=
torch
.
empty
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
wgrad
=
None
return
wgrad
wgrad_list
=
[
handle_custom_ddp_from_mcore
(
w
,
wgrad
)
for
w
,
wgrad
in
zip
(
weights
,
wgrad_list
)
]
else
:
wgrad
=
None
return
wgrad
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
wgrad_list
=
[
handle_custom_ddp_from_mcore
(
w
,
wgrad
)
for
w
,
wgrad
in
zip
(
weights
,
wgrad_list
)
]
if
not
ctx
.
use_bias
or
(
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
()
and
not
ctx
.
fp8
):
grad_biases
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
reduce_and_update_bwd_fp8_tensors
and
not
is_graph_capturing
():
FP8GlobalStateManager
.
reduce_and_update_fp8_tensors
(
forward
=
False
)
...
...
@@ -304,6 +323,7 @@ class _BatchLinear(torch.autograd.Function):
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8_calibration
None
,
# wgrad_store
None
,
# fp8_meta
None
,
# fuse_wgrad_accumulation
None
,
# cpu_offloading
...
...
@@ -381,6 +401,8 @@ class BatchedLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether to delay weight gradient computation
"""
def
__init__
(
...
...
@@ -403,6 +425,7 @@ class BatchedLinear(TransformerEngineBaseModule):
ub_overlap_rs
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
delay_wgrad_compute
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
...
...
@@ -424,6 +447,8 @@ class BatchedLinear(TransformerEngineBaseModule):
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
=
0
,
self
.
num_gemms
,
2
*
self
.
num_gemms
...
...
@@ -588,6 +613,7 @@ class BatchedLinear(TransformerEngineBaseModule):
is_first_microbatch
,
self
.
fp8
,
self
.
fp8_calibration
,
self
.
wgrad_store
,
self
.
fp8_meta
,
self
.
fuse_wgrad_accumulation
,
CPUOffloadEnabled
,
...
...
@@ -617,3 +643,27 @@ class BatchedLinear(TransformerEngineBaseModule):
if
self
.
return_bias
:
return
out
,
[
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
b
in
bias_tensors
]
return
out
def
backward_dw
(
self
):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if
self
.
wgrad_store
is
None
or
not
self
.
wgrad_store
.
delay_wgrad_compute
():
return
with
torch
.
cuda
.
nvtx
.
range
(
"_GroupedLinear_wgrad"
):
(
_
,
grad_biases_
,
_
),
tensor_list
=
self
.
wgrad_store
.
pop
()
wgrad_list
=
tensor_list
[
2
]
if
not
self
.
fuse_wgrad_accumulation
:
for
i
in
range
(
self
.
num_gemms
):
weight_param
=
getattr
(
self
,
f
"weight
{
i
}
"
)
if
weight_param
.
grad
is
None
:
weight_param
.
grad
=
wgrad_list
[
i
].
to
(
weight_param
.
dtype
)
if
self
.
use_bias
:
for
i
in
range
(
self
.
num_gemms
):
bias_param
=
getattr
(
self
,
f
"bias
{
i
}
"
)
if
bias_param
.
grad
is
None
:
bias_param
.
grad
=
grad_biases_
[
i
].
to
(
bias_param
.
dtype
)
del
grad_biases_
del
wgrad_list
del
tensor_list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment