Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
27ddce40
Commit
27ddce40
authored
Oct 11, 2025
by
wenjh
Browse files
Merge branch 'nv_main'
parents
d262ef4c
5b3092a0
Changes
208
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
619 additions
and
169 deletions
+619
-169
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+54
-37
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+29
-11
transformer_engine/pytorch/onnx_extensions.py
transformer_engine/pytorch/onnx_extensions.py
+51
-11
transformer_engine/pytorch/ops/_common.py
transformer_engine/pytorch/ops/_common.py
+3
-1
transformer_engine/pytorch/ops/basic/__init__.py
transformer_engine/pytorch/ops/basic/__init__.py
+1
-1
transformer_engine/pytorch/ops/basic/activation.py
transformer_engine/pytorch/ops/basic/activation.py
+142
-13
transformer_engine/pytorch/ops/basic/basic_linear.py
transformer_engine/pytorch/ops/basic/basic_linear.py
+33
-11
transformer_engine/pytorch/ops/basic/dropout.py
transformer_engine/pytorch/ops/basic/dropout.py
+55
-17
transformer_engine/pytorch/ops/basic/l2normalization.py
transformer_engine/pytorch/ops/basic/l2normalization.py
+6
-3
transformer_engine/pytorch/ops/basic/layer_norm.py
transformer_engine/pytorch/ops/basic/layer_norm.py
+5
-2
transformer_engine/pytorch/ops/basic/rmsnorm.py
transformer_engine/pytorch/ops/basic/rmsnorm.py
+5
-2
transformer_engine/pytorch/ops/fused/__init__.py
transformer_engine/pytorch/ops/fused/__init__.py
+4
-0
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
+133
-0
transformer_engine/pytorch/ops/fused/backward_linear_add.py
transformer_engine/pytorch/ops/fused/backward_linear_add.py
+24
-14
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
...sformer_engine/pytorch/ops/fused/backward_linear_scale.py
+24
-14
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
...ngine/pytorch/ops/fused/forward_linear_bias_activation.py
+6
-7
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
...ormer_engine/pytorch/ops/fused/forward_linear_bias_add.py
+7
-8
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
...rmer_engine/pytorch/ops/fused/forward_linear_scale_add.py
+4
-1
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+29
-15
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+4
-1
No files found.
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
27ddce40
...
...
@@ -94,39 +94,45 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
# bf16 (recipe is None):
return
{
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
None
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
None
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
"reglu"
:
(
tex
.
reglu
,
tex
.
dreglu
,
None
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
"qgelu"
:
(
tex
.
qgelu
,
tex
.
dqgelu
,
None
),
"qgeglu"
:
(
tex
.
qgeglu
,
tex
.
dqgeglu
,
None
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
None
),
"reglu"
:
(
tex
.
reglu
,
tex
.
dreglu
,
None
),
"srelu"
:
(
tex
.
srelu
,
tex
.
dsrelu
,
None
),
"sreglu"
:
(
tex
.
sreglu
,
tex
.
dsreglu
,
None
),
"silu"
:
(
tex
.
silu
,
tex
.
dsilu
,
None
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
}
if
recipe
.
delayed
()
or
recipe
.
mxfp8
():
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
return
{
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
tex
.
dbias_dgelu
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
tex
.
dbias_drelu
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
"reglu"
:
(
tex
.
reglu
,
tex
.
dreglu
,
None
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
"qgelu"
:
(
tex
.
qgelu
,
tex
.
dqgelu
,
tex
.
dbias_dqgelu
),
"qgeglu"
:
(
tex
.
qgeglu
,
tex
.
dqgeglu
,
None
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
tex
.
dbias_drelu
),
"reglu"
:
(
tex
.
reglu
,
tex
.
dreglu
,
None
),
"srelu"
:
(
tex
.
srelu
,
tex
.
dsrelu
,
tex
.
dbias_dsrelu
),
"sreglu"
:
(
tex
.
sreglu
,
tex
.
dsreglu
,
None
),
"silu"
:
(
tex
.
silu
,
tex
.
dsilu
,
tex
.
dbias_dsilu
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
}
# no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling: []
if
recipe
.
float8_current_scaling
()
or
recipe
.
float8_block_scaling
():
return
{
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
None
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
None
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
"reglu"
:
(
tex
.
reglu
,
tex
.
dreglu
,
None
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
"qgelu"
:
(
tex
.
qgelu
,
tex
.
dqgelu
,
None
),
"qgeglu"
:
(
tex
.
qgeglu
,
tex
.
dqgeglu
,
None
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
None
),
"reglu"
:
(
tex
.
reglu
,
tex
.
dreglu
,
None
),
"srelu"
:
(
tex
.
srelu
,
tex
.
dsrelu
,
None
),
"sreglu"
:
(
tex
.
sreglu
,
tex
.
dsreglu
,
None
),
"silu"
:
(
tex
.
silu
,
tex
.
dsilu
,
None
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
}
raise
NotImplementedError
(
f
"Unhandled recipe type
{
recipe
}
"
)
...
...
@@ -308,7 +314,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag
:
# Copy into Userbuffers buffer
ub_obj_lnout
=
get_ub
(
"fc1_fprop"
)
ub_obj_lnout
=
get_ub
(
"fc1_fprop"
,
fp8
)
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj_lnout
,
ln_out
,
...
...
@@ -446,20 +452,25 @@ class _LayerNormMLP(torch.autograd.Function):
act_out
=
activation_func
(
fc1_out
,
None
)
act_out
=
tex
.
quantize
(
act_out
,
fc2_input_quantizer
)
else
:
act_out
=
activation_func
(
fc1_out
,
fc2_input_quantizer
)
if
fp8_calibration
:
act_out
=
activation_func
(
fc1_out
,
None
)
else
:
act_out
=
activation_func
(
fc1_out
,
fc2_input_quantizer
)
if
not
is_grad_enabled
:
clear_tensor_data
(
fc1_out
)
if
fp8_calibration
:
fc2_input_quantizer
.
calibrate
(
act_out
)
fc2_weight_quantizer
.
calibrate
(
fc2_weight
)
if
not
fp8
and
fp8_calibration
:
if
fc2_input_quantizer
is
not
None
:
fc2_input_quantizer
.
calibrate
(
act_out
)
if
fc2_weight_quantizer
is
not
None
:
fc2_weight_quantizer
.
calibrate
(
fc2_weight
)
# Configure Userbuffers reduce-scatter if needed
ub_obj_fc2out
=
None
reduce_scatter_out
=
None
if
ub_overlap_rs
:
ub_obj_fc2out
=
get_ub
(
"fc2_fprop"
)
ub_obj_fc2out
=
get_ub
(
"fc2_fprop"
,
fp8
)
dim_size
=
list
(
act_out
.
size
())
dim_size
[
0
]
//=
tp_world_size
dim_size
[
-
1
]
=
fc2_weight
.
size
(
0
)
...
...
@@ -741,7 +752,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Cast to expected dtype and perform tensor-parallel communication
ub_obj_fc2_dgrad
=
None
if
ctx
.
ub_overlap_ag
:
ub_obj_fc2_dgrad
=
get_ub
(
"fc2_dgrad"
)
ub_obj_fc2_dgrad
=
get_ub
(
"fc2_dgrad"
,
ctx
.
fp8
)
ctx
.
ub_obj_gradout
=
ub_obj_fc2_dgrad
(
grad_output
,
...
...
@@ -765,7 +776,7 @@ class _LayerNormMLP(torch.autograd.Function):
# wgrad GEMM requires input with column-wise usage
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
if
ctx
.
ub_bulk_dgrad
:
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
,
ctx
.
fp8
)
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj_fc1_dgrad
,
ln_out
,
...
...
@@ -870,7 +881,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc2_dgrad
.
get_communication_stream
()
)
ub_obj_fc2_wgrad
=
get_ub
(
"fc2_wgrad"
)
ub_obj_fc2_wgrad
=
get_ub
(
"fc2_wgrad"
,
ctx
.
fp8
)
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
...
@@ -1045,16 +1056,16 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_shape
=
[
reduce
(
multiply_op
,
inputmat
.
shape
[:
-
1
]),
inputmat
.
shape
[
-
1
]]
if
ctx
.
ub_overlap_rs_dgrad
:
# Overlap DGRAD+RS
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
,
ctx
.
fp8
)
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
RS
else
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap ln_out all-gather with DGRAD compute
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
,
ctx
.
fp8
)
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
AG
if
ctx
.
ub_bulk_wgrad
:
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute
ub_obj_fc1_wgrad
=
get_ub
(
"fc1_wgrad"
)
ub_obj_fc1_wgrad
=
get_ub
(
"fc1_wgrad"
,
ctx
.
fp8
)
ub_type_fc1_wgrad
=
tex
.
CommOverlapType
.
RS
# --------------------------------------------------
...
...
@@ -1402,7 +1413,7 @@ class _LayerNormMLP(torch.autograd.Function):
class
LayerNormMLP
(
TransformerEngineBaseModule
):
r
"""
Applies layer normalization on the input followed by the MLP module, consisting of
2 successive linear transformations, separated by the
GeLU
activation.
2 successive linear transformations, separated by the activation
function
.
Parameters
----------
...
...
@@ -1418,7 +1429,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
type of normalization applied.
activation : str, default = 'gelu'
activation function used.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu', 'qgelu', 'srelu'.
Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'.
init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
...
...
@@ -1559,7 +1571,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
gemm_gelu_fusion
=
(
bool
(
int
(
os
.
getenv
(
"NVTE_GEMM_GELU_FUSION"
,
"0"
)))
and
self
.
activation
==
"gelu"
and
((
_ub_communicators
is
None
)
or
(
not
get_ub
(
"fc1_fprop"
).
is_atomic_gemm
()))
and
all
(
(
"fc1_fprop"
,
use_fp8
)
not
in
_ub_communicators
or
not
get_ub
(
"fc1_fprop"
,
use_fp8
).
is_atomic_gemm
()
for
use_fp8
in
[
False
,
True
]
)
)
self
.
name
=
name
...
...
@@ -1619,7 +1635,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
layer_norm_bias
=
None
# FC1 init
if
self
.
activation
in
[
"
r
eglu"
,
"geglu"
,
"
qg
eglu"
,
"swiglu"
]:
if
self
.
activation
in
[
"
g
eglu"
,
"
q
geglu"
,
"
reglu"
,
"sr
eglu"
,
"swiglu"
]:
fc1_output_features
=
2
*
self
.
size_per_partition
else
:
fc1_output_features
=
self
.
size_per_partition
...
...
@@ -1777,7 +1793,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fp8_output
=
False
if
self
.
ub_overlap_rs
:
if
get_ub
(
"fc2_fprop"
).
is_fp8_ubuf
():
if
get_ub
(
"fc2_fprop"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_output
=
True
with
torch
.
cuda
.
device
(
...
...
@@ -1915,7 +1931,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_grad_output_quantizer
,
)
=
[
None
]
*
10
fc1_weight_quantizer
,
fc2_weight_quantizer
=
self
.
_get_weight_quantizers
()
if
self
.
fp8
:
if
self
.
fp8
or
self
.
fp8_calibration
:
fc1_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
fc1_input_quantizer
.
internal
=
True
fc2_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_INPUT
]
...
...
@@ -2001,14 +2017,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
activation_map
=
{
"gelu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
),
"relu"
:
torch
.
nn
.
functional
.
relu
,
"geglu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
"reglu"
:
lambda
x
:
torch
.
nn
.
functional
.
relu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
"swiglu"
:
lambda
x
:
torch
.
nn
.
functional
.
silu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
"qgelu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
),
"qgeglu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
.
chunk
(
2
,
-
1
)[
0
],
approximate
=
"tanh"
)
*
x
.
chunk
(
2
,
-
1
)[
1
],
"qgelu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
),
"srelu"
:
torch
.
nn
.
functional
.
softplus
,
"relu"
:
torch
.
nn
.
functional
.
relu
,
"reglu"
:
lambda
x
:
torch
.
nn
.
functional
.
relu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
"srelu"
:
lambda
x
:
torch
.
nn
.
functional
.
relu
(
x
)
**
2
,
"sreglu"
:
lambda
x
:
torch
.
nn
.
functional
.
relu
(
x
.
chunk
(
2
,
-
1
)[
0
])
**
2
*
x
.
chunk
(
2
,
-
1
)[
1
],
"silu"
:
torch
.
nn
.
functional
.
silu
,
"swiglu"
:
lambda
x
:
torch
.
nn
.
functional
.
silu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
}
if
self
.
activation
not
in
activation_map
:
raise
ValueError
(
f
"Unsupported activation in onnx export:
{
self
.
activation
}
"
)
...
...
@@ -2129,7 +2148,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
if
not
self
.
fp8
:
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
return
[
None
,
None
]
fc1_weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
fc1_weight_quantizer
.
internal
=
True
...
...
@@ -2182,10 +2201,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
if
self
.
fc1_bias
.
grad
is
None
:
self
.
fc1_bias
.
grad
=
fc1_bias_grad
.
to
(
self
.
fc1_bias
.
dtype
)
if
not
self
.
fuse_wgrad_accumulation
:
if
self
.
fc2_weight
.
grad
is
None
:
self
.
fc2_weight
.
grad
=
fc2_wgrad
.
to
(
self
.
fc2_weight
.
dtype
)
if
self
.
fc1_weight
.
grad
is
None
:
self
.
fc1_weight
.
grad
=
fc1_wgrad
.
to
(
self
.
fc1_weight
.
dtype
)
self
.
fc2_weight
.
grad
=
fc2_wgrad
.
to
(
self
.
fc2_weight
.
dtype
)
self
.
fc1_weight
.
grad
=
fc1_wgrad
.
to
(
self
.
fc1_weight
.
dtype
)
del
fc2_bias_grad_
del
fc2_wgrad
del
fc1_wgrad
...
...
transformer_engine/pytorch/module/linear.py
View file @
27ddce40
...
...
@@ -147,10 +147,10 @@ class _Linear(torch.autograd.Function):
ub_obj
=
None
ub_type
=
None
if
ub_overlap_rs_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
RS
elif
ub_overlap_ag_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
AG
# ------------------------------------------------------
...
...
@@ -319,6 +319,13 @@ class _Linear(torch.autograd.Function):
# Finished forward GEMM...
# ------------------------------------------------------
# Deallocate GEMM input tensor if no longer needed
# TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically
# deallocated by GC. Manually deallocating is a temporary hack.
if
with_input_all_gather_nccl
:
clear_tensor_data
(
inputmat_total
)
inputmat_total
=
None
# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
...
...
@@ -544,23 +551,23 @@ class _Linear(torch.autograd.Function):
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
if
ctx
.
ub_overlap_ag
:
# Overlap grad_output all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
elif
ctx
.
ub_overlap_rs_dgrad
:
# Overlap dgrad reduce-scatter with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
else
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap inputmat all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
if
ctx
.
ub_bulk_wgrad
:
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
,
ctx
.
fp8
)
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
# --------------------------------------------------
...
...
@@ -793,7 +800,7 @@ class _Linear(torch.autograd.Function):
dgrad_send_stream
,
dgrad_recv_stream
=
ub_obj_dgrad
.
get_communication_stream
()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_overlap_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
,
ctx
.
fp8
)
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
...
@@ -905,9 +912,16 @@ class _Linear(torch.autograd.Function):
grad_bias
=
grad_bias_
del
grad_bias_
# Deallocate
input
tensor if permitted
# Deallocate tensor
s
if permitted
if
ctx
.
owns_input
:
# Input tensor is internal
clear_tensor_data
(
inputmat_total
)
elif
ctx
.
backward_input_needs_gather
:
# Gathered input tensor is internal
clear_tensor_data
(
inputmat_total
)
if
ctx
.
parallel_mode
==
"row"
and
ctx
.
sequence_parallel
:
# Gathered grad output tensor is internal
clear_tensor_data
(
grad_output
)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if
ctx
.
ub_bulk_wgrad
:
...
...
@@ -1404,10 +1418,14 @@ class Linear(TransformerEngineBaseModule):
is_first_microbatch
=
False
if
self
.
ub_overlap_rs_fprop
:
if
get_ub
(
self
.
ub_name
+
"_fprop"
).
is_fp8_ubuf
():
if
get_ub
(
self
.
ub_name
+
"_fprop"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_output
=
True
if
self
.
ub_overlap_rs_dgrad
:
if
get_ub
(
self
.
ub_name
+
"_dgrad"
).
is_fp8_ubuf
():
if
get_ub
(
self
.
ub_name
+
"_dgrad"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_grad
=
True
with
torch
.
cuda
.
device
(
...
...
@@ -1666,7 +1684,7 @@ class Linear(TransformerEngineBaseModule):
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
if
not
self
.
fp8
:
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
return
[
None
]
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
.
internal
=
True
...
...
transformer_engine/pytorch/onnx_extensions.py
View file @
27ddce40
...
...
@@ -112,7 +112,9 @@ schema = defs.OpSchema(
doc
=
"TRT FP8 Quantize Linear used for inference."
,
inputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"tensor"
,
"tensor(float)"
,
"Input tensor to quantize"
),
defs
.
OpSchema
.
FormalParameter
(
"scale"
,
"tensor(float)"
,
"Scale factor for quantization"
),
defs
.
OpSchema
.
FormalParameter
(
"scale_inv"
,
"tensor(float)"
,
"Inverse scale factor for quantization"
),
],
outputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"output"
,
"tensor(uint8)"
,
"Quantized output tensor"
)],
)
...
...
@@ -126,11 +128,10 @@ TRT_FP8QuantizeLinear = onnxscript.values.Op(
@
torch
.
library
.
custom_op
(
"tex::fp8_dequantize"
,
mutates_args
=
[])
def
onnx_dequantize_fp8_op
(
tensor
:
torch
.
Tensor
,
scale
:
float
)
->
torch
.
Tensor
:
def
onnx_dequantize_fp8_op
(
tensor
:
torch
.
Tensor
,
scale
_inv
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Dequantize from Float8Tensor used for inference."""
scale_tensor
=
torch
.
tensor
(
scale
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
quantizer
=
Float8Quantizer
(
scale_
tensor
,
torch
.
zeros
(
1
).
to
(
tensor
.
device
),
tex
.
DType
.
kFloat8E4M3
1
/
scale_
inv
,
torch
.
zeros
(
1
).
to
(
tensor
.
device
),
tex
.
DType
.
kFloat8E4M3
)
quantizer_tensor
=
quantizer
.
create_tensor_from_data
(
tensor
,
fake_dtype
=
torch
.
float32
)
return
quantizer_tensor
.
dequantize
()
...
...
@@ -143,10 +144,9 @@ def _(tensor: torch.Tensor, _) -> torch.Tensor:
def
onnx_dequantize_fp8_symbolic
(
tensor
:
onnxscript
.
onnx_types
.
TensorType
,
scale
:
float
tensor
:
onnxscript
.
onnx_types
.
TensorType
,
scale
_inv
:
onnxscript
.
onnx_types
.
TensorType
)
->
onnxscript
.
onnx_types
.
TensorType
:
"""Symbolic dequantize from Float8Tensor used for inference."""
scale_inv
=
op
.
Constant
(
value_float
=
1
/
scale
)
return
TRT_FP8DequantizeLinear
(
tensor
,
scale_inv
)
...
...
@@ -157,7 +157,9 @@ schema = defs.OpSchema(
doc
=
"TRT FP8 Dequantize Linear from Float8Tensor used for inference."
,
inputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"tensor"
,
"tensor(uint8)"
,
"Input tensor to dequantize"
),
defs
.
OpSchema
.
FormalParameter
(
"scale"
,
"tensor(float)"
,
"Scale factor for dequantization"
),
defs
.
OpSchema
.
FormalParameter
(
"scale_inv"
,
"tensor(float)"
,
"Inverse scale factor for dequantization"
),
],
outputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"output"
,
"tensor(float)"
,
"Dequantized output tensor"
)],
)
...
...
@@ -166,6 +168,43 @@ TRT_FP8DequantizeLinear = onnxscript.values.Op(
opset
=
trt_opset
,
name
=
"TRT_FP8DequantizeLinear"
,
op_schema
=
schema
)
# ONNX FP8 Current Scaling Quantization
@
torch
.
library
.
custom_op
(
"tex::fp8_cs_quantize"
,
mutates_args
=
[])
def
onnx_cs_quantize_fp8_op
(
tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Quantize to FP8 with current scaling; returns (uint8, scale_inv)."""
if
tensor
.
dtype
!=
torch
.
float32
:
tensor
=
tensor
.
to
(
torch
.
float32
)
amax
=
tensor
.
abs
().
max
()
eps
=
torch
.
tensor
(
1e-12
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
amax
=
torch
.
maximum
(
amax
,
eps
)
fp8_max
=
torch
.
tensor
(
448
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
scale
=
fp8_max
/
amax
q
=
torch
.
ops
.
tex
.
fp8_quantize
(
tensor
,
scale
)
scale_inv
=
1
/
scale
return
q
,
scale_inv
@
onnx_cs_quantize_fp8_op
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
empty
(
tensor
.
shape
,
dtype
=
torch
.
uint8
,
device
=
tensor
.
device
),
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
def
onnx_quantize_fp8_cs_symbolic
(
tensor
:
onnxscript
.
onnx_types
.
TensorType
,
):
"""Symbolic quantize with current scaling; computes scale_inv from tensor."""
# scale_inv = 1 / max(abs(tensor))
amax
=
op
.
ReduceMax
(
op
.
Abs
(
tensor
),
keepdims
=
0
)
eps
=
op
.
Constant
(
value_float
=
1.0e-12
)
amax
=
op
.
Max
(
amax
,
eps
)
scale_inv
=
op
.
Div
(
amax
,
op
.
Constant
(
value_float
=
448.0
))
q
=
TRT_FP8QuantizeLinear
(
tensor
,
scale_inv
)
return
q
,
scale_inv
# ONNX MXFP8 Quantization
...
...
@@ -194,12 +233,12 @@ def onnx_quantize_mxfp8_symbolic(
tensor
:
onnxscript
.
onnx_types
.
TensorType
,
)
->
Tuple
[
onnxscript
.
onnx_types
.
TensorType
,
onnxscript
.
onnx_types
.
TensorType
]:
"""Symbolic quantize to MXFP8Tensor used for inference."""
tensor_out
,
scale_inv_out
=
TRT_MXFP8Quantize
Linear
(
tensor
)
tensor_out
,
scale_inv_out
=
TRT_MXFP8
Dynamic
Quantize
(
tensor
)
return
tensor_out
,
scale_inv_out
schema
=
defs
.
OpSchema
(
name
=
"TRT_MXFP8Quantize
Linear
"
,
name
=
"TRT_MXFP8
Dynamic
Quantize"
,
domain
=
"trt"
,
since_version
=
1
,
doc
=
"TRT MXFP8 Quantize Linear used for inference."
,
...
...
@@ -214,8 +253,8 @@ schema = defs.OpSchema(
],
)
TRT_MXFP8Quantize
Linear
=
onnxscript
.
values
.
Op
(
opset
=
trt_opset
,
name
=
"TRT_MXFP8Quantize
Linear
"
,
op_schema
=
schema
TRT_MXFP8
Dynamic
Quantize
=
onnxscript
.
values
.
Op
(
opset
=
trt_opset
,
name
=
"TRT_MXFP8
Dynamic
Quantize"
,
op_schema
=
schema
)
...
...
@@ -356,6 +395,7 @@ te_translation_table = {
torch
.
ops
.
tex
.
gemm_inf
.
default
:
onnx_gemm_inf_symbolic
,
torch
.
ops
.
tex
.
fp8_quantize
.
default
:
onnx_quantize_fp8_symbolic
,
torch
.
ops
.
tex
.
fp8_dequantize
.
default
:
onnx_dequantize_fp8_symbolic
,
torch
.
ops
.
tex
.
fp8_cs_quantize
.
default
:
onnx_quantize_fp8_cs_symbolic
,
torch
.
ops
.
tex
.
mxfp8_quantize
.
default
:
onnx_quantize_mxfp8_symbolic
,
torch
.
ops
.
tex
.
mxfp8_dequantize
.
default
:
onnx_dequantize_mxfp8_symbolic
,
torch
.
ops
.
tex
.
layernorm
.
default
:
onnx_layernorm_symbolic
,
...
...
transformer_engine/pytorch/ops/_common.py
View file @
27ddce40
...
...
@@ -29,7 +29,9 @@ def maybe_dequantize(
if
is_quantized_tensor
(
tensor
):
return
tensor
.
dequantize
(
dtype
=
dtype
)
if
dtype
is
not
None
and
tensor
.
dtype
!=
dtype
:
return
tensor
.
to
(
dtype
)
tensor
=
tensor
.
to
(
dtype
)
if
not
tensor
.
is_contiguous
():
tensor
=
tensor
.
contiguous
()
return
tensor
...
...
transformer_engine/pytorch/ops/basic/__init__.py
View file @
27ddce40
...
...
@@ -4,7 +4,7 @@
"""Single tensor operations supported by the operation fuser."""
from
.activation
import
GELU
,
Re
LU
,
GEGLU
,
Re
G
LU
,
SwiGLU
from
.activation
import
GELU
,
GEG
LU
,
QGELU
,
Q
GEGLU
,
Re
LU
,
ReGLU
,
SReLU
,
SReGLU
,
Si
LU
,
SwiGLU
from
.add_extra_input
import
AddExtraInput
from
.all_gather
import
AllGather
from
.all_reduce
import
AllReduce
...
...
transformer_engine/pytorch/ops/basic/activation.py
View file @
27ddce40
...
...
@@ -11,11 +11,25 @@ from typing import Optional
import
torch
import
transformer_engine_torch
as
tex
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Quantizer
from
...utils
import
clear_tensor_data
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_dequantize
__all__
=
[
"GELU"
,
"GEGLU"
,
"QGELU"
,
"QGEGLU"
,
"ReLU"
,
"ReGLU"
,
"SReLU"
,
"SReGLU"
,
"SiLU"
,
"SwiGLU"
,
]
class
_ActivationOperation
(
BasicOperation
,
metaclass
=
abc
.
ABCMeta
):
r
"""Apply activation function
...
...
@@ -97,6 +111,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Save state for backward pass
if
ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x
)
ctx
.
save_for_backward
(
x
)
ctx
.
dtype
=
dtype
ctx
.
prev_op_grad_output_quantizer
=
prev_op_grad_output_quantizer
...
...
@@ -147,37 +163,75 @@ class GELU(_ActivationOperation):
return
tex
.
dgelu
(
*
args
,
**
kwargs
)
class
ReLU
(
_ActivationOperation
):
r
"""Rectified linear unit
class
GEGLU
(
_ActivationOperation
):
r
"""Gaussian Error Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{ReLU}(x) = \max(x,0)
\text{GEGLU}(a,b) = \text{GELU}(a) * b
where
.. math::
\text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right)
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
See `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__.
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
re
lu
(
*
args
,
**
kwargs
)
return
tex
.
geg
lu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
d
re
lu
(
*
args
,
**
kwargs
)
return
tex
.
d
geg
lu
(
*
args
,
**
kwargs
)
class
GEGLU
(
_ActivationOperation
):
r
"""Gaussian error gated linear unit
class
QGELU
(
_ActivationOperation
):
r
"""Quick Gaussian Error Linear Unit
Quick GELU from `HuggingFace<https://github.com/huggingface/transformers/blob/3e93dd295b5343557a83bc07b0b2ea64c926f9b4/src/transformers/activations.py#L90>`__
and `paper<https://github.com/hendrycks/GELUs>`__.
.. math::
\text{QGELU}(x) \approx x * \sigma(1.702 * x)
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
qgelu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
dqgelu
(
*
args
,
**
kwargs
)
class
QGEGLU
(
_ActivationOperation
):
r
"""Quick Gaussian Error Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{GEGLU}(a,b) = \text{GELU}(a) * b
\text{
Q
GEGLU}(a,b) = \text{
Q
GELU}(a) * b
where
.. math::
\text{GELU}(x) \approx
\frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right
)
\text{
Q
GELU}(x) \approx
x * \sigma(1.702 * x
)
.. warning::
...
...
@@ -187,19 +241,33 @@ class GEGLU(_ActivationOperation):
the first half of the input tensor, while PyTorch applies it to
the second half.
See `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__.
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
qgeglu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
dqgeglu
(
*
args
,
**
kwargs
)
class
ReLU
(
_ActivationOperation
):
r
"""Rectified Linear Unit
.. math::
\text{ReLU}(x) = \max(x,0)
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
geg
lu
(
*
args
,
**
kwargs
)
return
tex
.
re
lu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
d
geg
lu
(
*
args
,
**
kwargs
)
return
tex
.
d
re
lu
(
*
args
,
**
kwargs
)
class
ReGLU
(
_ActivationOperation
):
r
"""Rectified
g
ated
l
inear
u
nit
r
"""Rectified
G
ated
L
inear
U
nit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
...
...
@@ -227,6 +295,67 @@ class ReGLU(_ActivationOperation):
return
tex
.
dreglu
(
*
args
,
**
kwargs
)
class
SReLU
(
_ActivationOperation
):
r
"""Squared Rectified Linear Unit
.. math::
\text{SReLU}(x) = \max(x^2,0)
See `Primer: Searching for Efficient Transformers for Language Modeling<https://arxiv.org/abs/2109.08668v2>`__.
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
srelu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
dsrelu
(
*
args
,
**
kwargs
)
class
SReGLU
(
_ActivationOperation
):
r
"""Squared Rectified Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{SReGLU}(a,b) = \max(a^2,0) * b
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
sreglu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
dsreglu
(
*
args
,
**
kwargs
)
class
SiLU
(
_ActivationOperation
):
r
"""Sigmoid Linear Unit
.. math::
\text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)}
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
silu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
dsilu
(
*
args
,
**
kwargs
)
class
SwiGLU
(
_ActivationOperation
):
r
"""Swish gated linear unit
...
...
transformer_engine/pytorch/ops/basic/basic_linear.py
View file @
27ddce40
...
...
@@ -12,26 +12,32 @@ from typing import Any, Optional
import
torch
from
transformer_engine.pytorch.module.base
import
get_workspace
from
...cpp_extensions
import
general_gemm
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...distributed
import
(
CudaRNGStatesTracker
,
gather_along_first_dim
,
reduce_scatter_along_first_dim
,
)
from
...fp8
import
FP8GlobalStateManager
,
Recipe
from
...module.base
import
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
from
...module.base
import
(
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
get_dummy_wgrad
,
get_workspace
,
)
from
...tensor
import
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_dequantize
,
is_quantized_tensor
from
...utils
import
(
canonicalize_device
,
canonicalize_dtype
,
clear_tensor_data
,
devices_match
,
)
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_dequantize
,
is_quantized_tensor
def
_wait_async
(
handle
:
Optional
[
Any
])
->
None
:
...
...
@@ -73,7 +79,8 @@ class BasicLinear(BasicOperation):
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful.
meaningful. This is primarily intented to integrate with
Megatron-LM.
userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly
...
...
@@ -958,6 +965,8 @@ class BasicLinear(BasicOperation):
# Save state for backward pass
if
ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x_local
)
ctx
.
save_for_backward
(
x_local
,
w
)
ctx
.
with_quantized_compute
=
with_quantized_compute
ctx
.
input_quantizer
=
input_quantizer
...
...
@@ -979,20 +988,22 @@ class BasicLinear(BasicOperation):
# Saved tensors from forward pass
(
x_local
,
w
)
=
ctx
.
saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad
=
self
.
_accumulate_into_main_grad
grad_weight
=
None
if
ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
self
.
weight
,
"__fsdp_param__"
):
self
.
weight
.
main_grad
=
self
.
weight
.
get_main_grad
()
if
not
hasattr
(
self
.
weight
,
"main_grad"
):
weight_param
=
self
.
weight
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
if
not
hasattr
(
weight
_param
,
"main_grad"
):
raise
RuntimeError
(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight
=
self
.
weight
.
main_grad
.
detach
()
grad_weight
=
weight
_param
.
main_grad
.
detach
()
else
:
accumulate_into_main_grad
=
False
...
...
@@ -1019,6 +1030,17 @@ class BasicLinear(BasicOperation):
# Clear input tensor if possible
clear_tensor_data
(
x_local
)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if
accumulate_into_main_grad
:
grad_weight
=
None
weight_param
=
self
.
weight
if
hasattr
(
weight_param
,
"grad_added_to_main_grad"
):
weight_param
.
grad_added_to_main_grad
=
True
grad_weight
=
get_dummy_wgrad
(
list
(
weight_param
.
size
()),
weight_param
.
dtype
,
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
)
return
grad_input
,
[
grad_weight
]
transformer_engine/pytorch/ops/basic/dropout.py
View file @
27ddce40
...
...
@@ -8,12 +8,12 @@ from __future__ import annotations
from
typing
import
Optional
import
torch
from
transformer_engine.pytorch.ops.op
import
(
BasicOperation
,
OperationContext
,
)
import
transformer_engine_torch
as
tex
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...tensor
import
Quantizer
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
from
.._common
import
maybe_autocast_dtype
,
maybe_dequantize
from
..op
import
BasicOperation
,
OperationContext
class
Dropout
(
BasicOperation
):
...
...
@@ -27,7 +27,7 @@ class Dropout(BasicOperation):
def
__init__
(
self
,
p
:
float
)
->
None
:
super
().
__init__
()
self
.
dropout_probability
=
p
self
.
dropout_probability
:
float
=
p
def
op_forward
(
self
,
...
...
@@ -37,21 +37,46 @@ class Dropout(BasicOperation):
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
# Compute dropout if training
out
=
input_
is_training
=
self
.
training
mask
=
None
if
is_training
:
# Output dtype
dtype
=
maybe_autocast_dtype
(
default_dtype
=
input_
.
dtype
)
# Choose implementation
impl
=
None
if
not
self
.
training
:
impl
=
"evaluation"
elif
input_
.
numel
()
%
16
==
0
and
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
):
impl
=
"fused"
else
:
impl
=
"unfused"
# Perform dropout
out
:
torch
.
Tensor
mask
:
Optional
[
torch
.
Tensor
]
=
None
if
impl
==
"evaluation"
:
out
=
input_
elif
impl
==
"fused"
:
x
=
input_
if
not
isinstance
(
x
,
Float8TensorBase
):
x
=
maybe_dequantize
(
x
,
dtype
=
dtype
)
out
,
mask
=
tex
.
dropout_fwd
(
x
,
self
.
dropout_probability
)
elif
impl
==
"unfused"
:
x
=
maybe_dequantize
(
input_
,
dtype
=
dtype
)
keep_prob
=
1
-
self
.
dropout_probability
mask
=
torch
.
empty_like
(
input_
)
mask
=
torch
.
empty_like
(
x
)
mask
.
bernoulli_
(
keep_prob
)
mask
*=
1
/
keep_prob
out
=
out
*
mask
out
=
x
*
mask
else
:
raise
ValueError
(
f
"Unsupported forward implementation
{
impl
}
"
)
# Save context for backward
if
ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
mask
)
ctx
.
save_for_backward
(
mask
)
ctx
.
is_training
=
is_training
ctx
.
impl
=
impl
ctx
.
dropout_probability
=
self
.
dropout_probability
ctx
.
dtype
=
dtype
return
out
...
...
@@ -60,8 +85,21 @@ class Dropout(BasicOperation):
ctx
:
OperationContext
,
grad_output
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
tuple
[()]]:
# Saved tensors from forward pass
(
mask
,)
=
ctx
.
saved_tensors
grad_input
=
grad_output
if
ctx
.
is_training
:
grad_input
=
grad_input
*
mask
# Perform dropout backward pass
grad_input
:
torch
.
Tensor
if
ctx
.
impl
==
"evaluation"
:
grad_input
=
grad_output
elif
ctx
.
impl
==
"fused"
:
dy
=
maybe_dequantize
(
grad_output
,
dtype
=
ctx
.
dtype
)
grad_input
=
tex
.
dropout_bwd
(
dy
,
mask
,
ctx
.
dropout_probability
)
elif
ctx
.
impl
==
"unfused"
:
dy
=
maybe_dequantize
(
grad_output
,
dtype
=
ctx
.
dtype
)
grad_input
=
dy
*
mask
else
:
raise
ValueError
(
f
"Unsupported backward implementation
{
ctx
.
impl
}
"
)
return
grad_input
,
()
transformer_engine/pytorch/ops/basic/l2normalization.py
View file @
27ddce40
...
...
@@ -10,10 +10,8 @@ import os
import
torch
from
...utils
import
clear_tensor_data
from
...
import
torch_version
from
.._common
import
maybe_dequantize
from
..op
import
BasicOperation
,
OperationContext
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...jit
import
(
l2normalization_fused
,
l2normalization_fwd_fused
,
...
...
@@ -22,6 +20,9 @@ from ...jit import (
warmup_jit_l2normalization_all_dtypes
,
)
from
...tensor
import
Quantizer
from
...utils
import
clear_tensor_data
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_dequantize
class
L2Normalization
(
BasicOperation
):
...
...
@@ -101,6 +102,8 @@ class L2Normalization(BasicOperation):
# Save state for backward pass
if
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x
,
rsqrt_norm
)
ctx
.
save_for_backward
(
x
,
rsqrt_norm
)
return
y
...
...
transformer_engine/pytorch/ops/basic/layer_norm.py
View file @
27ddce40
...
...
@@ -14,6 +14,9 @@ import torch
from
transformer_engine_torch
import
layernorm_bwd
,
layernorm_fwd
from
...constants
import
TE_DType
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...export
import
is_in_onnx_export_mode
from
...tensor
import
Quantizer
from
...utils
import
(
canonicalize_device
,
canonicalize_dtype
,
...
...
@@ -22,8 +25,6 @@ from ...utils import (
)
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_autocast_dtype
,
maybe_dequantize
from
...export
import
is_in_onnx_export_mode
from
...tensor
import
Quantizer
class
LayerNorm
(
BasicOperation
):
...
...
@@ -215,6 +216,8 @@ class LayerNorm(BasicOperation):
# Save state for backward pass
if
ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x
,
means
,
rstdevs
)
ctx
.
save_for_backward
(
x
,
means
,
rstdevs
)
ctx
.
dtype
=
dtype
...
...
transformer_engine/pytorch/ops/basic/rmsnorm.py
View file @
27ddce40
...
...
@@ -14,6 +14,9 @@ import torch
from
transformer_engine_torch
import
rmsnorm_bwd
,
rmsnorm_fwd
from
...constants
import
TE_DType
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...export
import
is_in_onnx_export_mode
from
...tensor
import
Quantizer
from
...utils
import
(
canonicalize_device
,
canonicalize_dtype
,
...
...
@@ -22,8 +25,6 @@ from ...utils import (
)
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_autocast_dtype
,
maybe_dequantize
from
...export
import
is_in_onnx_export_mode
from
...tensor
import
Quantizer
class
RMSNorm
(
BasicOperation
):
...
...
@@ -196,6 +197,8 @@ class RMSNorm(BasicOperation):
# Save state for backward pass
if
ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x
,
rstdevs
)
ctx
.
save_for_backward
(
x
,
rstdevs
)
ctx
.
dtype
=
dtype
...
...
transformer_engine/pytorch/ops/fused/__init__.py
View file @
27ddce40
...
...
@@ -8,6 +8,10 @@ from .backward_activation_bias import (
BackwardActivationBias
,
fuse_backward_activation_bias
,
)
from
.backward_add_rmsnorm
import
(
BackwardAddRMSNorm
,
fuse_backward_add_rmsnorm
,
)
from
.backward_linear_add
import
(
BackwardLinearAdd
,
fuse_backward_linear_add
,
...
...
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
0 → 100644
View file @
27ddce40
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused backward RMNorm + add."""
from
__future__
import
annotations
from
typing
import
Optional
import
math
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.ops.basic
import
MakeExtraOutput
,
RMSNorm
from
transformer_engine.pytorch.ops.op
import
(
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
...utils
import
clear_tensor_data
from
.._common
import
maybe_dequantize
class
BackwardAddRMSNorm
(
FusedOperation
):
"""Fused backward RMNorm + add"""
def
__init__
(
self
,
*
,
add
:
MakeExtraOutput
,
rmsnorm
:
RMSNorm
):
super
().
__init__
((
add
,
rmsnorm
))
def
fuser_backward
(
self
,
basic_op_ctxs
:
list
[
OperationContext
],
grad_output
:
torch
.
Tensor
,
*
,
basic_op_grad_extra_outputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
)
->
tuple
[
torch
.
Tensor
,
list
[
tuple
[
Optional
[
torch
.
Tensor
],
...]],
list
[
tuple
[()]],
]:
# Get basic operations
rmsnorm_op
=
self
.
basic_ops
[
1
]
rmsnorm_op_ctx
=
basic_op_ctxs
[
0
]
# Saved tensors from forward pass
x
,
rstdevs
=
rmsnorm_op_ctx
.
saved_tensors
# Tensor dims
weight_dims
=
rmsnorm_op
.
weight
.
size
()
inner_dim
=
math
.
prod
(
weight_dims
)
# Check input tensors
dtype
=
rmsnorm_op_ctx
.
dtype
extra_grad
=
basic_op_grad_extra_outputs
[
1
][
0
]
dy
=
maybe_dequantize
(
grad_output
.
contiguous
(),
dtype
).
view
(
x
.
size
())
w
=
maybe_dequantize
(
rmsnorm_op
.
weight
,
dtype
).
view
((
inner_dim
,))
add
=
maybe_dequantize
(
extra_grad
.
contiguous
(),
dtype
).
view
(
x
.
size
())
# Compute RMSNorm backward pass
dx
,
dw
=
tex
.
rmsnorm_bwd_add
(
dy
,
x
,
add
,
rstdevs
,
w
,
rmsnorm_op
.
_sm_margins
[
"backward"
],
rmsnorm_op
.
zero_centered_gamma
,
)
# Clear saved tensors if possible
clear_tensor_data
(
x
)
clear_tensor_data
(
rstdevs
)
# Reshape results
grad_input
=
dx
.
view
(
grad_output
.
size
())
grad_weight
=
dw
.
view
(
weight_dims
)
return
grad_input
,
[(
grad_weight
,),
()],
[(),
()]
def
fuse_backward_add_rmsnorm
(
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
"""Fused backward RMNorm + add
Parameters
----------
ops: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out
=
[]
window
=
[]
while
len
(
ops
)
>=
2
:
out
.
extend
(
window
)
# Check if first op is linear
window
,
ops
=
ops
[:
1
],
ops
[
1
:]
op
,
_
=
window
[
0
]
if
not
isinstance
(
op
,
RMSNorm
):
continue
# Check if second op is "make extra output"
op
,
_
=
ops
[
0
]
if
not
isinstance
(
op
,
MakeExtraOutput
):
continue
if
op
.
_in_place
:
continue
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
# Replace window with fused op
op
=
BackwardAddRMSNorm
(
rmsnorm
=
window
[
0
][
0
],
add
=
window
[
1
][
0
],
)
basic_op_idxs
=
[
basic_op_idxs
[
0
]
for
_
,
basic_op_idxs
in
window
]
window
=
[(
op
,
basic_op_idxs
)]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
ops
)
return
out
transformer_engine/pytorch/ops/fused/backward_linear_add.py
View file @
27ddce40
...
...
@@ -9,13 +9,10 @@ from typing import Optional
import
torch
from
transformer_engine.pytorch.ops.basic
import
BasicLinear
,
MakeExtraOutput
from
transformer_engine.pytorch.ops.op
import
(
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
...module.base
import
get_dummy_wgrad
from
...utils
import
clear_tensor_data
from
..basic
import
BasicLinear
,
MakeExtraOutput
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
class
BackwardLinearAdd
(
FusedOperation
):
...
...
@@ -53,20 +50,22 @@ class BackwardLinearAdd(FusedOperation):
# Saved tensors from forward pass
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
grad_weight
=
None
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
linear_op
.
weight
,
"__fsdp_param__"
):
linear_op
.
weight
.
main_grad
=
linear_op
.
weight
.
get_main_grad
()
if
not
hasattr
(
linear_op
.
weight
,
"main_grad"
):
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
if
not
hasattr
(
weight
_param
,
"main_grad"
):
raise
RuntimeError
(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight
=
linear_op
.
weight
.
main_grad
.
detach
()
grad_weight
=
weight
_param
.
main_grad
.
detach
()
else
:
accumulate_into_main_grad
=
False
...
...
@@ -92,12 +91,23 @@ class BackwardLinearAdd(FusedOperation):
grad_output_quantizer
=
linear_op_ctx
.
grad_output_quantizer
,
grad_input_quantizer
=
linear_op_ctx
.
grad_input_quantizer
,
)
if
accumulate_into_main_grad
:
grad_weight
=
None
# Clear input tensor if possible
clear_tensor_data
(
x_local
)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if
accumulate_into_main_grad
:
grad_weight
=
None
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"grad_added_to_main_grad"
):
weight_param
.
grad_added_to_main_grad
=
True
grad_weight
=
get_dummy_wgrad
(
list
(
weight_param
.
size
()),
weight_param
.
dtype
,
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
)
return
grad_input
,
[(
grad_weight
,),
()],
[(),
()]
...
...
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
View file @
27ddce40
...
...
@@ -9,13 +9,10 @@ from typing import Optional
import
torch
from
..basic
import
BasicLinear
,
ConstantScale
from
..op
import
(
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
...module.base
import
get_dummy_wgrad
from
...utils
import
clear_tensor_data
from
..basic
import
BasicLinear
,
ConstantScale
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
class
BackwardLinearScale
(
FusedOperation
):
...
...
@@ -54,20 +51,22 @@ class BackwardLinearScale(FusedOperation):
# Saved tensors from forward pass
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
grad_weight
=
None
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
linear_op
.
weight
,
"__fsdp_param__"
):
linear_op
.
weight
.
main_grad
=
linear_op
.
weight
.
get_main_grad
()
if
not
hasattr
(
linear_op
.
weight
,
"main_grad"
):
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
if
not
hasattr
(
weight
_param
,
"main_grad"
):
raise
RuntimeError
(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight
=
linear_op
.
weight
.
main_grad
.
detach
()
grad_weight
=
weight
_param
.
main_grad
.
detach
()
else
:
accumulate_into_main_grad
=
False
...
...
@@ -92,12 +91,23 @@ class BackwardLinearScale(FusedOperation):
grad_output_quantizer
=
linear_op_ctx
.
grad_output_quantizer
,
grad_input_quantizer
=
linear_op_ctx
.
grad_input_quantizer
,
)
if
accumulate_into_main_grad
:
grad_weight
=
None
# Clear input tensor if possible
clear_tensor_data
(
x_local
)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if
accumulate_into_main_grad
:
grad_weight
=
None
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"grad_added_to_main_grad"
):
weight_param
.
grad_added_to_main_grad
=
True
grad_weight
=
get_dummy_wgrad
(
list
(
weight_param
.
size
()),
weight_param
.
dtype
,
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
)
return
grad_input
,
[(),
(
grad_weight
,)],
[(),
()]
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
View file @
27ddce40
...
...
@@ -10,14 +10,11 @@ from typing import Any, Optional
import
torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.ops.basic
import
BasicLinear
,
Bias
from
transformer_engine.pytorch.ops.op
import
(
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...fp8
import
FP8GlobalStateManager
from
...tensor
import
Quantizer
from
..basic
import
BasicLinear
,
Bias
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
class
ForwardLinearBiasActivation
(
FusedOperation
):
...
...
@@ -121,6 +118,8 @@ class ForwardLinearBiasActivation(FusedOperation):
# Save state for backward pass
if
linear_op_ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x_local
)
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input_quantizer
=
input_quantizer
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
View file @
27ddce40
...
...
@@ -10,14 +10,11 @@ from typing import Any, Optional
import
torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.ops.basic
import
AddExtraInput
,
BasicLinear
,
Bias
from
transformer_engine.pytorch.ops.op
import
(
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
transformer_engine.pytorch.tensor
import
Quantizer
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...fp8
import
FP8GlobalStateManager
from
...tensor
import
Quantizer
from
..basic
import
AddExtraInput
,
BasicLinear
,
Bias
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
class
ForwardLinearBiasAdd
(
FusedOperation
):
...
...
@@ -118,6 +115,8 @@ class ForwardLinearBiasAdd(FusedOperation):
# Save state for backward pass
if
linear_op_ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x_local
)
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input_quantizer
=
input_quantizer
...
...
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
View file @
27ddce40
...
...
@@ -10,14 +10,15 @@ from typing import Any, Optional
import
torch
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...fp8
import
FP8GlobalStateManager
from
...tensor
import
Quantizer
from
..basic
import
AddExtraInput
,
BasicLinear
,
ConstantScale
from
..op
import
(
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
...tensor
import
Quantizer
class
ForwardLinearScaleAdd
(
FusedOperation
):
...
...
@@ -95,6 +96,8 @@ class ForwardLinearScaleAdd(FusedOperation):
# Save state for backward pass
if
linear_op_ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x_local
)
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input_quantizer
=
input_quantizer
...
...
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
27ddce40
...
...
@@ -14,11 +14,12 @@ from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_exter
from
...cpp_extensions
import
general_gemm
from
...distributed
import
get_distributed_world_size
from
...module.base
import
(
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
fill_userbuffers_buffer_for_all_gather
,
get_dummy_wgrad
,
get_ub
,
get_workspace
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
)
from
...tensor.quantized_tensor
import
Quantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
...
...
@@ -240,16 +241,16 @@ class UserbuffersBackwardLinear(FusedOperation):
with_dgrad_all_gather_x
=
False
with_wgrad_reduce_scatter_dx
=
False
if
tensor_parallel_mode
==
"row"
:
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
,
with_quantized_compute
)
ub_type_dgrad
=
CommOverlapType
.
AG
with_dgrad_all_gather_dy
=
True
elif
tensor_parallel_mode
==
"column"
:
if
input_requires_grad
and
weight_requires_grad
:
with_bulk_overlap
=
True
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
,
with_quantized_compute
)
ub_type_dgrad
=
CommOverlapType
.
AG
with_dgrad_all_gather_x
=
True
ub_comm_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
)
ub_comm_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
,
with_quantized_compute
)
ub_type_wgrad
=
CommOverlapType
.
RS
with_wgrad_reduce_scatter_dx
=
True
if
ub_comm_wgrad
.
is_fp8_ubuf
():
...
...
@@ -257,7 +258,7 @@ class UserbuffersBackwardLinear(FusedOperation):
"Userbuffers reduce-scatter is not supported with FP8 buffers"
)
else
:
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
,
with_quantized_compute
)
ub_type_dgrad
=
CommOverlapType
.
RS
with_dgrad_reduce_scatter_dx
=
True
if
ub_comm_dgrad
.
is_fp8_ubuf
():
...
...
@@ -408,7 +409,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream
,
dgrad_recv_stream
=
ub_comm_dgrad
.
get_communication_stream
()
ub_obj_overlap_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
)
ub_obj_overlap_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
,
with_quantized_compute
)
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
...
@@ -513,20 +514,22 @@ class UserbuffersBackwardLinear(FusedOperation):
# Saved tensors from forward pass
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
grad_weight
=
None
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
linear_op
.
weight
,
"__fsdp_param__"
):
linear_op
.
weight
.
main_grad
=
linear_op
.
weight
.
get_main_grad
()
if
not
hasattr
(
linear_op
.
weight
,
"main_grad"
):
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
if
not
hasattr
(
weight
_param
,
"main_grad"
):
raise
RuntimeError
(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight
=
linear_op
.
weight
.
main_grad
.
detach
()
grad_weight
=
weight
_param
.
main_grad
.
detach
()
else
:
accumulate_into_main_grad
=
False
...
...
@@ -558,10 +561,21 @@ class UserbuffersBackwardLinear(FusedOperation):
# Clear input tensor if possible
clear_tensor_data
(
x_local
)
#
Return gradients
grad_params
=
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
#
Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if
accumulate_into_main_grad
:
grad_weight
=
None
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"grad_added_to_main_grad"
):
weight_param
.
grad_added_to_main_grad
=
True
grad_weight
=
get_dummy_wgrad
(
list
(
weight_param
.
size
()),
weight_param
.
dtype
,
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
)
# Return gradients
grad_params
=
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
grad_params
[
self
.
_op_idxs
[
"linear"
]]
=
(
grad_weight
,)
if
bias_op
is
not
None
:
grad_params
[
self
.
_op_idxs
[
"bias"
]]
=
(
grad_bias
,)
...
...
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
27ddce40
...
...
@@ -12,6 +12,7 @@ import torch
from
transformer_engine_torch
import
CommOverlapType
from
...cpp_extensions
import
general_gemm
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...distributed
import
get_distributed_world_size
from
...fp8
import
FP8GlobalStateManager
from
...module.base
import
(
...
...
@@ -189,7 +190,7 @@ class UserbuffersForwardLinear(FusedOperation):
output_quantizer
=
None
# Get Userbuffers communicator
ub_comm
=
get_ub
(
ub_comm_name
+
"_fprop"
)
ub_comm
=
get_ub
(
ub_comm_name
+
"_fprop"
,
with_quantized_compute
)
with_ub_all_gather
=
tensor_parallel_mode
==
"column"
with_ub_reduce_scatter
=
tensor_parallel_mode
==
"row"
ub_type
=
CommOverlapType
.
AG
if
with_ub_all_gather
else
CommOverlapType
.
RS
...
...
@@ -353,6 +354,8 @@ class UserbuffersForwardLinear(FusedOperation):
# Save state for backward pass
if
linear_op_ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x_local
)
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input_quantizer
=
input_quantizer
...
...
Prev
1
…
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment