Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
27ddce40
Commit
27ddce40
authored
Oct 11, 2025
by
wenjh
Browse files
Merge branch 'nv_main'
parents
d262ef4c
5b3092a0
Changes
208
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
619 additions
and
169 deletions
+619
-169
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+54
-37
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+29
-11
transformer_engine/pytorch/onnx_extensions.py
transformer_engine/pytorch/onnx_extensions.py
+51
-11
transformer_engine/pytorch/ops/_common.py
transformer_engine/pytorch/ops/_common.py
+3
-1
transformer_engine/pytorch/ops/basic/__init__.py
transformer_engine/pytorch/ops/basic/__init__.py
+1
-1
transformer_engine/pytorch/ops/basic/activation.py
transformer_engine/pytorch/ops/basic/activation.py
+142
-13
transformer_engine/pytorch/ops/basic/basic_linear.py
transformer_engine/pytorch/ops/basic/basic_linear.py
+33
-11
transformer_engine/pytorch/ops/basic/dropout.py
transformer_engine/pytorch/ops/basic/dropout.py
+55
-17
transformer_engine/pytorch/ops/basic/l2normalization.py
transformer_engine/pytorch/ops/basic/l2normalization.py
+6
-3
transformer_engine/pytorch/ops/basic/layer_norm.py
transformer_engine/pytorch/ops/basic/layer_norm.py
+5
-2
transformer_engine/pytorch/ops/basic/rmsnorm.py
transformer_engine/pytorch/ops/basic/rmsnorm.py
+5
-2
transformer_engine/pytorch/ops/fused/__init__.py
transformer_engine/pytorch/ops/fused/__init__.py
+4
-0
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
+133
-0
transformer_engine/pytorch/ops/fused/backward_linear_add.py
transformer_engine/pytorch/ops/fused/backward_linear_add.py
+24
-14
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
...sformer_engine/pytorch/ops/fused/backward_linear_scale.py
+24
-14
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
...ngine/pytorch/ops/fused/forward_linear_bias_activation.py
+6
-7
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
...ormer_engine/pytorch/ops/fused/forward_linear_bias_add.py
+7
-8
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
...rmer_engine/pytorch/ops/fused/forward_linear_scale_add.py
+4
-1
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+29
-15
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+4
-1
No files found.
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
27ddce40
...
@@ -94,39 +94,45 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
...
@@ -94,39 +94,45 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
# bf16 (recipe is None):
# bf16 (recipe is None):
return
{
return
{
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
None
),
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
None
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
None
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
"reglu"
:
(
tex
.
reglu
,
tex
.
dreglu
,
None
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
"qgelu"
:
(
tex
.
qgelu
,
tex
.
dqgelu
,
None
),
"qgelu"
:
(
tex
.
qgelu
,
tex
.
dqgelu
,
None
),
"qgeglu"
:
(
tex
.
qgeglu
,
tex
.
dqgeglu
,
None
),
"qgeglu"
:
(
tex
.
qgeglu
,
tex
.
dqgeglu
,
None
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
None
),
"reglu"
:
(
tex
.
reglu
,
tex
.
dreglu
,
None
),
"srelu"
:
(
tex
.
srelu
,
tex
.
dsrelu
,
None
),
"srelu"
:
(
tex
.
srelu
,
tex
.
dsrelu
,
None
),
"sreglu"
:
(
tex
.
sreglu
,
tex
.
dsreglu
,
None
),
"silu"
:
(
tex
.
silu
,
tex
.
dsilu
,
None
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
}
}
if
recipe
.
delayed
()
or
recipe
.
mxfp8
():
if
recipe
.
delayed
()
or
recipe
.
mxfp8
():
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
return
{
return
{
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
tex
.
dbias_dgelu
),
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
tex
.
dbias_dgelu
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
tex
.
dbias_drelu
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
"reglu"
:
(
tex
.
reglu
,
tex
.
dreglu
,
None
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
"qgelu"
:
(
tex
.
qgelu
,
tex
.
dqgelu
,
tex
.
dbias_dqgelu
),
"qgelu"
:
(
tex
.
qgelu
,
tex
.
dqgelu
,
tex
.
dbias_dqgelu
),
"qgeglu"
:
(
tex
.
qgeglu
,
tex
.
dqgeglu
,
None
),
"qgeglu"
:
(
tex
.
qgeglu
,
tex
.
dqgeglu
,
None
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
tex
.
dbias_drelu
),
"reglu"
:
(
tex
.
reglu
,
tex
.
dreglu
,
None
),
"srelu"
:
(
tex
.
srelu
,
tex
.
dsrelu
,
tex
.
dbias_dsrelu
),
"srelu"
:
(
tex
.
srelu
,
tex
.
dsrelu
,
tex
.
dbias_dsrelu
),
"sreglu"
:
(
tex
.
sreglu
,
tex
.
dsreglu
,
None
),
"silu"
:
(
tex
.
silu
,
tex
.
dsilu
,
tex
.
dbias_dsilu
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
}
}
# no activation fusion written yet
# no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling: []
# Per-tensor current scaling or fp8 blockwise scaling: []
if
recipe
.
float8_current_scaling
()
or
recipe
.
float8_block_scaling
():
if
recipe
.
float8_current_scaling
()
or
recipe
.
float8_block_scaling
():
return
{
return
{
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
None
),
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
None
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
None
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
"reglu"
:
(
tex
.
reglu
,
tex
.
dreglu
,
None
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
"qgelu"
:
(
tex
.
qgelu
,
tex
.
dqgelu
,
None
),
"qgelu"
:
(
tex
.
qgelu
,
tex
.
dqgelu
,
None
),
"qgeglu"
:
(
tex
.
qgeglu
,
tex
.
dqgeglu
,
None
),
"qgeglu"
:
(
tex
.
qgeglu
,
tex
.
dqgeglu
,
None
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
None
),
"reglu"
:
(
tex
.
reglu
,
tex
.
dreglu
,
None
),
"srelu"
:
(
tex
.
srelu
,
tex
.
dsrelu
,
None
),
"srelu"
:
(
tex
.
srelu
,
tex
.
dsrelu
,
None
),
"sreglu"
:
(
tex
.
sreglu
,
tex
.
dsreglu
,
None
),
"silu"
:
(
tex
.
silu
,
tex
.
dsilu
,
None
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
}
}
raise
NotImplementedError
(
f
"Unhandled recipe type
{
recipe
}
"
)
raise
NotImplementedError
(
f
"Unhandled recipe type
{
recipe
}
"
)
...
@@ -308,7 +314,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -308,7 +314,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag
:
if
ub_overlap_ag
:
# Copy into Userbuffers buffer
# Copy into Userbuffers buffer
ub_obj_lnout
=
get_ub
(
"fc1_fprop"
)
ub_obj_lnout
=
get_ub
(
"fc1_fprop"
,
fp8
)
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj_lnout
,
ub_obj_lnout
,
ln_out
,
ln_out
,
...
@@ -446,20 +452,25 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -446,20 +452,25 @@ class _LayerNormMLP(torch.autograd.Function):
act_out
=
activation_func
(
fc1_out
,
None
)
act_out
=
activation_func
(
fc1_out
,
None
)
act_out
=
tex
.
quantize
(
act_out
,
fc2_input_quantizer
)
act_out
=
tex
.
quantize
(
act_out
,
fc2_input_quantizer
)
else
:
else
:
act_out
=
activation_func
(
fc1_out
,
fc2_input_quantizer
)
if
fp8_calibration
:
act_out
=
activation_func
(
fc1_out
,
None
)
else
:
act_out
=
activation_func
(
fc1_out
,
fc2_input_quantizer
)
if
not
is_grad_enabled
:
if
not
is_grad_enabled
:
clear_tensor_data
(
fc1_out
)
clear_tensor_data
(
fc1_out
)
if
fp8_calibration
:
if
not
fp8
and
fp8_calibration
:
fc2_input_quantizer
.
calibrate
(
act_out
)
if
fc2_input_quantizer
is
not
None
:
fc2_weight_quantizer
.
calibrate
(
fc2_weight
)
fc2_input_quantizer
.
calibrate
(
act_out
)
if
fc2_weight_quantizer
is
not
None
:
fc2_weight_quantizer
.
calibrate
(
fc2_weight
)
# Configure Userbuffers reduce-scatter if needed
# Configure Userbuffers reduce-scatter if needed
ub_obj_fc2out
=
None
ub_obj_fc2out
=
None
reduce_scatter_out
=
None
reduce_scatter_out
=
None
if
ub_overlap_rs
:
if
ub_overlap_rs
:
ub_obj_fc2out
=
get_ub
(
"fc2_fprop"
)
ub_obj_fc2out
=
get_ub
(
"fc2_fprop"
,
fp8
)
dim_size
=
list
(
act_out
.
size
())
dim_size
=
list
(
act_out
.
size
())
dim_size
[
0
]
//=
tp_world_size
dim_size
[
0
]
//=
tp_world_size
dim_size
[
-
1
]
=
fc2_weight
.
size
(
0
)
dim_size
[
-
1
]
=
fc2_weight
.
size
(
0
)
...
@@ -741,7 +752,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -741,7 +752,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Cast to expected dtype and perform tensor-parallel communication
# Note: Cast to expected dtype and perform tensor-parallel communication
ub_obj_fc2_dgrad
=
None
ub_obj_fc2_dgrad
=
None
if
ctx
.
ub_overlap_ag
:
if
ctx
.
ub_overlap_ag
:
ub_obj_fc2_dgrad
=
get_ub
(
"fc2_dgrad"
)
ub_obj_fc2_dgrad
=
get_ub
(
"fc2_dgrad"
,
ctx
.
fp8
)
ctx
.
ub_obj_gradout
=
ub_obj_fc2_dgrad
ctx
.
ub_obj_gradout
=
ub_obj_fc2_dgrad
(
(
grad_output
,
grad_output
,
...
@@ -765,7 +776,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -765,7 +776,7 @@ class _LayerNormMLP(torch.autograd.Function):
# wgrad GEMM requires input with column-wise usage
# wgrad GEMM requires input with column-wise usage
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
if
ctx
.
ub_bulk_dgrad
:
if
ctx
.
ub_bulk_dgrad
:
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
,
ctx
.
fp8
)
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ln_out_total
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj_fc1_dgrad
,
ub_obj_fc1_dgrad
,
ln_out
,
ln_out
,
...
@@ -870,7 +881,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -870,7 +881,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc2_dgrad
.
get_communication_stream
()
ub_obj_fc2_dgrad
.
get_communication_stream
()
)
)
ub_obj_fc2_wgrad
=
get_ub
(
"fc2_wgrad"
)
ub_obj_fc2_wgrad
=
get_ub
(
"fc2_wgrad"
,
ctx
.
fp8
)
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
@@ -1045,16 +1056,16 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -1045,16 +1056,16 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_shape
=
[
reduce
(
multiply_op
,
inputmat
.
shape
[:
-
1
]),
inputmat
.
shape
[
-
1
]]
fc1_dgrad_shape
=
[
reduce
(
multiply_op
,
inputmat
.
shape
[:
-
1
]),
inputmat
.
shape
[
-
1
]]
if
ctx
.
ub_overlap_rs_dgrad
:
if
ctx
.
ub_overlap_rs_dgrad
:
# Overlap DGRAD+RS
# Overlap DGRAD+RS
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
,
ctx
.
fp8
)
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
RS
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
RS
else
:
else
:
if
ctx
.
ub_bulk_dgrad
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap ln_out all-gather with DGRAD compute
# Overlap ln_out all-gather with DGRAD compute
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
)
ub_obj_fc1_dgrad
=
get_ub
(
"fc1_dgrad"
,
ctx
.
fp8
)
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
AG
ub_type_fc1_dgrad
=
tex
.
CommOverlapType
.
AG
if
ctx
.
ub_bulk_wgrad
:
if
ctx
.
ub_bulk_wgrad
:
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute
ub_obj_fc1_wgrad
=
get_ub
(
"fc1_wgrad"
)
ub_obj_fc1_wgrad
=
get_ub
(
"fc1_wgrad"
,
ctx
.
fp8
)
ub_type_fc1_wgrad
=
tex
.
CommOverlapType
.
RS
ub_type_fc1_wgrad
=
tex
.
CommOverlapType
.
RS
# --------------------------------------------------
# --------------------------------------------------
...
@@ -1402,7 +1413,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -1402,7 +1413,7 @@ class _LayerNormMLP(torch.autograd.Function):
class
LayerNormMLP
(
TransformerEngineBaseModule
):
class
LayerNormMLP
(
TransformerEngineBaseModule
):
r
"""
r
"""
Applies layer normalization on the input followed by the MLP module, consisting of
Applies layer normalization on the input followed by the MLP module, consisting of
2 successive linear transformations, separated by the
GeLU
activation.
2 successive linear transformations, separated by the activation
function
.
Parameters
Parameters
----------
----------
...
@@ -1418,7 +1429,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1418,7 +1429,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
type of normalization applied.
type of normalization applied.
activation : str, default = 'gelu'
activation : str, default = 'gelu'
activation function used.
activation function used.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu', 'qgelu', 'srelu'.
Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'.
init_method : Callable, default = `None`
init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`.
used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
...
@@ -1559,7 +1571,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1559,7 +1571,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
gemm_gelu_fusion
=
(
self
.
gemm_gelu_fusion
=
(
bool
(
int
(
os
.
getenv
(
"NVTE_GEMM_GELU_FUSION"
,
"0"
)))
bool
(
int
(
os
.
getenv
(
"NVTE_GEMM_GELU_FUSION"
,
"0"
)))
and
self
.
activation
==
"gelu"
and
self
.
activation
==
"gelu"
and
((
_ub_communicators
is
None
)
or
(
not
get_ub
(
"fc1_fprop"
).
is_atomic_gemm
()))
and
all
(
(
"fc1_fprop"
,
use_fp8
)
not
in
_ub_communicators
or
not
get_ub
(
"fc1_fprop"
,
use_fp8
).
is_atomic_gemm
()
for
use_fp8
in
[
False
,
True
]
)
)
)
self
.
name
=
name
self
.
name
=
name
...
@@ -1619,7 +1635,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1619,7 +1635,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
layer_norm_bias
=
None
self
.
layer_norm_bias
=
None
# FC1 init
# FC1 init
if
self
.
activation
in
[
"
r
eglu"
,
"geglu"
,
"
qg
eglu"
,
"swiglu"
]:
if
self
.
activation
in
[
"
g
eglu"
,
"
q
geglu"
,
"
reglu"
,
"sr
eglu"
,
"swiglu"
]:
fc1_output_features
=
2
*
self
.
size_per_partition
fc1_output_features
=
2
*
self
.
size_per_partition
else
:
else
:
fc1_output_features
=
self
.
size_per_partition
fc1_output_features
=
self
.
size_per_partition
...
@@ -1777,7 +1793,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1777,7 +1793,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fp8_output
=
False
fp8_output
=
False
if
self
.
ub_overlap_rs
:
if
self
.
ub_overlap_rs
:
if
get_ub
(
"fc2_fprop"
).
is_fp8_ubuf
():
if
get_ub
(
"fc2_fprop"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_output
=
True
fp8_output
=
True
with
torch
.
cuda
.
device
(
with
torch
.
cuda
.
device
(
...
@@ -1915,7 +1931,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1915,7 +1931,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_grad_output_quantizer
,
fc2_grad_output_quantizer
,
)
=
[
None
]
*
10
)
=
[
None
]
*
10
fc1_weight_quantizer
,
fc2_weight_quantizer
=
self
.
_get_weight_quantizers
()
fc1_weight_quantizer
,
fc2_weight_quantizer
=
self
.
_get_weight_quantizers
()
if
self
.
fp8
:
if
self
.
fp8
or
self
.
fp8_calibration
:
fc1_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
fc1_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
fc1_input_quantizer
.
internal
=
True
fc1_input_quantizer
.
internal
=
True
fc2_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_INPUT
]
fc2_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_INPUT
]
...
@@ -2001,14 +2017,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -2001,14 +2017,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
activation_map
=
{
activation_map
=
{
"gelu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
),
"gelu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
),
"relu"
:
torch
.
nn
.
functional
.
relu
,
"geglu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
"geglu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
"reglu"
:
lambda
x
:
torch
.
nn
.
functional
.
relu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
"qgelu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
),
"swiglu"
:
lambda
x
:
torch
.
nn
.
functional
.
silu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
"qgeglu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
.
chunk
(
2
,
-
1
)[
0
],
approximate
=
"tanh"
)
"qgeglu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
.
chunk
(
2
,
-
1
)[
0
],
approximate
=
"tanh"
)
*
x
.
chunk
(
2
,
-
1
)[
1
],
*
x
.
chunk
(
2
,
-
1
)[
1
],
"qgelu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
),
"relu"
:
torch
.
nn
.
functional
.
relu
,
"srelu"
:
torch
.
nn
.
functional
.
softplus
,
"reglu"
:
lambda
x
:
torch
.
nn
.
functional
.
relu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
"srelu"
:
lambda
x
:
torch
.
nn
.
functional
.
relu
(
x
)
**
2
,
"sreglu"
:
lambda
x
:
torch
.
nn
.
functional
.
relu
(
x
.
chunk
(
2
,
-
1
)[
0
])
**
2
*
x
.
chunk
(
2
,
-
1
)[
1
],
"silu"
:
torch
.
nn
.
functional
.
silu
,
"swiglu"
:
lambda
x
:
torch
.
nn
.
functional
.
silu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
}
}
if
self
.
activation
not
in
activation_map
:
if
self
.
activation
not
in
activation_map
:
raise
ValueError
(
f
"Unsupported activation in onnx export:
{
self
.
activation
}
"
)
raise
ValueError
(
f
"Unsupported activation in onnx export:
{
self
.
activation
}
"
)
...
@@ -2129,7 +2148,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -2129,7 +2148,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
"""Get the weight quantizers of the module."""
if
not
self
.
fp8
:
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
return
[
None
,
None
]
return
[
None
,
None
]
fc1_weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
fc1_weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
fc1_weight_quantizer
.
internal
=
True
fc1_weight_quantizer
.
internal
=
True
...
@@ -2182,10 +2201,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -2182,10 +2201,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
if
self
.
fc1_bias
.
grad
is
None
:
if
self
.
fc1_bias
.
grad
is
None
:
self
.
fc1_bias
.
grad
=
fc1_bias_grad
.
to
(
self
.
fc1_bias
.
dtype
)
self
.
fc1_bias
.
grad
=
fc1_bias_grad
.
to
(
self
.
fc1_bias
.
dtype
)
if
not
self
.
fuse_wgrad_accumulation
:
if
not
self
.
fuse_wgrad_accumulation
:
if
self
.
fc2_weight
.
grad
is
None
:
self
.
fc2_weight
.
grad
=
fc2_wgrad
.
to
(
self
.
fc2_weight
.
dtype
)
self
.
fc2_weight
.
grad
=
fc2_wgrad
.
to
(
self
.
fc2_weight
.
dtype
)
self
.
fc1_weight
.
grad
=
fc1_wgrad
.
to
(
self
.
fc1_weight
.
dtype
)
if
self
.
fc1_weight
.
grad
is
None
:
self
.
fc1_weight
.
grad
=
fc1_wgrad
.
to
(
self
.
fc1_weight
.
dtype
)
del
fc2_bias_grad_
del
fc2_bias_grad_
del
fc2_wgrad
del
fc2_wgrad
del
fc1_wgrad
del
fc1_wgrad
...
...
transformer_engine/pytorch/module/linear.py
View file @
27ddce40
...
@@ -147,10 +147,10 @@ class _Linear(torch.autograd.Function):
...
@@ -147,10 +147,10 @@ class _Linear(torch.autograd.Function):
ub_obj
=
None
ub_obj
=
None
ub_type
=
None
ub_type
=
None
if
ub_overlap_rs_fprop
:
if
ub_overlap_rs_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
RS
ub_type
=
tex
.
CommOverlapType
.
RS
elif
ub_overlap_ag_fprop
:
elif
ub_overlap_ag_fprop
:
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
AG
ub_type
=
tex
.
CommOverlapType
.
AG
# ------------------------------------------------------
# ------------------------------------------------------
...
@@ -319,6 +319,13 @@ class _Linear(torch.autograd.Function):
...
@@ -319,6 +319,13 @@ class _Linear(torch.autograd.Function):
# Finished forward GEMM...
# Finished forward GEMM...
# ------------------------------------------------------
# ------------------------------------------------------
# Deallocate GEMM input tensor if no longer needed
# TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically
# deallocated by GC. Manually deallocating is a temporary hack.
if
with_input_all_gather_nccl
:
clear_tensor_data
(
inputmat_total
)
inputmat_total
=
None
# ------------------------------------------------------
# ------------------------------------------------------
# Prepare output tensor
# Prepare output tensor
# Note: Perform tensor-parallel communication
# Note: Perform tensor-parallel communication
...
@@ -544,23 +551,23 @@ class _Linear(torch.autograd.Function):
...
@@ -544,23 +551,23 @@ class _Linear(torch.autograd.Function):
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
dgrad_shape
=
[
reduce
(
multiply_op
,
ctx
.
inp_shape
[:
-
1
]),
ctx
.
inp_shape
[
-
1
]]
if
ctx
.
ub_overlap_ag
:
if
ctx
.
ub_overlap_ag
:
# Overlap grad_output all-gather with dgrad compute
# Overlap grad_output all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
elif
ctx
.
ub_overlap_rs_dgrad
:
elif
ctx
.
ub_overlap_rs_dgrad
:
# Overlap dgrad reduce-scatter with dgrad compute
# Overlap dgrad reduce-scatter with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
ub_type_dgrad
=
tex
.
CommOverlapType
.
RS
else
:
else
:
if
ctx
.
ub_bulk_dgrad
:
if
ctx
.
ub_bulk_dgrad
:
# Overlap inputmat all-gather with dgrad compute
# Overlap inputmat all-gather with dgrad compute
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
)
ctx
.
ub_obj_gradout
=
get_ub
(
ctx
.
ub_name
+
"_dgrad"
,
ctx
.
fp8
)
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_obj_dgrad
=
ctx
.
ub_obj_gradout
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
ub_type_dgrad
=
tex
.
CommOverlapType
.
AG
if
ctx
.
ub_bulk_wgrad
:
if
ctx
.
ub_bulk_wgrad
:
# Overlap dgrad reduce-scatter with wgrad compute
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
,
ctx
.
fp8
)
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
ub_type_wgrad
=
tex
.
CommOverlapType
.
RS
# --------------------------------------------------
# --------------------------------------------------
...
@@ -793,7 +800,7 @@ class _Linear(torch.autograd.Function):
...
@@ -793,7 +800,7 @@ class _Linear(torch.autograd.Function):
dgrad_send_stream
,
dgrad_recv_stream
=
ub_obj_dgrad
.
get_communication_stream
()
dgrad_send_stream
,
dgrad_recv_stream
=
ub_obj_dgrad
.
get_communication_stream
()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
)
ub_obj_overlap_wgrad
=
get_ub
(
ctx
.
ub_name
+
"_wgrad"
,
ctx
.
fp8
)
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
@@ -905,9 +912,16 @@ class _Linear(torch.autograd.Function):
...
@@ -905,9 +912,16 @@ class _Linear(torch.autograd.Function):
grad_bias
=
grad_bias_
grad_bias
=
grad_bias_
del
grad_bias_
del
grad_bias_
# Deallocate
input
tensor if permitted
# Deallocate tensor
s
if permitted
if
ctx
.
owns_input
:
if
ctx
.
owns_input
:
# Input tensor is internal
clear_tensor_data
(
inputmat_total
)
clear_tensor_data
(
inputmat_total
)
elif
ctx
.
backward_input_needs_gather
:
# Gathered input tensor is internal
clear_tensor_data
(
inputmat_total
)
if
ctx
.
parallel_mode
==
"row"
and
ctx
.
sequence_parallel
:
# Gathered grad output tensor is internal
clear_tensor_data
(
grad_output
)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if
ctx
.
ub_bulk_wgrad
:
if
ctx
.
ub_bulk_wgrad
:
...
@@ -1404,10 +1418,14 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1404,10 +1418,14 @@ class Linear(TransformerEngineBaseModule):
is_first_microbatch
=
False
is_first_microbatch
=
False
if
self
.
ub_overlap_rs_fprop
:
if
self
.
ub_overlap_rs_fprop
:
if
get_ub
(
self
.
ub_name
+
"_fprop"
).
is_fp8_ubuf
():
if
get_ub
(
self
.
ub_name
+
"_fprop"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_output
=
True
fp8_output
=
True
if
self
.
ub_overlap_rs_dgrad
:
if
self
.
ub_overlap_rs_dgrad
:
if
get_ub
(
self
.
ub_name
+
"_dgrad"
).
is_fp8_ubuf
():
if
get_ub
(
self
.
ub_name
+
"_dgrad"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
).
is_fp8_ubuf
():
fp8_grad
=
True
fp8_grad
=
True
with
torch
.
cuda
.
device
(
with
torch
.
cuda
.
device
(
...
@@ -1666,7 +1684,7 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1666,7 +1684,7 @@ class Linear(TransformerEngineBaseModule):
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
"""Get the weight quantizers of the module."""
if
not
self
.
fp8
:
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
return
[
None
]
return
[
None
]
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
.
internal
=
True
weight_quantizer
.
internal
=
True
...
...
transformer_engine/pytorch/onnx_extensions.py
View file @
27ddce40
...
@@ -112,7 +112,9 @@ schema = defs.OpSchema(
...
@@ -112,7 +112,9 @@ schema = defs.OpSchema(
doc
=
"TRT FP8 Quantize Linear used for inference."
,
doc
=
"TRT FP8 Quantize Linear used for inference."
,
inputs
=
[
inputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"tensor"
,
"tensor(float)"
,
"Input tensor to quantize"
),
defs
.
OpSchema
.
FormalParameter
(
"tensor"
,
"tensor(float)"
,
"Input tensor to quantize"
),
defs
.
OpSchema
.
FormalParameter
(
"scale"
,
"tensor(float)"
,
"Scale factor for quantization"
),
defs
.
OpSchema
.
FormalParameter
(
"scale_inv"
,
"tensor(float)"
,
"Inverse scale factor for quantization"
),
],
],
outputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"output"
,
"tensor(uint8)"
,
"Quantized output tensor"
)],
outputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"output"
,
"tensor(uint8)"
,
"Quantized output tensor"
)],
)
)
...
@@ -126,11 +128,10 @@ TRT_FP8QuantizeLinear = onnxscript.values.Op(
...
@@ -126,11 +128,10 @@ TRT_FP8QuantizeLinear = onnxscript.values.Op(
@
torch
.
library
.
custom_op
(
"tex::fp8_dequantize"
,
mutates_args
=
[])
@
torch
.
library
.
custom_op
(
"tex::fp8_dequantize"
,
mutates_args
=
[])
def
onnx_dequantize_fp8_op
(
tensor
:
torch
.
Tensor
,
scale
:
float
)
->
torch
.
Tensor
:
def
onnx_dequantize_fp8_op
(
tensor
:
torch
.
Tensor
,
scale
_inv
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Dequantize from Float8Tensor used for inference."""
"""Dequantize from Float8Tensor used for inference."""
scale_tensor
=
torch
.
tensor
(
scale
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
quantizer
=
Float8Quantizer
(
quantizer
=
Float8Quantizer
(
scale_
tensor
,
torch
.
zeros
(
1
).
to
(
tensor
.
device
),
tex
.
DType
.
kFloat8E4M3
1
/
scale_
inv
,
torch
.
zeros
(
1
).
to
(
tensor
.
device
),
tex
.
DType
.
kFloat8E4M3
)
)
quantizer_tensor
=
quantizer
.
create_tensor_from_data
(
tensor
,
fake_dtype
=
torch
.
float32
)
quantizer_tensor
=
quantizer
.
create_tensor_from_data
(
tensor
,
fake_dtype
=
torch
.
float32
)
return
quantizer_tensor
.
dequantize
()
return
quantizer_tensor
.
dequantize
()
...
@@ -143,10 +144,9 @@ def _(tensor: torch.Tensor, _) -> torch.Tensor:
...
@@ -143,10 +144,9 @@ def _(tensor: torch.Tensor, _) -> torch.Tensor:
def
onnx_dequantize_fp8_symbolic
(
def
onnx_dequantize_fp8_symbolic
(
tensor
:
onnxscript
.
onnx_types
.
TensorType
,
scale
:
float
tensor
:
onnxscript
.
onnx_types
.
TensorType
,
scale
_inv
:
onnxscript
.
onnx_types
.
TensorType
)
->
onnxscript
.
onnx_types
.
TensorType
:
)
->
onnxscript
.
onnx_types
.
TensorType
:
"""Symbolic dequantize from Float8Tensor used for inference."""
"""Symbolic dequantize from Float8Tensor used for inference."""
scale_inv
=
op
.
Constant
(
value_float
=
1
/
scale
)
return
TRT_FP8DequantizeLinear
(
tensor
,
scale_inv
)
return
TRT_FP8DequantizeLinear
(
tensor
,
scale_inv
)
...
@@ -157,7 +157,9 @@ schema = defs.OpSchema(
...
@@ -157,7 +157,9 @@ schema = defs.OpSchema(
doc
=
"TRT FP8 Dequantize Linear from Float8Tensor used for inference."
,
doc
=
"TRT FP8 Dequantize Linear from Float8Tensor used for inference."
,
inputs
=
[
inputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"tensor"
,
"tensor(uint8)"
,
"Input tensor to dequantize"
),
defs
.
OpSchema
.
FormalParameter
(
"tensor"
,
"tensor(uint8)"
,
"Input tensor to dequantize"
),
defs
.
OpSchema
.
FormalParameter
(
"scale"
,
"tensor(float)"
,
"Scale factor for dequantization"
),
defs
.
OpSchema
.
FormalParameter
(
"scale_inv"
,
"tensor(float)"
,
"Inverse scale factor for dequantization"
),
],
],
outputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"output"
,
"tensor(float)"
,
"Dequantized output tensor"
)],
outputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"output"
,
"tensor(float)"
,
"Dequantized output tensor"
)],
)
)
...
@@ -166,6 +168,43 @@ TRT_FP8DequantizeLinear = onnxscript.values.Op(
...
@@ -166,6 +168,43 @@ TRT_FP8DequantizeLinear = onnxscript.values.Op(
opset
=
trt_opset
,
name
=
"TRT_FP8DequantizeLinear"
,
op_schema
=
schema
opset
=
trt_opset
,
name
=
"TRT_FP8DequantizeLinear"
,
op_schema
=
schema
)
)
# ONNX FP8 Current Scaling Quantization
@
torch
.
library
.
custom_op
(
"tex::fp8_cs_quantize"
,
mutates_args
=
[])
def
onnx_cs_quantize_fp8_op
(
tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Quantize to FP8 with current scaling; returns (uint8, scale_inv)."""
if
tensor
.
dtype
!=
torch
.
float32
:
tensor
=
tensor
.
to
(
torch
.
float32
)
amax
=
tensor
.
abs
().
max
()
eps
=
torch
.
tensor
(
1e-12
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
amax
=
torch
.
maximum
(
amax
,
eps
)
fp8_max
=
torch
.
tensor
(
448
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
scale
=
fp8_max
/
amax
q
=
torch
.
ops
.
tex
.
fp8_quantize
(
tensor
,
scale
)
scale_inv
=
1
/
scale
return
q
,
scale_inv
@
onnx_cs_quantize_fp8_op
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
empty
(
tensor
.
shape
,
dtype
=
torch
.
uint8
,
device
=
tensor
.
device
),
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
def
onnx_quantize_fp8_cs_symbolic
(
tensor
:
onnxscript
.
onnx_types
.
TensorType
,
):
"""Symbolic quantize with current scaling; computes scale_inv from tensor."""
# scale_inv = 1 / max(abs(tensor))
amax
=
op
.
ReduceMax
(
op
.
Abs
(
tensor
),
keepdims
=
0
)
eps
=
op
.
Constant
(
value_float
=
1.0e-12
)
amax
=
op
.
Max
(
amax
,
eps
)
scale_inv
=
op
.
Div
(
amax
,
op
.
Constant
(
value_float
=
448.0
))
q
=
TRT_FP8QuantizeLinear
(
tensor
,
scale_inv
)
return
q
,
scale_inv
# ONNX MXFP8 Quantization
# ONNX MXFP8 Quantization
...
@@ -194,12 +233,12 @@ def onnx_quantize_mxfp8_symbolic(
...
@@ -194,12 +233,12 @@ def onnx_quantize_mxfp8_symbolic(
tensor
:
onnxscript
.
onnx_types
.
TensorType
,
tensor
:
onnxscript
.
onnx_types
.
TensorType
,
)
->
Tuple
[
onnxscript
.
onnx_types
.
TensorType
,
onnxscript
.
onnx_types
.
TensorType
]:
)
->
Tuple
[
onnxscript
.
onnx_types
.
TensorType
,
onnxscript
.
onnx_types
.
TensorType
]:
"""Symbolic quantize to MXFP8Tensor used for inference."""
"""Symbolic quantize to MXFP8Tensor used for inference."""
tensor_out
,
scale_inv_out
=
TRT_MXFP8Quantize
Linear
(
tensor
)
tensor_out
,
scale_inv_out
=
TRT_MXFP8
Dynamic
Quantize
(
tensor
)
return
tensor_out
,
scale_inv_out
return
tensor_out
,
scale_inv_out
schema
=
defs
.
OpSchema
(
schema
=
defs
.
OpSchema
(
name
=
"TRT_MXFP8Quantize
Linear
"
,
name
=
"TRT_MXFP8
Dynamic
Quantize"
,
domain
=
"trt"
,
domain
=
"trt"
,
since_version
=
1
,
since_version
=
1
,
doc
=
"TRT MXFP8 Quantize Linear used for inference."
,
doc
=
"TRT MXFP8 Quantize Linear used for inference."
,
...
@@ -214,8 +253,8 @@ schema = defs.OpSchema(
...
@@ -214,8 +253,8 @@ schema = defs.OpSchema(
],
],
)
)
TRT_MXFP8Quantize
Linear
=
onnxscript
.
values
.
Op
(
TRT_MXFP8
Dynamic
Quantize
=
onnxscript
.
values
.
Op
(
opset
=
trt_opset
,
name
=
"TRT_MXFP8Quantize
Linear
"
,
op_schema
=
schema
opset
=
trt_opset
,
name
=
"TRT_MXFP8
Dynamic
Quantize"
,
op_schema
=
schema
)
)
...
@@ -356,6 +395,7 @@ te_translation_table = {
...
@@ -356,6 +395,7 @@ te_translation_table = {
torch
.
ops
.
tex
.
gemm_inf
.
default
:
onnx_gemm_inf_symbolic
,
torch
.
ops
.
tex
.
gemm_inf
.
default
:
onnx_gemm_inf_symbolic
,
torch
.
ops
.
tex
.
fp8_quantize
.
default
:
onnx_quantize_fp8_symbolic
,
torch
.
ops
.
tex
.
fp8_quantize
.
default
:
onnx_quantize_fp8_symbolic
,
torch
.
ops
.
tex
.
fp8_dequantize
.
default
:
onnx_dequantize_fp8_symbolic
,
torch
.
ops
.
tex
.
fp8_dequantize
.
default
:
onnx_dequantize_fp8_symbolic
,
torch
.
ops
.
tex
.
fp8_cs_quantize
.
default
:
onnx_quantize_fp8_cs_symbolic
,
torch
.
ops
.
tex
.
mxfp8_quantize
.
default
:
onnx_quantize_mxfp8_symbolic
,
torch
.
ops
.
tex
.
mxfp8_quantize
.
default
:
onnx_quantize_mxfp8_symbolic
,
torch
.
ops
.
tex
.
mxfp8_dequantize
.
default
:
onnx_dequantize_mxfp8_symbolic
,
torch
.
ops
.
tex
.
mxfp8_dequantize
.
default
:
onnx_dequantize_mxfp8_symbolic
,
torch
.
ops
.
tex
.
layernorm
.
default
:
onnx_layernorm_symbolic
,
torch
.
ops
.
tex
.
layernorm
.
default
:
onnx_layernorm_symbolic
,
...
...
transformer_engine/pytorch/ops/_common.py
View file @
27ddce40
...
@@ -29,7 +29,9 @@ def maybe_dequantize(
...
@@ -29,7 +29,9 @@ def maybe_dequantize(
if
is_quantized_tensor
(
tensor
):
if
is_quantized_tensor
(
tensor
):
return
tensor
.
dequantize
(
dtype
=
dtype
)
return
tensor
.
dequantize
(
dtype
=
dtype
)
if
dtype
is
not
None
and
tensor
.
dtype
!=
dtype
:
if
dtype
is
not
None
and
tensor
.
dtype
!=
dtype
:
return
tensor
.
to
(
dtype
)
tensor
=
tensor
.
to
(
dtype
)
if
not
tensor
.
is_contiguous
():
tensor
=
tensor
.
contiguous
()
return
tensor
return
tensor
...
...
transformer_engine/pytorch/ops/basic/__init__.py
View file @
27ddce40
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
"""Single tensor operations supported by the operation fuser."""
"""Single tensor operations supported by the operation fuser."""
from
.activation
import
GELU
,
Re
LU
,
GEGLU
,
Re
G
LU
,
SwiGLU
from
.activation
import
GELU
,
GEG
LU
,
QGELU
,
Q
GEGLU
,
Re
LU
,
ReGLU
,
SReLU
,
SReGLU
,
Si
LU
,
SwiGLU
from
.add_extra_input
import
AddExtraInput
from
.add_extra_input
import
AddExtraInput
from
.all_gather
import
AllGather
from
.all_gather
import
AllGather
from
.all_reduce
import
AllReduce
from
.all_reduce
import
AllReduce
...
...
transformer_engine/pytorch/ops/basic/activation.py
View file @
27ddce40
...
@@ -11,11 +11,25 @@ from typing import Optional
...
@@ -11,11 +11,25 @@ from typing import Optional
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Quantizer
from
...tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Quantizer
from
...utils
import
clear_tensor_data
from
...utils
import
clear_tensor_data
from
..op
import
BasicOperation
,
OperationContext
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_dequantize
from
.._common
import
maybe_dequantize
__all__
=
[
"GELU"
,
"GEGLU"
,
"QGELU"
,
"QGEGLU"
,
"ReLU"
,
"ReGLU"
,
"SReLU"
,
"SReGLU"
,
"SiLU"
,
"SwiGLU"
,
]
class
_ActivationOperation
(
BasicOperation
,
metaclass
=
abc
.
ABCMeta
):
class
_ActivationOperation
(
BasicOperation
,
metaclass
=
abc
.
ABCMeta
):
r
"""Apply activation function
r
"""Apply activation function
...
@@ -97,6 +111,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
...
@@ -97,6 +111,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Save state for backward pass
# Save state for backward pass
if
ctx
.
requires_grad
:
if
ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x
)
ctx
.
save_for_backward
(
x
)
ctx
.
save_for_backward
(
x
)
ctx
.
dtype
=
dtype
ctx
.
dtype
=
dtype
ctx
.
prev_op_grad_output_quantizer
=
prev_op_grad_output_quantizer
ctx
.
prev_op_grad_output_quantizer
=
prev_op_grad_output_quantizer
...
@@ -147,37 +163,75 @@ class GELU(_ActivationOperation):
...
@@ -147,37 +163,75 @@ class GELU(_ActivationOperation):
return
tex
.
dgelu
(
*
args
,
**
kwargs
)
return
tex
.
dgelu
(
*
args
,
**
kwargs
)
class
ReLU
(
_ActivationOperation
):
class
GEGLU
(
_ActivationOperation
):
r
"""Rectified linear unit
r
"""Gaussian Error Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
.. math::
\text{ReLU}(x) = \max(x,0)
\text{GEGLU}(a,b) = \text{GELU}(a) * b
where
.. math::
\text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right)
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
See `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__.
"""
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
re
lu
(
*
args
,
**
kwargs
)
return
tex
.
geg
lu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
d
re
lu
(
*
args
,
**
kwargs
)
return
tex
.
d
geg
lu
(
*
args
,
**
kwargs
)
class
GEGLU
(
_ActivationOperation
):
class
QGELU
(
_ActivationOperation
):
r
"""Gaussian error gated linear unit
r
"""Quick Gaussian Error Linear Unit
Quick GELU from `HuggingFace<https://github.com/huggingface/transformers/blob/3e93dd295b5343557a83bc07b0b2ea64c926f9b4/src/transformers/activations.py#L90>`__
and `paper<https://github.com/hendrycks/GELUs>`__.
.. math::
\text{QGELU}(x) \approx x * \sigma(1.702 * x)
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
qgelu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
dqgelu
(
*
args
,
**
kwargs
)
class
QGEGLU
(
_ActivationOperation
):
r
"""Quick Gaussian Error Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
along the last dimension and the following is computed:
.. math::
.. math::
\text{GEGLU}(a,b) = \text{GELU}(a) * b
\text{
Q
GEGLU}(a,b) = \text{
Q
GELU}(a) * b
where
where
.. math::
.. math::
\text{GELU}(x) \approx
\frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right
)
\text{
Q
GELU}(x) \approx
x * \sigma(1.702 * x
)
.. warning::
.. warning::
...
@@ -187,19 +241,33 @@ class GEGLU(_ActivationOperation):
...
@@ -187,19 +241,33 @@ class GEGLU(_ActivationOperation):
the first half of the input tensor, while PyTorch applies it to
the first half of the input tensor, while PyTorch applies it to
the second half.
the second half.
See `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__.
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
qgeglu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
dqgeglu
(
*
args
,
**
kwargs
)
class
ReLU
(
_ActivationOperation
):
r
"""Rectified Linear Unit
.. math::
\text{ReLU}(x) = \max(x,0)
"""
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
geg
lu
(
*
args
,
**
kwargs
)
return
tex
.
re
lu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
d
geg
lu
(
*
args
,
**
kwargs
)
return
tex
.
d
re
lu
(
*
args
,
**
kwargs
)
class
ReGLU
(
_ActivationOperation
):
class
ReGLU
(
_ActivationOperation
):
r
"""Rectified
g
ated
l
inear
u
nit
r
"""Rectified
G
ated
L
inear
U
nit
The input tensor is split into chunks :math:`a` and :math:`b`
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
along the last dimension and the following is computed:
...
@@ -227,6 +295,67 @@ class ReGLU(_ActivationOperation):
...
@@ -227,6 +295,67 @@ class ReGLU(_ActivationOperation):
return
tex
.
dreglu
(
*
args
,
**
kwargs
)
return
tex
.
dreglu
(
*
args
,
**
kwargs
)
class
SReLU
(
_ActivationOperation
):
r
"""Squared Rectified Linear Unit
.. math::
\text{SReLU}(x) = \max(x^2,0)
See `Primer: Searching for Efficient Transformers for Language Modeling<https://arxiv.org/abs/2109.08668v2>`__.
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
srelu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
dsrelu
(
*
args
,
**
kwargs
)
class
SReGLU
(
_ActivationOperation
):
r
"""Squared Rectified Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{SReGLU}(a,b) = \max(a^2,0) * b
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
sreglu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
dsreglu
(
*
args
,
**
kwargs
)
class
SiLU
(
_ActivationOperation
):
r
"""Sigmoid Linear Unit
.. math::
\text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)}
"""
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
silu
(
*
args
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
dsilu
(
*
args
,
**
kwargs
)
class
SwiGLU
(
_ActivationOperation
):
class
SwiGLU
(
_ActivationOperation
):
r
"""Swish gated linear unit
r
"""Swish gated linear unit
...
...
transformer_engine/pytorch/ops/basic/basic_linear.py
View file @
27ddce40
...
@@ -12,26 +12,32 @@ from typing import Any, Optional
...
@@ -12,26 +12,32 @@ from typing import Any, Optional
import
torch
import
torch
from
transformer_engine.pytorch.module.base
import
get_workspace
from
...cpp_extensions
import
general_gemm
from
...cpp_extensions
import
general_gemm
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...distributed
import
(
from
...distributed
import
(
CudaRNGStatesTracker
,
CudaRNGStatesTracker
,
gather_along_first_dim
,
gather_along_first_dim
,
reduce_scatter_along_first_dim
,
reduce_scatter_along_first_dim
,
)
)
from
...fp8
import
FP8GlobalStateManager
,
Recipe
from
...fp8
import
FP8GlobalStateManager
,
Recipe
from
...module.base
import
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
from
...module.base
import
(
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
get_dummy_wgrad
,
get_workspace
,
)
from
...tensor
import
Quantizer
from
...tensor
import
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_dequantize
,
is_quantized_tensor
from
...utils
import
(
from
...utils
import
(
canonicalize_device
,
canonicalize_device
,
canonicalize_dtype
,
canonicalize_dtype
,
clear_tensor_data
,
clear_tensor_data
,
devices_match
,
devices_match
,
)
)
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_dequantize
,
is_quantized_tensor
def
_wait_async
(
handle
:
Optional
[
Any
])
->
None
:
def
_wait_async
(
handle
:
Optional
[
Any
])
->
None
:
...
@@ -73,7 +79,8 @@ class BasicLinear(BasicOperation):
...
@@ -73,7 +79,8 @@ class BasicLinear(BasicOperation):
weight's `main_grad` attribute instead of relying on PyTorch
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
there is no guarantee that `grad` will be set or be
meaningful.
meaningful. This is primarily intented to integrate with
Megatron-LM.
userbuffers_options, dict, optional
userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with
Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly
compute using Userbuffers. This feature is highly
...
@@ -958,6 +965,8 @@ class BasicLinear(BasicOperation):
...
@@ -958,6 +965,8 @@ class BasicLinear(BasicOperation):
# Save state for backward pass
# Save state for backward pass
if
ctx
.
requires_grad
:
if
ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x_local
)
ctx
.
save_for_backward
(
x_local
,
w
)
ctx
.
save_for_backward
(
x_local
,
w
)
ctx
.
with_quantized_compute
=
with_quantized_compute
ctx
.
with_quantized_compute
=
with_quantized_compute
ctx
.
input_quantizer
=
input_quantizer
ctx
.
input_quantizer
=
input_quantizer
...
@@ -979,20 +988,22 @@ class BasicLinear(BasicOperation):
...
@@ -979,20 +988,22 @@ class BasicLinear(BasicOperation):
# Saved tensors from forward pass
# Saved tensors from forward pass
(
x_local
,
w
)
=
ctx
.
saved_tensors
(
x_local
,
w
)
=
ctx
.
saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad
=
self
.
_accumulate_into_main_grad
accumulate_into_main_grad
=
self
.
_accumulate_into_main_grad
grad_weight
=
None
grad_weight
=
None
if
ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
self
.
weight
,
"__fsdp_param__"
):
weight_param
=
self
.
weight
self
.
weight
.
main_grad
=
self
.
weight
.
get_main_grad
()
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
if
not
hasattr
(
self
.
weight
,
"main_grad"
):
if
not
hasattr
(
weight
_param
,
"main_grad"
):
raise
RuntimeError
(
raise
RuntimeError
(
"BasicLinear op is configured with "
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
"but weight parameter does not have main_grad attribute"
)
)
grad_weight
=
self
.
weight
.
main_grad
.
detach
()
grad_weight
=
weight
_param
.
main_grad
.
detach
()
else
:
else
:
accumulate_into_main_grad
=
False
accumulate_into_main_grad
=
False
...
@@ -1019,6 +1030,17 @@ class BasicLinear(BasicOperation):
...
@@ -1019,6 +1030,17 @@ class BasicLinear(BasicOperation):
# Clear input tensor if possible
# Clear input tensor if possible
clear_tensor_data
(
x_local
)
clear_tensor_data
(
x_local
)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if
accumulate_into_main_grad
:
if
accumulate_into_main_grad
:
grad_weight
=
None
grad_weight
=
None
weight_param
=
self
.
weight
if
hasattr
(
weight_param
,
"grad_added_to_main_grad"
):
weight_param
.
grad_added_to_main_grad
=
True
grad_weight
=
get_dummy_wgrad
(
list
(
weight_param
.
size
()),
weight_param
.
dtype
,
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
)
return
grad_input
,
[
grad_weight
]
return
grad_input
,
[
grad_weight
]
transformer_engine/pytorch/ops/basic/dropout.py
View file @
27ddce40
...
@@ -8,12 +8,12 @@ from __future__ import annotations
...
@@ -8,12 +8,12 @@ from __future__ import annotations
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.ops.op
import
(
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
BasicOperation
,
OperationContext
,
)
from
...tensor
import
Quantizer
from
...tensor
import
Quantizer
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
from
.._common
import
maybe_autocast_dtype
,
maybe_dequantize
from
..op
import
BasicOperation
,
OperationContext
class
Dropout
(
BasicOperation
):
class
Dropout
(
BasicOperation
):
...
@@ -27,7 +27,7 @@ class Dropout(BasicOperation):
...
@@ -27,7 +27,7 @@ class Dropout(BasicOperation):
def
__init__
(
self
,
p
:
float
)
->
None
:
def
__init__
(
self
,
p
:
float
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
dropout_probability
=
p
self
.
dropout_probability
:
float
=
p
def
op_forward
(
def
op_forward
(
self
,
self
,
...
@@ -37,21 +37,46 @@ class Dropout(BasicOperation):
...
@@ -37,21 +37,46 @@ class Dropout(BasicOperation):
next_op_input_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Compute dropout if training
# Output dtype
out
=
input_
dtype
=
maybe_autocast_dtype
(
default_dtype
=
input_
.
dtype
)
is_training
=
self
.
training
mask
=
None
# Choose implementation
if
is_training
:
impl
=
None
if
not
self
.
training
:
impl
=
"evaluation"
elif
input_
.
numel
()
%
16
==
0
and
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
):
impl
=
"fused"
else
:
impl
=
"unfused"
# Perform dropout
out
:
torch
.
Tensor
mask
:
Optional
[
torch
.
Tensor
]
=
None
if
impl
==
"evaluation"
:
out
=
input_
elif
impl
==
"fused"
:
x
=
input_
if
not
isinstance
(
x
,
Float8TensorBase
):
x
=
maybe_dequantize
(
x
,
dtype
=
dtype
)
out
,
mask
=
tex
.
dropout_fwd
(
x
,
self
.
dropout_probability
)
elif
impl
==
"unfused"
:
x
=
maybe_dequantize
(
input_
,
dtype
=
dtype
)
keep_prob
=
1
-
self
.
dropout_probability
keep_prob
=
1
-
self
.
dropout_probability
mask
=
torch
.
empty_like
(
input_
)
mask
=
torch
.
empty_like
(
x
)
mask
.
bernoulli_
(
keep_prob
)
mask
.
bernoulli_
(
keep_prob
)
mask
*=
1
/
keep_prob
mask
*=
1
/
keep_prob
out
=
out
*
mask
out
=
x
*
mask
else
:
raise
ValueError
(
f
"Unsupported forward implementation
{
impl
}
"
)
# Save context for backward
# Save context for backward
if
ctx
.
requires_grad
:
if
ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
mask
)
ctx
.
save_for_backward
(
mask
)
ctx
.
save_for_backward
(
mask
)
ctx
.
is_training
=
is_training
ctx
.
impl
=
impl
ctx
.
dropout_probability
=
self
.
dropout_probability
ctx
.
dtype
=
dtype
return
out
return
out
...
@@ -60,8 +85,21 @@ class Dropout(BasicOperation):
...
@@ -60,8 +85,21 @@ class Dropout(BasicOperation):
ctx
:
OperationContext
,
ctx
:
OperationContext
,
grad_output
:
torch
.
Tensor
,
grad_output
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
tuple
[()]]:
)
->
tuple
[
torch
.
Tensor
,
tuple
[()]]:
# Saved tensors from forward pass
(
mask
,)
=
ctx
.
saved_tensors
(
mask
,)
=
ctx
.
saved_tensors
grad_input
=
grad_output
if
ctx
.
is_training
:
# Perform dropout backward pass
grad_input
=
grad_input
*
mask
grad_input
:
torch
.
Tensor
if
ctx
.
impl
==
"evaluation"
:
grad_input
=
grad_output
elif
ctx
.
impl
==
"fused"
:
dy
=
maybe_dequantize
(
grad_output
,
dtype
=
ctx
.
dtype
)
grad_input
=
tex
.
dropout_bwd
(
dy
,
mask
,
ctx
.
dropout_probability
)
elif
ctx
.
impl
==
"unfused"
:
dy
=
maybe_dequantize
(
grad_output
,
dtype
=
ctx
.
dtype
)
grad_input
=
dy
*
mask
else
:
raise
ValueError
(
f
"Unsupported backward implementation
{
ctx
.
impl
}
"
)
return
grad_input
,
()
return
grad_input
,
()
transformer_engine/pytorch/ops/basic/l2normalization.py
View file @
27ddce40
...
@@ -10,10 +10,8 @@ import os
...
@@ -10,10 +10,8 @@ import os
import
torch
import
torch
from
...utils
import
clear_tensor_data
from
...
import
torch_version
from
...
import
torch_version
from
.._common
import
maybe_dequantize
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..op
import
BasicOperation
,
OperationContext
from
...jit
import
(
from
...jit
import
(
l2normalization_fused
,
l2normalization_fused
,
l2normalization_fwd_fused
,
l2normalization_fwd_fused
,
...
@@ -22,6 +20,9 @@ from ...jit import (
...
@@ -22,6 +20,9 @@ from ...jit import (
warmup_jit_l2normalization_all_dtypes
,
warmup_jit_l2normalization_all_dtypes
,
)
)
from
...tensor
import
Quantizer
from
...tensor
import
Quantizer
from
...utils
import
clear_tensor_data
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_dequantize
class
L2Normalization
(
BasicOperation
):
class
L2Normalization
(
BasicOperation
):
...
@@ -101,6 +102,8 @@ class L2Normalization(BasicOperation):
...
@@ -101,6 +102,8 @@ class L2Normalization(BasicOperation):
# Save state for backward pass
# Save state for backward pass
if
requires_grad
:
if
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x
,
rsqrt_norm
)
ctx
.
save_for_backward
(
x
,
rsqrt_norm
)
ctx
.
save_for_backward
(
x
,
rsqrt_norm
)
return
y
return
y
...
...
transformer_engine/pytorch/ops/basic/layer_norm.py
View file @
27ddce40
...
@@ -14,6 +14,9 @@ import torch
...
@@ -14,6 +14,9 @@ import torch
from
transformer_engine_torch
import
layernorm_bwd
,
layernorm_fwd
from
transformer_engine_torch
import
layernorm_bwd
,
layernorm_fwd
from
...constants
import
TE_DType
from
...constants
import
TE_DType
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...export
import
is_in_onnx_export_mode
from
...tensor
import
Quantizer
from
...utils
import
(
from
...utils
import
(
canonicalize_device
,
canonicalize_device
,
canonicalize_dtype
,
canonicalize_dtype
,
...
@@ -22,8 +25,6 @@ from ...utils import (
...
@@ -22,8 +25,6 @@ from ...utils import (
)
)
from
..op
import
BasicOperation
,
OperationContext
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_autocast_dtype
,
maybe_dequantize
from
.._common
import
maybe_autocast_dtype
,
maybe_dequantize
from
...export
import
is_in_onnx_export_mode
from
...tensor
import
Quantizer
class
LayerNorm
(
BasicOperation
):
class
LayerNorm
(
BasicOperation
):
...
@@ -215,6 +216,8 @@ class LayerNorm(BasicOperation):
...
@@ -215,6 +216,8 @@ class LayerNorm(BasicOperation):
# Save state for backward pass
# Save state for backward pass
if
ctx
.
requires_grad
:
if
ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x
,
means
,
rstdevs
)
ctx
.
save_for_backward
(
x
,
means
,
rstdevs
)
ctx
.
save_for_backward
(
x
,
means
,
rstdevs
)
ctx
.
dtype
=
dtype
ctx
.
dtype
=
dtype
...
...
transformer_engine/pytorch/ops/basic/rmsnorm.py
View file @
27ddce40
...
@@ -14,6 +14,9 @@ import torch
...
@@ -14,6 +14,9 @@ import torch
from
transformer_engine_torch
import
rmsnorm_bwd
,
rmsnorm_fwd
from
transformer_engine_torch
import
rmsnorm_bwd
,
rmsnorm_fwd
from
...constants
import
TE_DType
from
...constants
import
TE_DType
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...export
import
is_in_onnx_export_mode
from
...tensor
import
Quantizer
from
...utils
import
(
from
...utils
import
(
canonicalize_device
,
canonicalize_device
,
canonicalize_dtype
,
canonicalize_dtype
,
...
@@ -22,8 +25,6 @@ from ...utils import (
...
@@ -22,8 +25,6 @@ from ...utils import (
)
)
from
..op
import
BasicOperation
,
OperationContext
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_autocast_dtype
,
maybe_dequantize
from
.._common
import
maybe_autocast_dtype
,
maybe_dequantize
from
...export
import
is_in_onnx_export_mode
from
...tensor
import
Quantizer
class
RMSNorm
(
BasicOperation
):
class
RMSNorm
(
BasicOperation
):
...
@@ -196,6 +197,8 @@ class RMSNorm(BasicOperation):
...
@@ -196,6 +197,8 @@ class RMSNorm(BasicOperation):
# Save state for backward pass
# Save state for backward pass
if
ctx
.
requires_grad
:
if
ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x
,
rstdevs
)
ctx
.
save_for_backward
(
x
,
rstdevs
)
ctx
.
save_for_backward
(
x
,
rstdevs
)
ctx
.
dtype
=
dtype
ctx
.
dtype
=
dtype
...
...
transformer_engine/pytorch/ops/fused/__init__.py
View file @
27ddce40
...
@@ -8,6 +8,10 @@ from .backward_activation_bias import (
...
@@ -8,6 +8,10 @@ from .backward_activation_bias import (
BackwardActivationBias
,
BackwardActivationBias
,
fuse_backward_activation_bias
,
fuse_backward_activation_bias
,
)
)
from
.backward_add_rmsnorm
import
(
BackwardAddRMSNorm
,
fuse_backward_add_rmsnorm
,
)
from
.backward_linear_add
import
(
from
.backward_linear_add
import
(
BackwardLinearAdd
,
BackwardLinearAdd
,
fuse_backward_linear_add
,
fuse_backward_linear_add
,
...
...
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
0 → 100644
View file @
27ddce40
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused backward RMNorm + add."""
from
__future__
import
annotations
from
typing
import
Optional
import
math
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.ops.basic
import
MakeExtraOutput
,
RMSNorm
from
transformer_engine.pytorch.ops.op
import
(
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
...utils
import
clear_tensor_data
from
.._common
import
maybe_dequantize
class
BackwardAddRMSNorm
(
FusedOperation
):
"""Fused backward RMNorm + add"""
def
__init__
(
self
,
*
,
add
:
MakeExtraOutput
,
rmsnorm
:
RMSNorm
):
super
().
__init__
((
add
,
rmsnorm
))
def
fuser_backward
(
self
,
basic_op_ctxs
:
list
[
OperationContext
],
grad_output
:
torch
.
Tensor
,
*
,
basic_op_grad_extra_outputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
)
->
tuple
[
torch
.
Tensor
,
list
[
tuple
[
Optional
[
torch
.
Tensor
],
...]],
list
[
tuple
[()]],
]:
# Get basic operations
rmsnorm_op
=
self
.
basic_ops
[
1
]
rmsnorm_op_ctx
=
basic_op_ctxs
[
0
]
# Saved tensors from forward pass
x
,
rstdevs
=
rmsnorm_op_ctx
.
saved_tensors
# Tensor dims
weight_dims
=
rmsnorm_op
.
weight
.
size
()
inner_dim
=
math
.
prod
(
weight_dims
)
# Check input tensors
dtype
=
rmsnorm_op_ctx
.
dtype
extra_grad
=
basic_op_grad_extra_outputs
[
1
][
0
]
dy
=
maybe_dequantize
(
grad_output
.
contiguous
(),
dtype
).
view
(
x
.
size
())
w
=
maybe_dequantize
(
rmsnorm_op
.
weight
,
dtype
).
view
((
inner_dim
,))
add
=
maybe_dequantize
(
extra_grad
.
contiguous
(),
dtype
).
view
(
x
.
size
())
# Compute RMSNorm backward pass
dx
,
dw
=
tex
.
rmsnorm_bwd_add
(
dy
,
x
,
add
,
rstdevs
,
w
,
rmsnorm_op
.
_sm_margins
[
"backward"
],
rmsnorm_op
.
zero_centered_gamma
,
)
# Clear saved tensors if possible
clear_tensor_data
(
x
)
clear_tensor_data
(
rstdevs
)
# Reshape results
grad_input
=
dx
.
view
(
grad_output
.
size
())
grad_weight
=
dw
.
view
(
weight_dims
)
return
grad_input
,
[(
grad_weight
,),
()],
[(),
()]
def
fuse_backward_add_rmsnorm
(
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
"""Fused backward RMNorm + add
Parameters
----------
ops: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out
=
[]
window
=
[]
while
len
(
ops
)
>=
2
:
out
.
extend
(
window
)
# Check if first op is linear
window
,
ops
=
ops
[:
1
],
ops
[
1
:]
op
,
_
=
window
[
0
]
if
not
isinstance
(
op
,
RMSNorm
):
continue
# Check if second op is "make extra output"
op
,
_
=
ops
[
0
]
if
not
isinstance
(
op
,
MakeExtraOutput
):
continue
if
op
.
_in_place
:
continue
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
# Replace window with fused op
op
=
BackwardAddRMSNorm
(
rmsnorm
=
window
[
0
][
0
],
add
=
window
[
1
][
0
],
)
basic_op_idxs
=
[
basic_op_idxs
[
0
]
for
_
,
basic_op_idxs
in
window
]
window
=
[(
op
,
basic_op_idxs
)]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
ops
)
return
out
transformer_engine/pytorch/ops/fused/backward_linear_add.py
View file @
27ddce40
...
@@ -9,13 +9,10 @@ from typing import Optional
...
@@ -9,13 +9,10 @@ from typing import Optional
import
torch
import
torch
from
transformer_engine.pytorch.ops.basic
import
BasicLinear
,
MakeExtraOutput
from
...module.base
import
get_dummy_wgrad
from
transformer_engine.pytorch.ops.op
import
(
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
...utils
import
clear_tensor_data
from
...utils
import
clear_tensor_data
from
..basic
import
BasicLinear
,
MakeExtraOutput
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
class
BackwardLinearAdd
(
FusedOperation
):
class
BackwardLinearAdd
(
FusedOperation
):
...
@@ -53,20 +50,22 @@ class BackwardLinearAdd(FusedOperation):
...
@@ -53,20 +50,22 @@ class BackwardLinearAdd(FusedOperation):
# Saved tensors from forward pass
# Saved tensors from forward pass
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
grad_weight
=
None
grad_weight
=
None
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
linear_op
.
weight
,
"__fsdp_param__"
):
weight_param
=
linear_op
.
weight
linear_op
.
weight
.
main_grad
=
linear_op
.
weight
.
get_main_grad
()
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
if
not
hasattr
(
linear_op
.
weight
,
"main_grad"
):
if
not
hasattr
(
weight
_param
,
"main_grad"
):
raise
RuntimeError
(
raise
RuntimeError
(
"BasicLinear op is configured with "
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
"but weight parameter does not have main_grad attribute"
)
)
grad_weight
=
linear_op
.
weight
.
main_grad
.
detach
()
grad_weight
=
weight
_param
.
main_grad
.
detach
()
else
:
else
:
accumulate_into_main_grad
=
False
accumulate_into_main_grad
=
False
...
@@ -92,12 +91,23 @@ class BackwardLinearAdd(FusedOperation):
...
@@ -92,12 +91,23 @@ class BackwardLinearAdd(FusedOperation):
grad_output_quantizer
=
linear_op_ctx
.
grad_output_quantizer
,
grad_output_quantizer
=
linear_op_ctx
.
grad_output_quantizer
,
grad_input_quantizer
=
linear_op_ctx
.
grad_input_quantizer
,
grad_input_quantizer
=
linear_op_ctx
.
grad_input_quantizer
,
)
)
if
accumulate_into_main_grad
:
grad_weight
=
None
# Clear input tensor if possible
# Clear input tensor if possible
clear_tensor_data
(
x_local
)
clear_tensor_data
(
x_local
)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if
accumulate_into_main_grad
:
grad_weight
=
None
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"grad_added_to_main_grad"
):
weight_param
.
grad_added_to_main_grad
=
True
grad_weight
=
get_dummy_wgrad
(
list
(
weight_param
.
size
()),
weight_param
.
dtype
,
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
)
return
grad_input
,
[(
grad_weight
,),
()],
[(),
()]
return
grad_input
,
[(
grad_weight
,),
()],
[(),
()]
...
...
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
View file @
27ddce40
...
@@ -9,13 +9,10 @@ from typing import Optional
...
@@ -9,13 +9,10 @@ from typing import Optional
import
torch
import
torch
from
..basic
import
BasicLinear
,
ConstantScale
from
...module.base
import
get_dummy_wgrad
from
..op
import
(
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
...utils
import
clear_tensor_data
from
...utils
import
clear_tensor_data
from
..basic
import
BasicLinear
,
ConstantScale
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
class
BackwardLinearScale
(
FusedOperation
):
class
BackwardLinearScale
(
FusedOperation
):
...
@@ -54,20 +51,22 @@ class BackwardLinearScale(FusedOperation):
...
@@ -54,20 +51,22 @@ class BackwardLinearScale(FusedOperation):
# Saved tensors from forward pass
# Saved tensors from forward pass
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
grad_weight
=
None
grad_weight
=
None
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
linear_op
.
weight
,
"__fsdp_param__"
):
weight_param
=
linear_op
.
weight
linear_op
.
weight
.
main_grad
=
linear_op
.
weight
.
get_main_grad
()
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
if
not
hasattr
(
linear_op
.
weight
,
"main_grad"
):
if
not
hasattr
(
weight
_param
,
"main_grad"
):
raise
RuntimeError
(
raise
RuntimeError
(
"BasicLinear op is configured with "
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
"but weight parameter does not have main_grad attribute"
)
)
grad_weight
=
linear_op
.
weight
.
main_grad
.
detach
()
grad_weight
=
weight
_param
.
main_grad
.
detach
()
else
:
else
:
accumulate_into_main_grad
=
False
accumulate_into_main_grad
=
False
...
@@ -92,12 +91,23 @@ class BackwardLinearScale(FusedOperation):
...
@@ -92,12 +91,23 @@ class BackwardLinearScale(FusedOperation):
grad_output_quantizer
=
linear_op_ctx
.
grad_output_quantizer
,
grad_output_quantizer
=
linear_op_ctx
.
grad_output_quantizer
,
grad_input_quantizer
=
linear_op_ctx
.
grad_input_quantizer
,
grad_input_quantizer
=
linear_op_ctx
.
grad_input_quantizer
,
)
)
if
accumulate_into_main_grad
:
grad_weight
=
None
# Clear input tensor if possible
# Clear input tensor if possible
clear_tensor_data
(
x_local
)
clear_tensor_data
(
x_local
)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if
accumulate_into_main_grad
:
grad_weight
=
None
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"grad_added_to_main_grad"
):
weight_param
.
grad_added_to_main_grad
=
True
grad_weight
=
get_dummy_wgrad
(
list
(
weight_param
.
size
()),
weight_param
.
dtype
,
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
)
return
grad_input
,
[(),
(
grad_weight
,)],
[(),
()]
return
grad_input
,
[(),
(
grad_weight
,)],
[(),
()]
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
View file @
27ddce40
...
@@ -10,14 +10,11 @@ from typing import Any, Optional
...
@@ -10,14 +10,11 @@ from typing import Any, Optional
import
torch
import
torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
transformer_engine.pytorch.ops.basic
import
BasicLinear
,
Bias
from
...fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.ops.op
import
(
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
...tensor
import
Quantizer
from
...tensor
import
Quantizer
from
..basic
import
BasicLinear
,
Bias
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
class
ForwardLinearBiasActivation
(
FusedOperation
):
class
ForwardLinearBiasActivation
(
FusedOperation
):
...
@@ -121,6 +118,8 @@ class ForwardLinearBiasActivation(FusedOperation):
...
@@ -121,6 +118,8 @@ class ForwardLinearBiasActivation(FusedOperation):
# Save state for backward pass
# Save state for backward pass
if
linear_op_ctx
.
requires_grad
:
if
linear_op_ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x_local
)
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input_quantizer
=
input_quantizer
linear_op_ctx
.
input_quantizer
=
input_quantizer
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
View file @
27ddce40
...
@@ -10,14 +10,11 @@ from typing import Any, Optional
...
@@ -10,14 +10,11 @@ from typing import Any, Optional
import
torch
import
torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
transformer_engine.pytorch.ops.basic
import
AddExtraInput
,
BasicLinear
,
Bias
from
...fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.ops.op
import
(
from
...tensor
import
Quantizer
FusedOperation
,
from
..basic
import
AddExtraInput
,
BasicLinear
,
Bias
FusibleOperation
,
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
OperationContext
,
)
from
transformer_engine.pytorch.tensor
import
Quantizer
class
ForwardLinearBiasAdd
(
FusedOperation
):
class
ForwardLinearBiasAdd
(
FusedOperation
):
...
@@ -118,6 +115,8 @@ class ForwardLinearBiasAdd(FusedOperation):
...
@@ -118,6 +115,8 @@ class ForwardLinearBiasAdd(FusedOperation):
# Save state for backward pass
# Save state for backward pass
if
linear_op_ctx
.
requires_grad
:
if
linear_op_ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x_local
)
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input_quantizer
=
input_quantizer
linear_op_ctx
.
input_quantizer
=
input_quantizer
...
...
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
View file @
27ddce40
...
@@ -10,14 +10,15 @@ from typing import Any, Optional
...
@@ -10,14 +10,15 @@ from typing import Any, Optional
import
torch
import
torch
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...fp8
import
FP8GlobalStateManager
from
...fp8
import
FP8GlobalStateManager
from
...tensor
import
Quantizer
from
..basic
import
AddExtraInput
,
BasicLinear
,
ConstantScale
from
..basic
import
AddExtraInput
,
BasicLinear
,
ConstantScale
from
..op
import
(
from
..op
import
(
FusedOperation
,
FusedOperation
,
FusibleOperation
,
FusibleOperation
,
OperationContext
,
OperationContext
,
)
)
from
...tensor
import
Quantizer
class
ForwardLinearScaleAdd
(
FusedOperation
):
class
ForwardLinearScaleAdd
(
FusedOperation
):
...
@@ -95,6 +96,8 @@ class ForwardLinearScaleAdd(FusedOperation):
...
@@ -95,6 +96,8 @@ class ForwardLinearScaleAdd(FusedOperation):
# Save state for backward pass
# Save state for backward pass
if
linear_op_ctx
.
requires_grad
:
if
linear_op_ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x_local
)
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input_quantizer
=
input_quantizer
linear_op_ctx
.
input_quantizer
=
input_quantizer
...
...
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
27ddce40
...
@@ -14,11 +14,12 @@ from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_exter
...
@@ -14,11 +14,12 @@ from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_exter
from
...cpp_extensions
import
general_gemm
from
...cpp_extensions
import
general_gemm
from
...distributed
import
get_distributed_world_size
from
...distributed
import
get_distributed_world_size
from
...module.base
import
(
from
...module.base
import
(
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
fill_userbuffers_buffer_for_all_gather
,
fill_userbuffers_buffer_for_all_gather
,
get_dummy_wgrad
,
get_ub
,
get_ub
,
get_workspace
,
get_workspace
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
)
)
from
...tensor.quantized_tensor
import
Quantizer
from
...tensor.quantized_tensor
import
Quantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
...
@@ -240,16 +241,16 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -240,16 +241,16 @@ class UserbuffersBackwardLinear(FusedOperation):
with_dgrad_all_gather_x
=
False
with_dgrad_all_gather_x
=
False
with_wgrad_reduce_scatter_dx
=
False
with_wgrad_reduce_scatter_dx
=
False
if
tensor_parallel_mode
==
"row"
:
if
tensor_parallel_mode
==
"row"
:
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
,
with_quantized_compute
)
ub_type_dgrad
=
CommOverlapType
.
AG
ub_type_dgrad
=
CommOverlapType
.
AG
with_dgrad_all_gather_dy
=
True
with_dgrad_all_gather_dy
=
True
elif
tensor_parallel_mode
==
"column"
:
elif
tensor_parallel_mode
==
"column"
:
if
input_requires_grad
and
weight_requires_grad
:
if
input_requires_grad
and
weight_requires_grad
:
with_bulk_overlap
=
True
with_bulk_overlap
=
True
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
,
with_quantized_compute
)
ub_type_dgrad
=
CommOverlapType
.
AG
ub_type_dgrad
=
CommOverlapType
.
AG
with_dgrad_all_gather_x
=
True
with_dgrad_all_gather_x
=
True
ub_comm_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
)
ub_comm_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
,
with_quantized_compute
)
ub_type_wgrad
=
CommOverlapType
.
RS
ub_type_wgrad
=
CommOverlapType
.
RS
with_wgrad_reduce_scatter_dx
=
True
with_wgrad_reduce_scatter_dx
=
True
if
ub_comm_wgrad
.
is_fp8_ubuf
():
if
ub_comm_wgrad
.
is_fp8_ubuf
():
...
@@ -257,7 +258,7 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -257,7 +258,7 @@ class UserbuffersBackwardLinear(FusedOperation):
"Userbuffers reduce-scatter is not supported with FP8 buffers"
"Userbuffers reduce-scatter is not supported with FP8 buffers"
)
)
else
:
else
:
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
)
ub_comm_dgrad
=
get_ub
(
ub_comm_name
+
"_dgrad"
,
with_quantized_compute
)
ub_type_dgrad
=
CommOverlapType
.
RS
ub_type_dgrad
=
CommOverlapType
.
RS
with_dgrad_reduce_scatter_dx
=
True
with_dgrad_reduce_scatter_dx
=
True
if
ub_comm_dgrad
.
is_fp8_ubuf
():
if
ub_comm_dgrad
.
is_fp8_ubuf
():
...
@@ -408,7 +409,7 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -408,7 +409,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get the communication stream from the dgrad GEMM to use for the AG
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream
,
dgrad_recv_stream
=
ub_comm_dgrad
.
get_communication_stream
()
dgrad_send_stream
,
dgrad_recv_stream
=
ub_comm_dgrad
.
get_communication_stream
()
ub_obj_overlap_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
)
ub_obj_overlap_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
,
with_quantized_compute
)
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
@@ -513,20 +514,22 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -513,20 +514,22 @@ class UserbuffersBackwardLinear(FusedOperation):
# Saved tensors from forward pass
# Saved tensors from forward pass
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
grad_weight
=
None
grad_weight
=
None
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
linear_op
.
weight
,
"__fsdp_param__"
):
weight_param
=
linear_op
.
weight
linear_op
.
weight
.
main_grad
=
linear_op
.
weight
.
get_main_grad
()
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
if
not
hasattr
(
linear_op
.
weight
,
"main_grad"
):
if
not
hasattr
(
weight
_param
,
"main_grad"
):
raise
RuntimeError
(
raise
RuntimeError
(
"BasicLinear op is configured with "
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
"but weight parameter does not have main_grad attribute"
)
)
grad_weight
=
linear_op
.
weight
.
main_grad
.
detach
()
grad_weight
=
weight
_param
.
main_grad
.
detach
()
else
:
else
:
accumulate_into_main_grad
=
False
accumulate_into_main_grad
=
False
...
@@ -558,10 +561,21 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -558,10 +561,21 @@ class UserbuffersBackwardLinear(FusedOperation):
# Clear input tensor if possible
# Clear input tensor if possible
clear_tensor_data
(
x_local
)
clear_tensor_data
(
x_local
)
#
Return gradients
#
Megatron-LM wgrad fusion
grad_params
=
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
# Note: Return dummy tensor for grad weight if needed.
if
accumulate_into_main_grad
:
if
accumulate_into_main_grad
:
grad_weight
=
None
grad_weight
=
None
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"grad_added_to_main_grad"
):
weight_param
.
grad_added_to_main_grad
=
True
grad_weight
=
get_dummy_wgrad
(
list
(
weight_param
.
size
()),
weight_param
.
dtype
,
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
)
# Return gradients
grad_params
=
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
grad_params
[
self
.
_op_idxs
[
"linear"
]]
=
(
grad_weight
,)
grad_params
[
self
.
_op_idxs
[
"linear"
]]
=
(
grad_weight
,)
if
bias_op
is
not
None
:
if
bias_op
is
not
None
:
grad_params
[
self
.
_op_idxs
[
"bias"
]]
=
(
grad_bias
,)
grad_params
[
self
.
_op_idxs
[
"bias"
]]
=
(
grad_bias
,)
...
...
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
27ddce40
...
@@ -12,6 +12,7 @@ import torch
...
@@ -12,6 +12,7 @@ import torch
from
transformer_engine_torch
import
CommOverlapType
from
transformer_engine_torch
import
CommOverlapType
from
...cpp_extensions
import
general_gemm
from
...cpp_extensions
import
general_gemm
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...distributed
import
get_distributed_world_size
from
...distributed
import
get_distributed_world_size
from
...fp8
import
FP8GlobalStateManager
from
...fp8
import
FP8GlobalStateManager
from
...module.base
import
(
from
...module.base
import
(
...
@@ -189,7 +190,7 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -189,7 +190,7 @@ class UserbuffersForwardLinear(FusedOperation):
output_quantizer
=
None
output_quantizer
=
None
# Get Userbuffers communicator
# Get Userbuffers communicator
ub_comm
=
get_ub
(
ub_comm_name
+
"_fprop"
)
ub_comm
=
get_ub
(
ub_comm_name
+
"_fprop"
,
with_quantized_compute
)
with_ub_all_gather
=
tensor_parallel_mode
==
"column"
with_ub_all_gather
=
tensor_parallel_mode
==
"column"
with_ub_reduce_scatter
=
tensor_parallel_mode
==
"row"
with_ub_reduce_scatter
=
tensor_parallel_mode
==
"row"
ub_type
=
CommOverlapType
.
AG
if
with_ub_all_gather
else
CommOverlapType
.
RS
ub_type
=
CommOverlapType
.
AG
if
with_ub_all_gather
else
CommOverlapType
.
RS
...
@@ -353,6 +354,8 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -353,6 +354,8 @@ class UserbuffersForwardLinear(FusedOperation):
# Save state for backward pass
# Save state for backward pass
if
linear_op_ctx
.
requires_grad
:
if
linear_op_ctx
.
requires_grad
:
if
is_cpu_offload_enabled
():
mark_activation_offload
(
x_local
)
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input_quantizer
=
input_quantizer
linear_op_ctx
.
input_quantizer
=
input_quantizer
...
...
Prev
1
…
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment