Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
460b006c
Commit
460b006c
authored
May 20, 2025
by
yuguo
Browse files
[DCU] surpport delay_wgrad_compute in batchgemm
parent
196a213f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
119 additions
and
44 deletions
+119
-44
tests/pytorch/test_batched_linear.py
tests/pytorch/test_batched_linear.py
+29
-4
transformer_engine/pytorch/module/batched_linear.py
transformer_engine/pytorch/module/batched_linear.py
+90
-40
No files found.
tests/pytorch/test_batched_linear.py
View file @
460b006c
...
@@ -171,7 +171,7 @@ def reset_global_fp8_state():
...
@@ -171,7 +171,7 @@ def reset_global_fp8_state():
FP8GlobalStateManager
.
reset
()
FP8GlobalStateManager
.
reset
()
def
_test_batched_linear_accuracy
(
def
_test_batched_linear_accuracy
(
block
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
block
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
batch_num
):
):
reset_rng_states
()
reset_rng_states
()
if
fp8
:
if
fp8
:
...
@@ -202,9 +202,31 @@ def _test_batched_linear_accuracy(
...
@@ -202,9 +202,31 @@ def _test_batched_linear_accuracy(
)
)
loss
=
out
.
sum
()
loss
=
out
.
sum
()
loss
.
backward
()
loss
.
backward
()
if
delay_wgrad_compute
:
if
isinstance
(
block
,
BatchedLinear
):
block
.
backward_dw
()
else
:
for
i
in
range
(
num_gemms
):
block
[
i
].
backward_dw
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
outputs
=
[
out
,
inp_hidden_states
.
grad
]
outputs
=
[
out
,
inp_hidden_states
.
grad
]
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
if
isinstance
(
block
,
BatchedLinear
):
if
getattr
(
p
,
"main_grad"
,
None
)
is
not
None
:
for
j
in
range
(
batch_num
):
outputs
.
append
(
p
.
main_grad
[
p
.
main_grad
.
shape
[
0
]
//
batch_num
*
j
:
p
.
main_grad
.
shape
[
0
]
//
batch_num
*
(
j
+
1
)])
assert
p
.
grad
is
None
# grad should be None if fuse_wgrad_accumulation is True
else
:
for
j
in
range
(
batch_num
):
outputs
.
append
(
p
.
grad
[
p
.
grad
.
shape
[
0
]
//
batch_num
*
j
:
p
.
grad
.
shape
[
0
]
//
batch_num
*
(
j
+
1
)])
else
:
if
getattr
(
p
,
"main_grad"
,
None
)
is
not
None
:
outputs
.
append
(
p
.
main_grad
)
assert
p
.
grad
is
None
# grad should be None if fuse_wgrad_accumulation is True
else
:
outputs
.
append
(
p
.
grad
)
return
outputs
return
outputs
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
...
@@ -215,6 +237,7 @@ def _test_batched_linear_accuracy(
...
@@ -215,6 +237,7 @@ def _test_batched_linear_accuracy(
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"delay_wgrad_compute"
,
all_boolean
)
def
test_batched_linear_accuracy
(
def
test_batched_linear_accuracy
(
dtype
,
dtype
,
num_gemms
,
num_gemms
,
...
@@ -224,6 +247,7 @@ def test_batched_linear_accuracy(
...
@@ -224,6 +247,7 @@ def test_batched_linear_accuracy(
recipe
,
recipe
,
fp8_model_params
,
fp8_model_params
,
fuse_wgrad_accumulation
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
parallel_mode
=
None
,
parallel_mode
=
None
,
):
):
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
...
@@ -250,6 +274,7 @@ def test_batched_linear_accuracy(
...
@@ -250,6 +274,7 @@ def test_batched_linear_accuracy(
parallel_mode
=
parallel_mode
,
parallel_mode
=
parallel_mode
,
device
=
"cuda"
,
device
=
"cuda"
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
delay_wgrad_compute
=
delay_wgrad_compute
,
).
eval
()
).
eval
()
sequential_linear
=
torch
.
nn
.
ModuleList
(
sequential_linear
=
torch
.
nn
.
ModuleList
(
[
[
...
@@ -281,10 +306,10 @@ def test_batched_linear_accuracy(
...
@@ -281,10 +306,10 @@ def test_batched_linear_accuracy(
sequential_linear
[
i
*
batch_num
+
j
].
weight
.
main_grad
=
weight_i
.
main_grad
[
weight_i
.
main_grad
.
shape
[
0
]
//
batch_num
*
j
:
weight_i
.
main_grad
.
shape
[
0
]
//
batch_num
*
(
j
+
1
)].
clone
()
sequential_linear
[
i
*
batch_num
+
j
].
weight
.
main_grad
=
weight_i
.
main_grad
[
weight_i
.
main_grad
.
shape
[
0
]
//
batch_num
*
j
:
weight_i
.
main_grad
.
shape
[
0
]
//
batch_num
*
(
j
+
1
)].
clone
()
outputs_ref
=
_test_batched_linear_accuracy
(
outputs_ref
=
_test_batched_linear_accuracy
(
sequential_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
sequential_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
batch_num
)
)
outputs
=
_test_batched_linear_accuracy
(
outputs
=
_test_batched_linear_accuracy
(
batched_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
batched_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
batch_num
)
)
# Shoule be bit-wise match
# Shoule be bit-wise match
...
@@ -292,4 +317,4 @@ def test_batched_linear_accuracy(
...
@@ -292,4 +317,4 @@ def test_batched_linear_accuracy(
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
6e-3
,
atol
=
6e-3
)
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
6e-3
,
atol
=
6e-3
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_batched_linear_accuracy
(
torch
.
float32
,
2
,
1
,
"126m"
,
False
,
recipe
.
Float8CurrentScaling
(),
True
,
True
)
test_batched_linear_accuracy
(
torch
.
float32
,
2
,
1
,
"126m"
,
False
,
recipe
.
Float8CurrentScaling
(),
True
,
True
,
True
)
transformer_engine/pytorch/module/batched_linear.py
View file @
460b006c
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
import
os
import
os
import
logging
import
logging
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
,
List
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
,
List
import
functools
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
...
@@ -18,6 +18,7 @@ from .base import (
...
@@ -18,6 +18,7 @@ from .base import (
_2X_ACC_DGRAD
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
_2X_ACC_WGRAD
,
)
)
from
._common
import
WeightGradStore
from
..fp8
import
get_fp8_te_dtype
,
FP8GlobalStateManager
from
..fp8
import
get_fp8_te_dtype
,
FP8GlobalStateManager
from
..utils
import
(
from
..utils
import
(
divide
,
divide
,
...
@@ -82,6 +83,7 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -82,6 +83,7 @@ class _BatchLinear(torch.autograd.Function):
is_first_microbatch
:
Union
[
bool
,
None
],
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8
:
bool
,
fp8_calibration
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
fp8_meta
:
Dict
[
str
,
Any
],
fp8_meta
:
Dict
[
str
,
Any
],
fuse_wgrad_accumulation
:
bool
,
fuse_wgrad_accumulation
:
bool
,
cpu_offloading
:
bool
,
cpu_offloading
:
bool
,
...
@@ -183,6 +185,7 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -183,6 +185,7 @@ class _BatchLinear(torch.autograd.Function):
ctx
.
tp_size
=
tp_size
ctx
.
tp_size
=
tp_size
ctx
.
requires_dgrad
=
inp
.
requires_grad
ctx
.
requires_dgrad
=
inp
.
requires_grad
ctx
.
reduce_and_update_bwd_fp8_tensors
=
False
ctx
.
reduce_and_update_bwd_fp8_tensors
=
False
ctx
.
wgrad_store
=
wgrad_store
# [*, in_features] -> [*, out_features] except first dimension changes for SP
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return
out
.
view
(
-
1
,
*
inp
.
shape
[
1
:
-
1
],
out
.
shape
[
-
1
])
return
out
.
view
(
-
1
,
*
inp
.
shape
[
1
:
-
1
],
out
.
shape
[
-
1
])
...
@@ -246,53 +249,69 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -246,53 +249,69 @@ class _BatchLinear(torch.autograd.Function):
torch
.
empty
(
w
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
w
.
device
)
torch
.
empty
(
w
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
w
.
device
)
for
w
in
weights
for
w
in
weights
]
]
# WGRAD
batched_gemm_wgrad
=
functools
.
partial
(
_
,
grad_biases
,
_
=
batchgemm
(
batchgemm
,
inputmats
,
dtype
=
ctx
.
activation_dtype
,
grad_output_mats
,
workspaces
=
get_multi_stream_cublas_batchgemm_workspace
(),
wgrad_list
,
ctx
.
activation_dtype
,
get_multi_stream_cublas_batchgemm_workspace
(),
layout
=
"NT"
,
layout
=
"NT"
,
grad
=
True
,
grad
=
True
,
use_bias
=
ctx
.
use_bias
,
use_bias
=
ctx
.
use_bias
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
)
)
# WGRAD
# Deallocate input tensor
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
clear_tensor_data
(
*
inputmats
)
ctx
.
wgrad_store
.
put
([
inputmats
,
grad_output_mats
,
wgrad_list
],
batched_gemm_wgrad
)
clear_tensor_data
(
*
inputmats_t
)
else
:
_
,
grad_biases_
,
_
=
batched_gemm_wgrad
(
inputmats
,
grad_output_mats
,
wgrad_list
)
if
not
ctx
.
use_bias
:
grad_biases
=
[
None
]
*
ctx
.
num_gemms
for
i
in
range
(
ctx
.
num_gemms
):
if
grad_biases
[
i
]
is
None
:
def
handle_custom_ddp_from_mcore
(
w
,
wgrad
):
grad_biases
[
i
]
=
grad_biases_
[
i
]
if
w
.
requires_grad
:
del
grad_biases_
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
w
,
"grad_added_to_main_grad"
):
w
.
grad_added_to_main_grad
=
True
# Deallocate input tensor
if
getattr
(
w
,
"zero_out_wgrad"
,
False
):
clear_tensor_data
(
*
inputmats
)
wgrad
=
torch
.
zeros
(
clear_tensor_data
(
*
inputmats_t
)
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
def
handle_custom_ddp_from_mcore
(
w
,
wgrad
):
device
=
torch
.
cuda
.
current_device
(),
if
w
.
requires_grad
:
requires_grad
=
False
,
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
w
,
"grad_added_to_main_grad"
):
)
w
.
grad_added_to_main_grad
=
True
if
getattr
(
w
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
else
:
wgrad
=
torch
.
empty
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
else
:
else
:
wgrad
=
torch
.
empty
(
wgrad
=
None
w
.
main_grad
.
shape
,
return
wgrad
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
wgrad_list
=
[
requires_grad
=
False
,
handle_custom_ddp_from_mcore
(
w
,
wgrad
)
for
w
,
wgrad
in
zip
(
weights
,
wgrad_list
)
)
]
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
else
:
else
:
wgrad
=
None
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
return
wgrad
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
wgrad_list
=
[
if
not
ctx
.
use_bias
or
(
handle_custom_ddp_from_mcore
(
w
,
wgrad
)
for
w
,
wgrad
in
zip
(
weights
,
wgrad_list
)
ctx
.
wgrad_store
is
not
None
]
and
ctx
.
wgrad_store
.
delay_wgrad_compute
()
and
not
ctx
.
fp8
):
grad_biases
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
reduce_and_update_bwd_fp8_tensors
and
not
is_graph_capturing
():
if
ctx
.
reduce_and_update_bwd_fp8_tensors
and
not
is_graph_capturing
():
FP8GlobalStateManager
.
reduce_and_update_fp8_tensors
(
forward
=
False
)
FP8GlobalStateManager
.
reduce_and_update_fp8_tensors
(
forward
=
False
)
...
@@ -304,6 +323,7 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -304,6 +323,7 @@ class _BatchLinear(torch.autograd.Function):
None
,
# is_first_microbatch
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8
None
,
# fp8_calibration
None
,
# fp8_calibration
None
,
# wgrad_store
None
,
# fp8_meta
None
,
# fp8_meta
None
,
# fuse_wgrad_accumulation
None
,
# fuse_wgrad_accumulation
None
,
# cpu_offloading
None
,
# cpu_offloading
...
@@ -381,6 +401,8 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -381,6 +401,8 @@ class BatchedLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether to delay weight gradient computation
"""
"""
def
__init__
(
def
__init__
(
...
@@ -403,6 +425,7 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -403,6 +425,7 @@ class BatchedLinear(TransformerEngineBaseModule):
ub_overlap_rs
:
bool
=
False
,
ub_overlap_rs
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
ub_name
:
Optional
[
str
]
=
None
,
delay_wgrad_compute
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -424,6 +447,8 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -424,6 +447,8 @@ class BatchedLinear(TransformerEngineBaseModule):
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
rng_tracker_name
=
rng_tracker_name
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
=
0
,
self
.
num_gemms
,
2
*
self
.
num_gemms
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
=
0
,
self
.
num_gemms
,
2
*
self
.
num_gemms
...
@@ -588,6 +613,7 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -588,6 +613,7 @@ class BatchedLinear(TransformerEngineBaseModule):
is_first_microbatch
,
is_first_microbatch
,
self
.
fp8
,
self
.
fp8
,
self
.
fp8_calibration
,
self
.
fp8_calibration
,
self
.
wgrad_store
,
self
.
fp8_meta
,
self
.
fp8_meta
,
self
.
fuse_wgrad_accumulation
,
self
.
fuse_wgrad_accumulation
,
CPUOffloadEnabled
,
CPUOffloadEnabled
,
...
@@ -617,3 +643,27 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -617,3 +643,27 @@ class BatchedLinear(TransformerEngineBaseModule):
if
self
.
return_bias
:
if
self
.
return_bias
:
return
out
,
[
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
b
in
bias_tensors
]
return
out
,
[
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
b
in
bias_tensors
]
return
out
return
out
def
backward_dw
(
self
):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if
self
.
wgrad_store
is
None
or
not
self
.
wgrad_store
.
delay_wgrad_compute
():
return
with
torch
.
cuda
.
nvtx
.
range
(
"_GroupedLinear_wgrad"
):
(
_
,
grad_biases_
,
_
),
tensor_list
=
self
.
wgrad_store
.
pop
()
wgrad_list
=
tensor_list
[
2
]
if
not
self
.
fuse_wgrad_accumulation
:
for
i
in
range
(
self
.
num_gemms
):
weight_param
=
getattr
(
self
,
f
"weight
{
i
}
"
)
if
weight_param
.
grad
is
None
:
weight_param
.
grad
=
wgrad_list
[
i
].
to
(
weight_param
.
dtype
)
if
self
.
use_bias
:
for
i
in
range
(
self
.
num_gemms
):
bias_param
=
getattr
(
self
,
f
"bias
{
i
}
"
)
if
bias_param
.
grad
is
None
:
bias_param
.
grad
=
grad_biases_
[
i
].
to
(
bias_param
.
dtype
)
del
grad_biases_
del
wgrad_list
del
tensor_list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment