Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
063ef88d
Commit
063ef88d
authored
Dec 03, 2025
by
wenjh
Browse files
Merge nv main up to v2.10.0.dev0
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
91670b05
5624dbb4
Changes
298
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
387 additions
and
157 deletions
+387
-157
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+38
-41
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+55
-18
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+88
-25
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+58
-19
transformer_engine/pytorch/ops/_common.py
transformer_engine/pytorch/ops/_common.py
+5
-5
transformer_engine/pytorch/ops/basic/__init__.py
transformer_engine/pytorch/ops/basic/__init__.py
+13
-1
transformer_engine/pytorch/ops/basic/activation.py
transformer_engine/pytorch/ops/basic/activation.py
+36
-0
transformer_engine/pytorch/ops/basic/basic_linear.py
transformer_engine/pytorch/ops/basic/basic_linear.py
+54
-29
transformer_engine/pytorch/ops/basic/dropout.py
transformer_engine/pytorch/ops/basic/dropout.py
+2
-2
transformer_engine/pytorch/ops/basic/quantize.py
transformer_engine/pytorch/ops/basic/quantize.py
+3
-3
transformer_engine/pytorch/ops/fused/backward_activation_bias.py
...rmer_engine/pytorch/ops/fused/backward_activation_bias.py
+1
-1
transformer_engine/pytorch/ops/fused/backward_linear_add.py
transformer_engine/pytorch/ops/fused/backward_linear_add.py
+1
-0
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
...sformer_engine/pytorch/ops/fused/backward_linear_scale.py
+1
-0
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
...ngine/pytorch/ops/fused/forward_linear_bias_activation.py
+2
-2
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
...ormer_engine/pytorch/ops/fused/forward_linear_bias_add.py
+2
-2
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
...rmer_engine/pytorch/ops/fused/forward_linear_scale_add.py
+2
-2
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+1
-0
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+3
-3
transformer_engine/pytorch/ops/fuser.py
transformer_engine/pytorch/ops/fuser.py
+5
-1
transformer_engine/pytorch/ops/op.py
transformer_engine/pytorch/ops/op.py
+17
-3
No files found.
transformer_engine/pytorch/module/grouped_linear.py
View file @
063ef88d
...
@@ -22,7 +22,7 @@ from .base import (
...
@@ -22,7 +22,7 @@ from .base import (
_2X_ACC_WGRAD
,
_2X_ACC_WGRAD
,
)
)
from
._common
import
WeightGradStore
from
._common
import
WeightGradStore
from
..
fp8
import
FP8GlobalStateManager
from
..
quantization
import
FP8GlobalStateManager
from
..utils
import
(
from
..utils
import
(
divide
,
divide
,
cast_if_needed
,
cast_if_needed
,
...
@@ -46,7 +46,7 @@ from ..cpu_offload import is_cpu_offload_enabled
...
@@ -46,7 +46,7 @@ from ..cpu_offload import is_cpu_offload_enabled
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.quantized_tensor
import
(
from
..tensor.quantized_tensor
import
(
QuantizedTensor
Bas
e
,
QuantizedTensor
Storag
e
,
Quantizer
,
Quantizer
,
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
...
@@ -204,39 +204,27 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -204,39 +204,27 @@ class _GroupedLinear(torch.autograd.Function):
inputmats
[
0
]
=
inp
inputmats
[
0
]
=
inp
else
:
else
:
for
inputmat
in
inputmats
:
for
inputmat
in
inputmats
:
if
isinstance
(
inputmat
,
QuantizedTensor
Bas
e
):
if
isinstance
(
inputmat
,
QuantizedTensor
Storag
e
):
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
else
:
else
:
inputmats
=
[
None
]
*
num_gemms
inputmats
=
[
None
]
*
num_gemms
if
inp
.
requires_grad
:
if
inp
.
requires_grad
:
for
weight
in
weights_fp8
:
for
weight
in
weights_fp8
:
if
isinstance
(
weight
,
QuantizedTensor
Bas
e
):
if
isinstance
(
weight
,
QuantizedTensor
Storag
e
):
weight
.
update_usage
(
columnwise_usage
=
True
)
weight
.
update_usage
(
columnwise_usage
=
True
)
for
i
in
range
(
num_gemms
):
if
cpu_offloading
:
weights
[
i
].
offloading_activation
=
False
ctx
.
grad_added_to_main_grad
=
hasattr
(
weights
[
0
],
"grad_added_to_main_grad"
)
weights_fp8
[
i
].
offloading_activation
=
False
biases
[
i
].
offloading_activation
=
False
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
if
fine_grained_activation_offloading
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
)
if
(
if
ctx
.
grad_added_to_main_grad
:
fine_grained_activation_offloading
# If you are passing torch.nn.Parameter through the Torch hooks, you will
and
weights
[
0
].
requires_grad
# get back torch.Tensor. Torch rips off the Parameter wrapper.
and
fuse_wgrad_accumulation
# You need to preserve the weight object to have all the attributes user
):
# sets for the weights. Because of this, it is not recommended to offload
grad_added_to_main_grad_list
=
[]
# weights if weights are externally touched outside this module
for
weight
in
weights
:
ctx
.
weight_objects
=
[]
if
weight
.
requires_grad
and
hasattr
(
weight
,
"grad_added_to_main_grad"
):
for
weight
in
weights
:
grad_added_to_main_grad_list
.
append
(
weight
.
grad_added_to_main_grad
)
ctx
.
weight_objects
.
append
(
weight
)
weight
.
grad_added_to_main_grad
=
True
else
:
grad_added_to_main_grad_list
.
append
(
None
)
ctx
.
grad_added_to_main_grad_list
=
grad_added_to_main_grad_list
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
*
inputmats
,
*
inputmats
,
...
@@ -300,13 +288,15 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -300,13 +288,15 @@ class _GroupedLinear(torch.autograd.Function):
biases
=
saved_tensors
[
3
*
N
:
4
*
N
]
biases
=
saved_tensors
[
3
*
N
:
4
*
N
]
main_grads
=
[
main_grad_func
()
for
main_grad_func
in
ctx
.
main_grad_funcs
]
main_grads
=
[
main_grad_func
()
for
main_grad_func
in
ctx
.
main_grad_funcs
]
if
(
ctx
.
cpu_offloading
or
ctx
.
fine_grained_activation_offloading
)
and
ctx
.
fuse_wgrad_accumulation
:
if
ctx
.
cpu_offloading
:
for
i
in
range
(
ctx
.
num_gemms
):
if
ctx
.
grad_added_to_main_grad
:
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
for
i
,
weight
in
enumerate
(
ctx
.
weight_objects
):
w
.
main_grad
=
main_grads
[
i
]
origin_weights
[
i
]
=
ctx
.
weight_objects
[
i
]
weights
[
i
]
=
w
ctx
.
weight_objects
[
i
]
=
None
if
ctx
.
fine_grained_activation_offloading
and
weights
[
0
].
requires_grad
:
weights
[
i
].
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad_list
[
i
]
if
ctx
.
fuse_wgrad_accumulation
:
for
i
in
range
(
N
):
origin_weights
[
i
].
main_grad
=
main_grads
[
i
]
# Preprocess grad output
# Preprocess grad output
grad_output_view
=
grad_output
.
contiguous
().
view
(
-
1
,
grad_output
.
shape
[
-
1
])
grad_output_view
=
grad_output
.
contiguous
().
view
(
-
1
,
grad_output
.
shape
[
-
1
])
...
@@ -369,7 +359,7 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -369,7 +359,7 @@ class _GroupedLinear(torch.autograd.Function):
)
)
for
weight
,
quantizer
in
zip
(
weights
,
ctx
.
weight_quantizers
):
for
weight
,
quantizer
in
zip
(
weights
,
ctx
.
weight_quantizers
):
if
quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
Bas
e
):
if
quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
Storag
e
):
weight
.
update_usage
(
weight
.
update_usage
(
rowwise_usage
=
quantizer
.
rowwise_usage
,
rowwise_usage
=
quantizer
.
rowwise_usage
,
columnwise_usage
=
quantizer
.
columnwise_usage
,
columnwise_usage
=
quantizer
.
columnwise_usage
,
...
@@ -433,7 +423,11 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -433,7 +423,11 @@ class _GroupedLinear(torch.autograd.Function):
use_bias
=
ctx
.
use_bias
if
grad_biases
[
0
]
is
None
else
None
,
use_bias
=
ctx
.
use_bias
if
grad_biases
[
0
]
is
None
else
None
,
bias
=
biases
,
bias
=
biases
,
use_split_accumulator
=
wgrad_gemm_use_split_accumulator
,
use_split_accumulator
=
wgrad_gemm_use_split_accumulator
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
accumulate
=
(
accumulate_wgrad_into_param_main_grad
if
not
getattr
(
weights
[
0
],
"overwrite_main_grad"
,
False
)
else
False
),
)
)
# WGRAD
# WGRAD
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
...
@@ -555,7 +549,9 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -555,7 +549,9 @@ class GroupedLinear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
instead return the bias value during the forward pass together with the
...
@@ -772,7 +768,7 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -772,7 +768,7 @@ class GroupedLinear(TransformerEngineBaseModule):
produced)
produced)
"""
"""
assert
not
isinstance
(
assert
not
isinstance
(
inp
,
QuantizedTensor
Bas
e
inp
,
QuantizedTensor
Storag
e
),
"GroupedLinear doesn't support input tensor in FP8."
),
"GroupedLinear doesn't support input tensor in FP8."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
...
@@ -907,16 +903,17 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -907,16 +903,17 @@ class GroupedLinear(TransformerEngineBaseModule):
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"bwd"
]
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"bwd"
]
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensor
Bas
e
]]:
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensor
Storag
e
]]:
"""Get the weight tensors of the module."""
"""Get the weight tensors of the module."""
weight_tensors
=
[
getattr
(
self
,
f
"weight
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
weight_tensors
=
[
getattr
(
self
,
f
"weight
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
if
not
self
.
fp8
and
any
(
isinstance
(
w
,
QuantizedTensor
Bas
e
)
for
w
in
weight_tensors
):
if
not
self
.
fp8
and
any
(
isinstance
(
w
,
QuantizedTensor
Storag
e
)
for
w
in
weight_tensors
):
warnings
.
warn
(
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
"Please make sure this is intentional."
)
)
weight_tensors
=
[
weight_tensors
=
[
w
.
dequantize
()
if
isinstance
(
w
,
QuantizedTensorBase
)
else
w
for
w
in
weight_tensors
w
.
dequantize
()
if
isinstance
(
w
,
QuantizedTensorStorage
)
else
w
for
w
in
weight_tensors
]
]
return
weight_tensors
return
weight_tensors
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
063ef88d
...
@@ -16,6 +16,7 @@ import transformer_engine_torch as tex
...
@@ -16,6 +16,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.tensor.utils
import
is_experimental
from
.base
import
(
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
get_workspace
,
...
@@ -26,9 +27,10 @@ from .base import (
...
@@ -26,9 +27,10 @@ from .base import (
_2X_ACC_DGRAD
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
_2X_ACC_WGRAD
,
)
)
from
..
fp8
import
FP8GlobalStateManager
from
..
quantization
import
FP8GlobalStateManager
from
..utils
import
(
from
..utils
import
(
assert_dim_for_fp8_exec
,
assert_dim_for_fp8_exec
,
assert_dim_for_all_gather
,
cast_if_needed
,
cast_if_needed
,
clear_tensor_data
,
clear_tensor_data
,
divide
,
divide
,
...
@@ -57,7 +59,7 @@ from ..graph import is_graph_capturing
...
@@ -57,7 +59,7 @@ from ..graph import is_graph_capturing
from
._common
import
apply_normalization
,
noop_cat
,
WeightGradStore
from
._common
import
apply_normalization
,
noop_cat
,
WeightGradStore
from
..tensor.quantized_tensor
import
(
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensor
,
QuantizedTensor
Bas
e
,
QuantizedTensor
Storag
e
,
Quantizer
,
Quantizer
,
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
...
@@ -65,8 +67,8 @@ from ..tensor.quantized_tensor import (
...
@@ -65,8 +67,8 @@ from ..tensor.quantized_tensor import (
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.debug_state
import
TEDebugState
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.
_internal.mxfp8
_tensor_
bas
e
import
MXFP8TensorBas
e
from
..tensor.
storage.float8_blockwise
_tensor_
storag
e
import
Float8BlockwiseQTensorStorag
e
from
..tensor.
_internal.float8_blockwise
_tensor_
bas
e
import
Float8BlockwiseQTensorBas
e
from
..tensor.
storage.mxfp8
_tensor_
storag
e
import
MXFP8TensorStorag
e
from
..export
import
is_in_onnx_export_mode
,
assert_warmed_up
from
..export
import
is_in_onnx_export_mode
,
assert_warmed_up
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
...
@@ -144,6 +146,8 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -144,6 +146,8 @@ class _LayerNormLinear(torch.autograd.Function):
if
ub_name
is
not
None
:
if
ub_name
is
not
None
:
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ub_name
}
"
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ub_name
}
"
with_input_all_gather
=
parallel_mode
==
"column"
and
sequence_parallel
# Make sure input dimensions are compatible
# Make sure input dimensions are compatible
out_features
,
in_features
=
weight
.
shape
out_features
,
in_features
=
weight
.
shape
inp_shape
=
inp
.
shape
inp_shape
=
inp
.
shape
...
@@ -153,6 +157,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -153,6 +157,7 @@ class _LayerNormLinear(torch.autograd.Function):
inputmat
=
inp
inputmat
=
inp
if
fp8
:
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
assert_dim_for_all_gather
(
inputmat
,
with_input_all_gather
,
input_quantizer
)
# Cast for native AMP
# Cast for native AMP
nvtx_range_push
(
f
"
{
nvtx_label
}
.norm_input_cast"
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.norm_input_cast"
)
...
@@ -166,7 +171,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -166,7 +171,6 @@ class _LayerNormLinear(torch.autograd.Function):
weight_requires_grad
=
weight
.
requires_grad
weight_requires_grad
=
weight
.
requires_grad
backward_needs_input
=
is_grad_enabled
and
weight_requires_grad
backward_needs_input
=
is_grad_enabled
and
weight_requires_grad
with_input_all_gather
=
parallel_mode
==
"column"
and
sequence_parallel
# Configure Userbuffers communication (comm+GEMM overlap)
# Configure Userbuffers communication (comm+GEMM overlap)
if
debug
:
# turn off userbuffers in debug mode
if
debug
:
# turn off userbuffers in debug mode
...
@@ -199,11 +203,13 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -199,11 +203,13 @@ class _LayerNormLinear(torch.autograd.Function):
# Avoid quantized norm kernel if norm output will be returned
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
# or if a gather of ln_out must be in high precision.
experimental
=
is_experimental
(
input_quantizer
)
with_quantized_norm
=
(
with_quantized_norm
=
(
fp8
fp8
and
not
debug
and
not
debug
and
not
return_layernorm_output
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
and
not
return_layernorm_output_gathered
and
not
experimental
# TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom()
)
)
# Apply normalization
# Apply normalization
...
@@ -249,7 +255,8 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -249,7 +255,8 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer
=
None
quantizer
=
None
if
fp8
or
debug
:
if
fp8
or
debug
:
quantizer
=
input_quantizer
quantizer
=
input_quantizer
if
not
with_quantized_norm
:
# experimental recipe doesn't need to support quantized AG
if
not
with_quantized_norm
and
not
experimental
:
ln_out
=
quantizer
(
ln_out
)
ln_out
=
quantizer
(
ln_out
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag_fprop
:
# Initialize Userbuffers all-gather
if
ub_overlap_ag_fprop
:
# Initialize Userbuffers all-gather
...
@@ -280,7 +287,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -280,7 +287,7 @@ class _LayerNormLinear(torch.autograd.Function):
weightmat
=
weight
weightmat
=
weight
quantized_weight
=
False
quantized_weight
=
False
if
fp8
or
debug
:
if
fp8
or
debug
:
quantized_weight
=
not
isinstance
(
weight
,
QuantizedTensor
Bas
e
)
quantized_weight
=
not
isinstance
(
weight
,
QuantizedTensor
Storag
e
)
# Configure quantizer
# Configure quantizer
if
weight_quantizer
is
not
None
:
if
weight_quantizer
is
not
None
:
...
@@ -405,18 +412,18 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -405,18 +412,18 @@ class _LayerNormLinear(torch.autograd.Function):
# Input with column-wise usage is needed for wgrad GEMM.
# Input with column-wise usage is needed for wgrad GEMM.
if
backward_needs_input
:
if
backward_needs_input
:
if
isinstance
(
ln_out
,
QuantizedTensor
Bas
e
):
if
isinstance
(
ln_out
,
QuantizedTensor
Storag
e
):
# For sequence parallel in vanilla FP8, rowwise data is
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
# can be allgathered.
if
(
if
(
isinstance
(
ln_out
,
(
MXFP8Tensor
Bas
e
,
Float8BlockwiseQTensor
Bas
e
))
isinstance
(
ln_out
,
(
MXFP8Tensor
Storag
e
,
Float8BlockwiseQTensor
Storag
e
))
or
not
ctx
.
ln_out_needs_gather
or
not
ctx
.
ln_out_needs_gather
):
):
ln_out
.
update_usage
(
rowwise_usage
=
False
)
ln_out
.
update_usage
(
rowwise_usage
=
False
)
# Weight with column-wise usage is needed for dgrad GEMM.
# Weight with column-wise usage is needed for dgrad GEMM.
if
isinstance
(
weightmat
,
QuantizedTensor
Bas
e
):
if
isinstance
(
weightmat
,
QuantizedTensor
Storag
e
):
weightmat
.
update_usage
(
columnwise_usage
=
True
)
weightmat
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
:
if
cpu_offloading
:
...
@@ -716,9 +723,9 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -716,9 +723,9 @@ class _LayerNormLinear(torch.autograd.Function):
# --------------------------------------------------
# --------------------------------------------------
# Make sure required data is available
# Make sure required data is available
if
isinstance
(
grad_output
,
QuantizedTensor
Bas
e
):
if
isinstance
(
grad_output
,
QuantizedTensor
Storag
e
):
grad_output
.
update_usage
(
rowwise_usage
=
True
)
grad_output
.
update_usage
(
rowwise_usage
=
True
)
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
Bas
e
):
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
Storag
e
):
weight
.
update_usage
(
columnwise_usage
=
True
)
weight
.
update_usage
(
columnwise_usage
=
True
)
# Choose whether to use GEMM kernel with split accumulator
# Choose whether to use GEMM kernel with split accumulator
...
@@ -837,14 +844,14 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -837,14 +844,14 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_work
.
wait
()
ln_out_total_work
.
wait
()
ln_out_total_work
=
None
ln_out_total_work
=
None
if
ctx
.
fp8
or
ctx
.
debug
:
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
ln_out_total
,
QuantizedTensor
Bas
e
):
if
isinstance
(
ln_out_total
,
QuantizedTensor
Storag
e
):
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
else
:
else
:
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ln_out_total
=
ctx
.
input_quantizer
(
ln_out_total
)
ln_out_total
=
ctx
.
input_quantizer
(
ln_out_total
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
grad_output
,
QuantizedTensor
Bas
e
):
if
isinstance
(
grad_output
,
QuantizedTensor
Storag
e
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
grad_output
.
update_usage
(
columnwise_usage
=
True
)
else
:
else
:
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
@@ -880,7 +887,11 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -880,7 +887,11 @@ class _LayerNormLinear(torch.autograd.Function):
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
),
"quantization_params"
:
ctx
.
grad_weight_quantizer
,
"quantization_params"
:
ctx
.
grad_weight_quantizer
,
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
"accumulate"
:
(
accumulate_wgrad_into_param_main_grad
if
not
getattr
(
weight
,
"overwrite_main_grad"
,
False
)
else
False
),
"layout"
:
"NT"
,
"layout"
:
"NT"
,
"out"
:
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"out"
:
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"bias"
:
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
"bias"
:
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
...
@@ -1037,7 +1048,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -1037,7 +1048,7 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop
(
f
"
{
nvtx_label
}
.reduce_and_update_fp8_tensors"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.reduce_and_update_fp8_tensors"
)
# Scatter fp8 weight buffers
# Scatter fp8 weight buffers
# if ctx.fp8 and not isinstance(weight, QuantizedTensor
Bas
e):
# if ctx.fp8 and not isinstance(weight, QuantizedTensor
Storag
e):
# _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
# _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return
(
return
(
...
@@ -1164,7 +1175,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1164,7 +1175,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
instead return the bias value during the forward pass together with the
...
@@ -1470,6 +1483,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1470,6 +1483,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
elif
recipe
.
nvfp4
():
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
# elif other recipes (mxfp8, etc)
# elif other recipes (mxfp8, etc)
def
reset_layer_norm_parameters
(
self
)
->
None
:
def
reset_layer_norm_parameters
(
self
)
->
None
:
...
@@ -1812,7 +1827,29 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1812,7 +1827,29 @@ class LayerNormLinear(TransformerEngineBaseModule):
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorBase
]]:
def
_customize_quantizers_nvfp4
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on current scaling recipe + layernorm_linear."""
assert
recipe
.
nvfp4
(),
"Incorrect recipe."
if
fwd
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"column"
:
# set input_quantizer with amax reduction TP group
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_group
=
self
.
tp_group
else
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"row"
:
# customize grad_output_quantizer with amax reduction TP group
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorStorage
]]:
"""Get the weight tensors of the module."""
"""Get the weight tensors of the module."""
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
if
any
(
isinstance
(
w
,
QuantizedTensor
)
for
w
in
unfused_weights
):
if
any
(
isinstance
(
w
,
QuantizedTensor
)
for
w
in
unfused_weights
):
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
063ef88d
...
@@ -18,6 +18,7 @@ import transformer_engine_torch as tex
...
@@ -18,6 +18,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.tensor.utils
import
is_experimental
from
.base
import
(
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
get_workspace
,
...
@@ -28,7 +29,7 @@ from .base import (
...
@@ -28,7 +29,7 @@ from .base import (
_2X_ACC_DGRAD
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
_2X_ACC_WGRAD
,
)
)
from
..
fp8
import
FP8GlobalStateManager
from
..
quantization
import
FP8GlobalStateManager
from
..jit
import
(
from
..jit
import
(
bias_gelu_fused
,
bias_gelu_fused
,
bgrad_dgelu_fused
,
bgrad_dgelu_fused
,
...
@@ -41,6 +42,7 @@ from ..utils import (
...
@@ -41,6 +42,7 @@ from ..utils import (
init_method_constant
,
init_method_constant
,
cast_if_needed
,
cast_if_needed
,
assert_dim_for_fp8_exec
,
assert_dim_for_fp8_exec
,
assert_dim_for_all_gather
,
clear_tensor_data
,
clear_tensor_data
,
requires_grad
,
requires_grad
,
needs_quantized_gemm
,
needs_quantized_gemm
,
...
@@ -65,11 +67,12 @@ from ..tensor.float8_tensor import (
...
@@ -65,11 +67,12 @@ from ..tensor.float8_tensor import (
Float8Tensor
,
Float8Tensor
,
)
)
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.nvfp4_tensor
import
NVFP4Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
._common
import
apply_normalization
,
WeightGradStore
from
._common
import
apply_normalization
,
WeightGradStore
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..tensor.quantized_tensor
import
(
from
..tensor.quantized_tensor
import
(
QuantizedTensor
Bas
e
,
QuantizedTensor
Storag
e
,
Quantizer
,
Quantizer
,
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
...
@@ -120,8 +123,14 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
...
@@ -120,8 +123,14 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
}
}
# no activation fusion written yet
# no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling: []
# Per-tensor current scaling or fp8 blockwise scaling or custom quantization: []
if
recipe
.
float8_current_scaling
()
or
recipe
.
float8_block_scaling
():
# TODO(ksivaman): Fuse nvfp4 act once kernel is available.
if
(
recipe
.
float8_current_scaling
()
or
recipe
.
float8_block_scaling
()
or
recipe
.
nvfp4
()
or
recipe
.
custom
()
):
return
{
return
{
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
None
),
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
None
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
...
@@ -218,6 +227,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -218,6 +227,7 @@ class _LayerNormMLP(torch.autograd.Function):
inputmat
=
inp
.
view
((
-
1
,
in_features
))
inputmat
=
inp
.
view
((
-
1
,
in_features
))
if
fp8
:
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
fc1_weight
,
fc2_weight
)
assert_dim_for_fp8_exec
(
inputmat
,
fc1_weight
,
fc2_weight
)
assert_dim_for_all_gather
(
inputmat
,
sequence_parallel
,
fc1_input_quantizer
)
activation_func
=
_act_func
(
activation_func
=
_act_func
(
activation
,
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
activation
,
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
...
@@ -265,11 +275,13 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -265,11 +275,13 @@ class _LayerNormMLP(torch.autograd.Function):
# high precision layernorm output and output of the linear are returned
# high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm
# for debug: : layernorm output = High precision to enable processing of this norm
experimental
=
is_experimental
(
fc1_input_quantizer
)
with_quantized_norm
=
(
with_quantized_norm
=
(
fp8
fp8
and
not
debug
and
not
debug
and
not
return_layernorm_output
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
and
not
return_layernorm_output_gathered
and
not
experimental
)
)
# Apply normalization
# Apply normalization
...
@@ -309,7 +321,8 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -309,7 +321,8 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer
=
None
quantizer
=
None
if
fp8
or
debug
:
if
fp8
or
debug
:
quantizer
=
fc1_input_quantizer
quantizer
=
fc1_input_quantizer
if
not
with_quantized_norm
:
# experimental recipe doesn't need to support quantized AG
if
not
with_quantized_norm
and
not
experimental
:
ln_out
=
fc1_input_quantizer
(
ln_out
)
ln_out
=
fc1_input_quantizer
(
ln_out
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag
:
if
ub_overlap_ag
:
...
@@ -447,10 +460,18 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -447,10 +460,18 @@ class _LayerNormMLP(torch.autograd.Function):
act_out
=
fc2_input_quantizer
(
act_out
)
act_out
=
fc2_input_quantizer
(
act_out
)
else
:
else
:
fc1_out
,
*
_
=
fc1_outputs
fc1_out
,
*
_
=
fc1_outputs
if
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_block_scaling
():
if
fp8
:
# tex.quantize does not support GELU fusion for blockwise.
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
act_out
=
activation_func
(
fc1_out
,
None
)
if
recipe
.
float8_block_scaling
():
act_out
=
tex
.
quantize
(
act_out
,
fc2_input_quantizer
)
# tex.quantize does not support GELU fusion for blockwise
act_out
=
activation_func
(
fc1_out
,
None
)
act_out
=
tex
.
quantize
(
act_out
,
fc2_input_quantizer
)
elif
recipe
.
custom
():
# tex.quantize does not support custom quantizers
act_out
=
activation_func
(
fc1_out
,
None
)
act_out
=
fc2_input_quantizer
(
act_out
)
else
:
act_out
=
activation_func
(
fc1_out
,
fc2_input_quantizer
)
else
:
else
:
if
fp8_calibration
:
if
fp8_calibration
:
act_out
=
activation_func
(
fc1_out
,
None
)
act_out
=
activation_func
(
fc1_out
,
None
)
...
@@ -521,9 +542,9 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -521,9 +542,9 @@ class _LayerNormMLP(torch.autograd.Function):
if
is_grad_enabled
:
if
is_grad_enabled
:
# Weight with column-wise usage is needed for dgrad GEMM.
# Weight with column-wise usage is needed for dgrad GEMM.
if
isinstance
(
fc1_weight_final
,
QuantizedTensor
Bas
e
):
if
isinstance
(
fc1_weight_final
,
QuantizedTensor
Storag
e
):
fc1_weight_final
.
update_usage
(
columnwise_usage
=
True
)
fc1_weight_final
.
update_usage
(
columnwise_usage
=
True
)
if
isinstance
(
fc2_weight_final
,
QuantizedTensor
Bas
e
):
if
isinstance
(
fc2_weight_final
,
QuantizedTensor
Storag
e
):
fc2_weight_final
.
update_usage
(
columnwise_usage
=
True
)
fc2_weight_final
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
:
if
cpu_offloading
:
...
@@ -555,6 +576,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -555,6 +576,7 @@ class _LayerNormMLP(torch.autograd.Function):
if
not
fc2_weight
.
requires_grad
:
if
not
fc2_weight
.
requires_grad
:
clear_tensor_data
(
act_out
)
clear_tensor_data
(
act_out
)
act_out
=
None
act_out
=
None
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
inputmat
,
inputmat
,
ln_weight
,
ln_weight
,
...
@@ -680,6 +702,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -680,6 +702,7 @@ class _LayerNormMLP(torch.autograd.Function):
mu
,
mu
,
rsigma
,
rsigma
,
)
=
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
)
=
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
# Delete the references to tensor objects once they've been consumed
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
# by the `restore_from_saved` method to construct back the actual tensors.
ctx
.
tensor_objects
=
None
ctx
.
tensor_objects
=
None
...
@@ -820,10 +843,10 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -820,10 +843,10 @@ class _LayerNormMLP(torch.autograd.Function):
)
)
# Make sure required data is available
# Make sure required data is available
if
isinstance
(
grad_output
,
QuantizedTensor
Bas
e
):
if
isinstance
(
grad_output
,
QuantizedTensor
Storag
e
):
grad_output
.
update_usage
(
rowwise_usage
=
True
)
grad_output
.
update_usage
(
rowwise_usage
=
True
)
if
ctx
.
fc2_weight_quantizer
is
not
None
and
isinstance
(
if
ctx
.
fc2_weight_quantizer
is
not
None
and
isinstance
(
ctx
.
fc2_weight
,
QuantizedTensor
Bas
e
ctx
.
fc2_weight
,
QuantizedTensor
Storag
e
):
):
ctx
.
fc2_weight
.
update_usage
(
columnwise_usage
=
True
)
ctx
.
fc2_weight
.
update_usage
(
columnwise_usage
=
True
)
...
@@ -905,14 +928,14 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -905,14 +928,14 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
# make sure required data is available
if
ctx
.
fp8
or
ctx
.
debug
:
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
act_out
,
QuantizedTensor
Bas
e
):
if
isinstance
(
act_out
,
QuantizedTensor
Storag
e
):
act_out
.
update_usage
(
columnwise_usage
=
True
)
act_out
.
update_usage
(
columnwise_usage
=
True
)
else
:
else
:
ctx
.
fc2_input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
fc2_input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
act_out
=
ctx
.
fc2_input_quantizer
(
act_out
)
act_out
=
ctx
.
fc2_input_quantizer
(
act_out
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
grad_output
,
QuantizedTensor
Bas
e
):
if
isinstance
(
grad_output
,
QuantizedTensor
Storag
e
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
grad_output
.
update_usage
(
columnwise_usage
=
True
)
else
:
else
:
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
@@ -932,7 +955,11 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -932,7 +955,11 @@ class _LayerNormMLP(torch.autograd.Function):
else
ctx
.
activation_dtype
else
ctx
.
activation_dtype
),
),
"quantization_params"
:
ctx
.
fc2_grad_weight_quantizer
,
# wgrad in high precision
"quantization_params"
:
ctx
.
fc2_grad_weight_quantizer
,
# wgrad in high precision
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
"accumulate"
:
(
accumulate_wgrad_into_param_main_grad
if
not
getattr
(
fc1_weight
,
"overwrite_main_grad"
,
False
)
else
False
),
"layout"
:
"NT"
,
"layout"
:
"NT"
,
"out"
:
origin_fc2_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"out"
:
origin_fc2_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"bias"
:
fc2_bias
if
fc2_bias
is
not
None
and
fc2_bias_grad
is
None
else
None
,
"bias"
:
fc2_bias
if
fc2_bias
is
not
None
and
fc2_bias_grad
is
None
else
None
,
...
@@ -1028,8 +1055,11 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -1028,8 +1055,11 @@ class _LayerNormMLP(torch.autograd.Function):
)
# activation in high precision
)
# activation in high precision
if
ctx
.
fp8
:
if
ctx
.
fp8
:
# TODO float8 blockwise current scaling has no bgrad fusion for now
# TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now
if
isinstance
(
ctx
.
fc1_grad_output_quantizer
,
Float8BlockQuantizer
):
if
(
isinstance
(
ctx
.
fc1_grad_output_quantizer
,
Float8BlockQuantizer
)
or
ctx
.
fp8_recipe
.
custom
()
):
fc1_bias_grad
=
dact
.
view
(
-
1
,
dact
.
shape
[
-
1
]).
sum
(
dim
=
0
)
fc1_bias_grad
=
dact
.
view
(
-
1
,
dact
.
shape
[
-
1
]).
sum
(
dim
=
0
)
dact
=
ctx
.
fc1_grad_output_quantizer
(
dact
)
dact
=
ctx
.
fc1_grad_output_quantizer
(
dact
)
else
:
else
:
...
@@ -1074,7 +1104,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -1074,7 +1104,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Make sure required data is available
# Make sure required data is available
if
ctx
.
fc1_weight_quantizer
is
not
None
and
isinstance
(
if
ctx
.
fc1_weight_quantizer
is
not
None
and
isinstance
(
ctx
.
fc1_weight_quantizer
,
QuantizedTensor
Bas
e
ctx
.
fc1_weight_quantizer
,
QuantizedTensor
Storag
e
):
):
ctx
.
fc1_weight
.
update_usage
(
columnwise_usage
=
True
)
ctx
.
fc1_weight
.
update_usage
(
columnwise_usage
=
True
)
...
@@ -1145,7 +1175,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -1145,7 +1175,7 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total_work
.
wait
()
ln_out_total_work
.
wait
()
ln_out_total_work
=
None
ln_out_total_work
=
None
if
ctx
.
fp8
or
ctx
.
debug
:
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
ln_out_total
,
QuantizedTensor
Bas
e
):
if
isinstance
(
ln_out_total
,
QuantizedTensor
Storag
e
):
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
else
:
else
:
ctx
.
fc1_input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
fc1_input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
@@ -1155,7 +1185,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -1155,7 +1185,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
# make sure required data is available
if
ctx
.
fp8
or
ctx
.
debug
:
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
dact
,
QuantizedTensor
Bas
e
):
if
isinstance
(
dact
,
QuantizedTensor
Storag
e
):
dact
.
update_usage
(
columnwise_usage
=
True
)
dact
.
update_usage
(
columnwise_usage
=
True
)
else
:
else
:
ctx
.
fc1_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
fc1_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
@@ -1178,7 +1208,11 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -1178,7 +1208,11 @@ class _LayerNormMLP(torch.autograd.Function):
else
ctx
.
activation_dtype
else
ctx
.
activation_dtype
),
),
"quantization_params"
:
ctx
.
fc1_grad_weight_quantizer
,
"quantization_params"
:
ctx
.
fc1_grad_weight_quantizer
,
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
"accumulate"
:
(
accumulate_wgrad_into_param_main_grad
if
not
getattr
(
fc2_weight
,
"overwrite_main_grad"
,
False
)
else
False
),
"layout"
:
"NT"
,
"layout"
:
"NT"
,
"out"
:
origin_fc1_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"out"
:
origin_fc1_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"bias"
:
fc1_bias
if
fuse_gemm_and_bias_fc1_wgrad
else
None
,
"bias"
:
fc1_bias
if
fuse_gemm_and_bias_fc1_wgrad
else
None
,
...
@@ -1486,7 +1520,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1486,7 +1520,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias for FC2, but
when set to `True`, this module will not apply the additive bias for FC2, but
instead return the bias value during the forward pass together with the
instead return the bias value during the forward pass together with the
...
@@ -1718,6 +1754,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1718,6 +1754,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
elif
recipe
.
nvfp4
():
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
# elif for other recipes (mxfp8, etc.)
# elif for other recipes (mxfp8, etc.)
def
reset_layer_norm_parameters
(
self
)
->
None
:
def
reset_layer_norm_parameters
(
self
)
->
None
:
...
@@ -1937,7 +1975,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1937,7 +1975,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_INPUT
]
fc2_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_INPUT
]
fc2_input_quantizer
.
set_usage
(
fc2_input_quantizer
.
set_usage
(
rowwise
=
True
,
rowwise
=
True
,
columnwise
=
isinstance
(
fc2_input_quantizer
,
(
MXFP8Quantizer
,
Float8BlockQuantizer
)),
columnwise
=
isinstance
(
fc2_input_quantizer
,
(
MXFP8Quantizer
,
Float8BlockQuantizer
,
NVFP4Quantizer
),
),
)
)
fc1_input_quantizer
.
internal
=
True
fc1_input_quantizer
.
internal
=
True
if
fp8_output
:
if
fp8_output
:
...
@@ -2142,7 +2183,29 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -2142,7 +2183,29 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
].
amax_reduction_group
=
self
.
tp_group
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorBase
]]:
def
_customize_quantizers_nvfp4
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on current scaling recipe + layernorm_mlp."""
assert
recipe
.
nvfp4
(),
"Incorrect recipe."
if
fwd
:
if
self
.
sequence_parallel
and
self
.
set_parallel_mode
:
# fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_group
=
self
.
tp_group
else
:
if
self
.
sequence_parallel
and
self
.
set_parallel_mode
:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorStorage
]]:
"""Get the weight tensors of the module."""
"""Get the weight tensors of the module."""
return
[
self
.
fc1_weight
,
self
.
fc2_weight
]
return
[
self
.
fc1_weight
,
self
.
fc2_weight
]
...
...
transformer_engine/pytorch/module/linear.py
View file @
063ef88d
...
@@ -27,7 +27,7 @@ from .base import (
...
@@ -27,7 +27,7 @@ from .base import (
_2X_ACC_WGRAD
,
_2X_ACC_WGRAD
,
)
)
from
._common
import
noop_cat
,
WeightGradStore
from
._common
import
noop_cat
,
WeightGradStore
from
..
fp8
import
FP8GlobalStateManager
from
..
quantization
import
FP8GlobalStateManager
from
..utils
import
(
from
..utils
import
(
cast_if_needed
,
cast_if_needed
,
clear_tensor_data
,
clear_tensor_data
,
...
@@ -36,6 +36,7 @@ from ..utils import (
...
@@ -36,6 +36,7 @@ from ..utils import (
requires_grad
,
requires_grad
,
needs_quantized_gemm
,
needs_quantized_gemm
,
assert_dim_for_fp8_exec
,
assert_dim_for_fp8_exec
,
assert_dim_for_all_gather
,
nvtx_range_pop
,
nvtx_range_pop
,
nvtx_range_push
,
nvtx_range_push
,
get_activation_offloading
,
get_activation_offloading
,
...
@@ -60,13 +61,14 @@ from ..jit import no_torch_dynamo
...
@@ -60,13 +61,14 @@ from ..jit import no_torch_dynamo
from
..graph
import
is_graph_capturing
from
..graph
import
is_graph_capturing
from
..tensor.quantized_tensor
import
(
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensor
,
QuantizedTensor
Bas
e
,
QuantizedTensor
Storag
e
,
Quantizer
,
Quantizer
,
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
)
)
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.utils
import
is_experimental
from
..export
import
is_in_onnx_export_mode
,
assert_warmed_up
from
..export
import
is_in_onnx_export_mode
,
assert_warmed_up
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.debug_state
import
TEDebugState
...
@@ -154,6 +156,9 @@ class _Linear(torch.autograd.Function):
...
@@ -154,6 +156,9 @@ class _Linear(torch.autograd.Function):
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
AG
ub_type
=
tex
.
CommOverlapType
.
AG
# experimental recipe check
experimental
=
is_experimental
(
input_quantizer
)
or
is_experimental
(
weight_quantizer
)
# ------------------------------------------------------
# ------------------------------------------------------
# Prepare input tensor
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# Note: Cast to expected dtype and perform tensor-parallel communication
...
@@ -164,6 +169,7 @@ class _Linear(torch.autograd.Function):
...
@@ -164,6 +169,7 @@ class _Linear(torch.autograd.Function):
own_quantized_input
=
False
own_quantized_input
=
False
if
fp8
:
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
assert_dim_for_all_gather
(
inputmat
,
with_input_all_gather_nccl
,
input_quantizer
)
if
save_original_input
:
if
save_original_input
:
assert
not
isinstance
(
assert
not
isinstance
(
input_quantizer
,
Float8Quantizer
input_quantizer
,
Float8Quantizer
...
@@ -175,7 +181,7 @@ class _Linear(torch.autograd.Function):
...
@@ -175,7 +181,7 @@ class _Linear(torch.autograd.Function):
if
fp8
or
debug
:
if
fp8
or
debug
:
if
input_quantizer
is
None
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
raise
ValueError
(
"Missing quantizer for input tensor"
)
if
not
isinstance
(
inputmat
,
QuantizedTensor
Base
)
:
if
not
isinstance
(
inputmat
,
QuantizedTensor
Storage
)
and
not
experimental
:
own_quantized_input
=
True
own_quantized_input
=
True
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
if
isinstance
(
if
isinstance
(
...
@@ -213,7 +219,7 @@ class _Linear(torch.autograd.Function):
...
@@ -213,7 +219,7 @@ class _Linear(torch.autograd.Function):
else
:
# Do not all-gather input tensor
else
:
# Do not all-gather input tensor
if
fp8
or
debug
:
if
fp8
or
debug
:
if
isinstance
(
inputmat
,
QuantizedTensor
Bas
e
):
if
isinstance
(
inputmat
,
QuantizedTensor
Storag
e
):
inputmat
.
update_usage
(
rowwise_usage
=
True
)
inputmat
.
update_usage
(
rowwise_usage
=
True
)
else
:
else
:
if
input_quantizer
is
None
:
if
input_quantizer
is
None
:
...
@@ -369,7 +375,7 @@ class _Linear(torch.autograd.Function):
...
@@ -369,7 +375,7 @@ class _Linear(torch.autograd.Function):
if
(
if
(
backward_needs_input
backward_needs_input
and
own_quantized_input
and
own_quantized_input
and
isinstance
(
inputmat
,
QuantizedTensor
Bas
e
)
and
isinstance
(
inputmat
,
QuantizedTensor
Storag
e
)
):
):
if
(
if
(
ctx
.
backward_input_needs_gather
ctx
.
backward_input_needs_gather
...
@@ -388,7 +394,7 @@ class _Linear(torch.autograd.Function):
...
@@ -388,7 +394,7 @@ class _Linear(torch.autograd.Function):
# Weight with column-wise usage is needed for dgrad GEMM.
# Weight with column-wise usage is needed for dgrad GEMM.
if
inp
.
requires_grad
:
if
inp
.
requires_grad
:
if
isinstance
(
weightmat
,
QuantizedTensor
Bas
e
):
if
isinstance
(
weightmat
,
QuantizedTensor
Storag
e
):
weightmat
.
update_usage
(
columnwise_usage
=
True
)
weightmat
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
and
saved_inputmat
is
not
None
:
if
cpu_offloading
and
saved_inputmat
is
not
None
:
...
@@ -401,7 +407,7 @@ class _Linear(torch.autograd.Function):
...
@@ -401,7 +407,7 @@ class _Linear(torch.autograd.Function):
ctx
.
fsdp_shapes
=
_fsdp_scatter_tensors
(
ctx
.
fsdp_shapes
=
_fsdp_scatter_tensors
(
fsdp_group
,
fsdp_group
,
saved_inputmat
,
saved_inputmat
,
weightmat
if
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
Bas
e
)
else
None
,
weightmat
if
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
Storag
e
)
else
None
,
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
...
@@ -471,6 +477,7 @@ class _Linear(torch.autograd.Function):
...
@@ -471,6 +477,7 @@ class _Linear(torch.autograd.Function):
ctx
.
main_grad_func
=
lambda
:
weight
.
main_grad
ctx
.
main_grad_func
=
lambda
:
weight
.
main_grad
ctx
.
debug
=
debug
ctx
.
debug
=
debug
ctx
.
experimental
=
experimental
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
use_bias
=
bias
is
not
None
ctx
.
use_bias
=
bias
is
not
None
...
@@ -637,10 +644,10 @@ class _Linear(torch.autograd.Function):
...
@@ -637,10 +644,10 @@ class _Linear(torch.autograd.Function):
inputmat_total_work
=
None
inputmat_total_work
=
None
if
ctx
.
requires_wgrad
:
if
ctx
.
requires_wgrad
:
if
ctx
.
fp8
or
ctx
.
debug
:
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
inputmat
,
QuantizedTensor
Bas
e
):
if
isinstance
(
inputmat
,
QuantizedTensor
Storag
e
):
# Input tensor is already quantized
# Input tensor is already quantized
pass
pass
elif
ctx
.
debug
:
elif
ctx
.
debug
or
ctx
.
experimental
:
# Debug quantizer will be applied immediately before wgrad GEMM
# Debug quantizer will be applied immediately before wgrad GEMM
pass
pass
else
:
else
:
...
@@ -656,7 +663,7 @@ class _Linear(torch.autograd.Function):
...
@@ -656,7 +663,7 @@ class _Linear(torch.autograd.Function):
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
inputmat
=
quantizer
(
inputmat
)
inputmat
=
quantizer
(
inputmat
)
else
:
else
:
if
isinstance
(
inputmat
,
QuantizedTensor
Bas
e
):
if
isinstance
(
inputmat
,
QuantizedTensor
Storag
e
):
inputmat
=
inputmat
.
dequantize
(
dtype
=
ctx
.
activation_dtype
)
inputmat
=
inputmat
.
dequantize
(
dtype
=
ctx
.
activation_dtype
)
else
:
else
:
inputmat
=
cast_if_needed
(
inputmat
,
ctx
.
activation_dtype
)
inputmat
=
cast_if_needed
(
inputmat
,
ctx
.
activation_dtype
)
...
@@ -701,9 +708,11 @@ class _Linear(torch.autograd.Function):
...
@@ -701,9 +708,11 @@ class _Linear(torch.autograd.Function):
if
ctx
.
requires_dgrad
:
if
ctx
.
requires_dgrad
:
# Make sure required data is available
# Make sure required data is available
if
isinstance
(
grad_output
,
QuantizedTensor
Bas
e
):
if
isinstance
(
grad_output
,
QuantizedTensor
Storag
e
):
grad_output
.
update_usage
(
rowwise_usage
=
True
)
grad_output
.
update_usage
(
rowwise_usage
=
True
)
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight_fp8
,
QuantizedTensorBase
):
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight_fp8
,
QuantizedTensorStorage
):
weight_fp8
.
update_usage
(
columnwise_usage
=
True
)
weight_fp8
.
update_usage
(
columnwise_usage
=
True
)
# Choose whether to use GEMM kernel with split accumulator
# Choose whether to use GEMM kernel with split accumulator
...
@@ -729,6 +738,7 @@ class _Linear(torch.autograd.Function):
...
@@ -729,6 +738,7 @@ class _Linear(torch.autograd.Function):
# dgrad GEMM
# dgrad GEMM
# Note: dx = dy * w
# Note: dx = dy * w
nvtx_range_push
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weight_fp8
,
weight_fp8
,
...
@@ -786,7 +796,7 @@ class _Linear(torch.autograd.Function):
...
@@ -786,7 +796,7 @@ class _Linear(torch.autograd.Function):
inputmat_total_work
.
wait
()
inputmat_total_work
.
wait
()
inputmat_total_work
=
None
inputmat_total_work
=
None
if
ctx
.
fp8
or
ctx
.
debug
:
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
inputmat_total
,
QuantizedTensor
Bas
e
):
if
isinstance
(
inputmat_total
,
QuantizedTensor
Storag
e
):
inputmat_total
.
update_usage
(
columnwise_usage
=
True
)
inputmat_total
.
update_usage
(
columnwise_usage
=
True
)
else
:
else
:
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
@@ -828,7 +838,7 @@ class _Linear(torch.autograd.Function):
...
@@ -828,7 +838,7 @@ class _Linear(torch.autograd.Function):
)
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
grad_output
,
QuantizedTensor
Bas
e
):
if
isinstance
(
grad_output
,
QuantizedTensor
Storag
e
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
grad_output
.
update_usage
(
columnwise_usage
=
True
)
else
:
else
:
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
@@ -864,7 +874,11 @@ class _Linear(torch.autograd.Function):
...
@@ -864,7 +874,11 @@ class _Linear(torch.autograd.Function):
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
),
"quantization_params"
:
ctx
.
grad_weight_quantizer
,
"quantization_params"
:
ctx
.
grad_weight_quantizer
,
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
"accumulate"
:
(
accumulate_wgrad_into_param_main_grad
if
not
getattr
(
weight
,
"overwrite_main_grad"
,
False
)
else
False
),
"layout"
:
"NT"
,
"layout"
:
"NT"
,
"out"
:
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"out"
:
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"bias"
:
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
"bias"
:
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
...
@@ -984,7 +998,7 @@ class _Linear(torch.autograd.Function):
...
@@ -984,7 +998,7 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop
(
f
"
{
nvtx_label
}
.reduce_and_update_fp8_tensors"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.reduce_and_update_fp8_tensors"
)
# Scatter fp8 weight buffers
# Scatter fp8 weight buffers
if
ctx
.
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
Bas
e
):
if
ctx
.
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
Storag
e
):
_fsdp_scatter_tensors
(
ctx
.
fsdp_group
,
weight_fp8
)
_fsdp_scatter_tensors
(
ctx
.
fsdp_group
,
weight_fp8
)
return
(
return
(
wgrad
,
wgrad
,
...
@@ -1086,7 +1100,9 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1086,7 +1100,9 @@ class Linear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
instead return the bias value during the forward pass together with the
...
@@ -1363,6 +1379,8 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1363,6 +1379,8 @@ class Linear(TransformerEngineBaseModule):
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
elif
recipe
.
nvfp4
():
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
# elif for other recipes (mxfp8, etc.)
# elif for other recipes (mxfp8, etc.)
def
reset_parameters
(
self
,
defer_init
=
False
):
def
reset_parameters
(
self
,
defer_init
=
False
):
...
@@ -1452,7 +1470,6 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1452,7 +1470,6 @@ class Linear(TransformerEngineBaseModule):
if
not
debug
if
not
debug
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
)
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
)
)
)
if
debug
:
if
debug
:
if
self
.
no_debug_features_active
(
quantizers
):
if
self
.
no_debug_features_active
(
quantizers
):
debug
=
False
debug
=
False
...
@@ -1557,7 +1574,7 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1557,7 +1574,7 @@ class Linear(TransformerEngineBaseModule):
for
name
,
q
in
zip
(
names
,
original_quantizers
)
for
name
,
q
in
zip
(
names
,
original_quantizers
)
)
)
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensor
Bas
e
]]:
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensor
Storag
e
]]:
"""Get the weight tensors of the module."""
"""Get the weight tensors of the module."""
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
if
any
(
isinstance
(
w
,
QuantizedTensor
)
for
w
in
unfused_weights
):
if
any
(
isinstance
(
w
,
QuantizedTensor
)
for
w
in
unfused_weights
):
...
@@ -1693,6 +1710,28 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1693,6 +1710,28 @@ class Linear(TransformerEngineBaseModule):
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
].
amax_reduction_group
=
self
.
tp_group
def
_customize_quantizers_nvfp4
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on current scaling recipe + linear."""
assert
recipe
.
nvfp4
(),
"Incorrect recipe."
if
fwd
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"column"
:
# customize input_quantizer with amax reduction TP group
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_group
=
self
.
tp_group
else
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"row"
:
# customize grad_output_quantizer with amax reduction TP group
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
"""Get the weight quantizers of the module."""
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
...
...
transformer_engine/pytorch/ops/_common.py
View file @
063ef88d
...
@@ -11,19 +11,19 @@ import torch
...
@@ -11,19 +11,19 @@ import torch
from
transformer_engine_torch
import
FP8TensorMeta
from
transformer_engine_torch
import
FP8TensorMeta
from
..
import
torch_version
from
..
import
torch_version
from
..
fp8
import
FP8GlobalStateManager
from
..
quantization
import
FP8GlobalStateManager
from
..tensor.float8_tensor
import
Float8Tensor
from
..tensor.float8_tensor
import
Float8Tensor
from
..tensor.quantized_tensor
import
QuantizedTensor
Bas
e
from
..tensor.quantized_tensor
import
QuantizedTensor
Storag
e
from
..utils
import
canonicalize_dtype
from
..utils
import
canonicalize_dtype
def
is_quantized_tensor
(
tensor
:
torch
.
Tensor
|
QuantizedTensor
Bas
e
)
->
bool
:
def
is_quantized_tensor
(
tensor
:
torch
.
Tensor
|
QuantizedTensor
Storag
e
)
->
bool
:
"""Check if tensor is a quantized tensor"""
"""Check if tensor is a quantized tensor"""
return
isinstance
(
tensor
,
QuantizedTensor
Bas
e
)
return
isinstance
(
tensor
,
QuantizedTensor
Storag
e
)
def
maybe_dequantize
(
def
maybe_dequantize
(
tensor
:
torch
.
Tensor
|
QuantizedTensor
Bas
e
,
dtype
:
torch
.
dtype
|
None
=
None
tensor
:
torch
.
Tensor
|
QuantizedTensor
Storag
e
,
dtype
:
torch
.
dtype
|
None
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Dequantize tensor to given dtype or just convert if not a quantized tensor"""
"""Dequantize tensor to given dtype or just convert if not a quantized tensor"""
if
is_quantized_tensor
(
tensor
):
if
is_quantized_tensor
(
tensor
):
...
...
transformer_engine/pytorch/ops/basic/__init__.py
View file @
063ef88d
...
@@ -4,7 +4,19 @@
...
@@ -4,7 +4,19 @@
"""Single tensor operations supported by the operation fuser."""
"""Single tensor operations supported by the operation fuser."""
from
.activation
import
GELU
,
GEGLU
,
QGELU
,
QGEGLU
,
ReLU
,
ReGLU
,
SReLU
,
SReGLU
,
SiLU
,
SwiGLU
from
.activation
import
(
GELU
,
GEGLU
,
QGELU
,
QGEGLU
,
ReLU
,
ReGLU
,
SReLU
,
SReGLU
,
SiLU
,
SwiGLU
,
ClampedSwiGLU
,
)
from
.add_extra_input
import
AddExtraInput
from
.add_extra_input
import
AddExtraInput
from
.all_gather
import
AllGather
from
.all_gather
import
AllGather
from
.all_reduce
import
AllReduce
from
.all_reduce
import
AllReduce
...
...
transformer_engine/pytorch/ops/basic/activation.py
View file @
063ef88d
...
@@ -28,6 +28,7 @@ __all__ = [
...
@@ -28,6 +28,7 @@ __all__ = [
"SReGLU"
,
"SReGLU"
,
"SiLU"
,
"SiLU"
,
"SwiGLU"
,
"SwiGLU"
,
"ClampedSwiGLU"
,
]
]
...
@@ -392,3 +393,38 @@ class SwiGLU(_ActivationOperation):
...
@@ -392,3 +393,38 @@ class SwiGLU(_ActivationOperation):
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
dswiglu
(
*
args
,
**
kwargs
)
return
tex
.
dswiglu
(
*
args
,
**
kwargs
)
class
ClampedSwiGLU
(
_ActivationOperation
):
r
"""GPT-OSS
Implementation based on `GPT-OSS<https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250>`__.
This activation has two differences compared to the original SwiGLU
1. Both gate and pre-activations are clipped based on parameter limit.
2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation.
.. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt
from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor.
Parameters
----------
limit: float
The clamp limit.
alpha: float
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input: bool, default = False
Quantize input tensor when caching for use in the backward pass.
"""
def
__init__
(
self
,
*
,
limit
:
float
=
7.0
,
alpha
:
float
=
1.702
,
cache_quantized_input
:
bool
=
False
):
super
().
__init__
(
cache_quantized_input
=
cache_quantized_input
)
self
.
limit
=
limit
self
.
alpha
=
alpha
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
clamped_swiglu
(
*
args
,
limit
=
self
.
limit
,
alpha
=
self
.
alpha
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
clamped_dswiglu
(
*
args
,
limit
=
self
.
limit
,
alpha
=
self
.
alpha
,
**
kwargs
)
transformer_engine/pytorch/ops/basic/basic_linear.py
View file @
063ef88d
...
@@ -19,7 +19,7 @@ from ...distributed import (
...
@@ -19,7 +19,7 @@ from ...distributed import (
gather_along_first_dim
,
gather_along_first_dim
,
reduce_scatter_along_first_dim
,
reduce_scatter_along_first_dim
,
)
)
from
...
fp8
import
FP8GlobalStateManager
,
Recipe
from
...
quantization
import
FP8GlobalStateManager
,
Recipe
from
...module.base
import
(
from
...module.base
import
(
_2X_ACC_FPROP
,
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_DGRAD
,
...
@@ -29,7 +29,7 @@ from ...module.base import (
...
@@ -29,7 +29,7 @@ from ...module.base import (
)
)
from
...tensor
import
Quantizer
from
...tensor
import
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
from
...tensor.
_internal
.float8_tensor_
bas
e
import
Float8Tensor
Bas
e
from
...tensor.
storage
.float8_tensor_
storag
e
import
Float8Tensor
Storag
e
from
...utils
import
(
from
...utils
import
(
canonicalize_device
,
canonicalize_device
,
canonicalize_dtype
,
canonicalize_dtype
,
...
@@ -80,7 +80,9 @@ class BasicLinear(BasicOperation):
...
@@ -80,7 +80,9 @@ class BasicLinear(BasicOperation):
autograd. The weight's `main_grad` must be set externally and
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
there is no guarantee that `grad` will be set or be
meaningful. This is primarily intented to integrate with
meaningful. This is primarily intented to integrate with
Megatron-LM.
Megatron-LM. This argument along with weight tensor having
attribute 'overwrite_main_grad' set to True will overwrite
`main_grad` instead of accumulating.
userbuffers_options, dict, optional
userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with
Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly
compute using Userbuffers. This feature is highly
...
@@ -301,8 +303,8 @@ class BasicLinear(BasicOperation):
...
@@ -301,8 +303,8 @@ class BasicLinear(BasicOperation):
"Tried to quantize weight with deferred initialization "
"Tried to quantize weight with deferred initialization "
"due to meta device, but no quantizer was available. "
"due to meta device, but no quantizer was available. "
"This is most likely because the weight was initialized "
"This is most likely because the weight was initialized "
"within
fp8
_model_init, but the forward pass was not "
"within
quantized
_model_init, but the forward pass was not "
"performed within
fp8_
autocast."
"performed within autocast."
)
)
quantizer
.
set_usage
(
quantizer
.
set_usage
(
rowwise
=
True
,
rowwise
=
True
,
...
@@ -322,6 +324,20 @@ class BasicLinear(BasicOperation):
...
@@ -322,6 +324,20 @@ class BasicLinear(BasicOperation):
if
self
.
weight
.
device
.
type
==
"meta"
:
if
self
.
weight
.
device
.
type
==
"meta"
:
self
.
reset_parameters
()
self
.
reset_parameters
()
def
pre_fuser_forward
(
self
,
*
,
requires_grad
:
bool
)
->
None
:
super
().
pre_fuser_forward
(
requires_grad
=
requires_grad
)
if
FP8GlobalStateManager
.
is_fp8_enabled
():
# Configure quantizer usages
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
weight_requires_grad
=
requires_grad
and
self
.
weight
.
requires_grad
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
def
reset_recipe_state
(
self
,
*
,
recipe
:
Optional
[
Recipe
])
->
None
:
def
reset_recipe_state
(
self
,
*
,
recipe
:
Optional
[
Recipe
])
->
None
:
super
().
reset_recipe_state
(
recipe
=
recipe
)
super
().
reset_recipe_state
(
recipe
=
recipe
)
...
@@ -352,6 +368,35 @@ class BasicLinear(BasicOperation):
...
@@ -352,6 +368,35 @@ class BasicLinear(BasicOperation):
and
not
getattr
(
self
,
"_with_quantized_weight"
,
False
)
and
not
getattr
(
self
,
"_with_quantized_weight"
,
False
)
)
)
# Recipe-specific configuration
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
if
recipe
is
not
None
:
if
recipe
.
float8_current_scaling
():
input_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
input_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
weight_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_weight
.
power_2_scale
weight_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_weight
.
amax_epsilon
grad_output_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
grad_output_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
if
getattr
(
self
,
"sequence_parallel"
,
False
):
tensor_parallel_mode
=
getattr
(
self
,
"tensor_parallel_mode"
,
None
)
if
tensor_parallel_mode
==
"column"
:
input_quantizer
.
with_amax_reduction
=
True
input_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
elif
tensor_parallel_mode
==
"row"
:
grad_output_quantizer
.
with_amax_reduction
=
True
grad_output_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
if
recipe
.
nvfp4
():
if
getattr
(
self
,
"sequence_parallel"
,
False
):
tensor_parallel_mode
=
getattr
(
self
,
"tensor_parallel_mode"
,
None
)
if
tensor_parallel_mode
==
"column"
:
input_quantizer
.
with_amax_reduction
=
True
input_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
elif
tensor_parallel_mode
==
"row"
:
grad_output_quantizer
.
with_amax_reduction
=
True
grad_output_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
@
staticmethod
@
staticmethod
def
_functional_forward
(
def
_functional_forward
(
input
:
torch
.
Tensor
,
# pylint: disable=redefined-builtin
input
:
torch
.
Tensor
,
# pylint: disable=redefined-builtin
...
@@ -568,7 +613,7 @@ class BasicLinear(BasicOperation):
...
@@ -568,7 +613,7 @@ class BasicLinear(BasicOperation):
# Prepare input tensor for backward pass
# Prepare input tensor for backward pass
if
weight_requires_grad
:
if
weight_requires_grad
:
if
with_quantized_compute
and
is_quantized_tensor
(
x_local
):
if
with_quantized_compute
and
is_quantized_tensor
(
x_local
):
if
not
(
isinstance
(
x_local
,
Float8Tensor
Bas
e
)
and
with_x_all_gather
):
if
not
(
isinstance
(
x_local
,
Float8Tensor
Storag
e
)
and
with_x_all_gather
):
# FP8 does not support all-gather of transpose data
# FP8 does not support all-gather of transpose data
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
else
:
else
:
...
@@ -731,7 +776,7 @@ class BasicLinear(BasicOperation):
...
@@ -731,7 +776,7 @@ class BasicLinear(BasicOperation):
if
with_quantized_compute
:
if
with_quantized_compute
:
if
input_quantizer
is
None
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
raise
ValueError
(
"Missing quantizer for input tensor"
)
input_quantizer
.
set_usage
(
columnwise
=
True
)
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
if
with_x_all_gather
:
if
with_x_all_gather
:
x
,
x_async
=
gather_along_first_dim
(
x
,
x_async
=
gather_along_first_dim
(
x_local
,
x_local
,
...
@@ -912,34 +957,13 @@ class BasicLinear(BasicOperation):
...
@@ -912,34 +957,13 @@ class BasicLinear(BasicOperation):
input_requires_grad
=
ctx
.
requires_grad
input_requires_grad
=
ctx
.
requires_grad
weight_requires_grad
=
ctx
.
requires_grad
and
self
.
weight
.
requires_grad
weight_requires_grad
=
ctx
.
requires_grad
and
self
.
weight
.
requires_grad
#
FP8 metadata
#
Quantizers
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
next_op_input_quantizer
output_quantizer
=
next_op_input_quantizer
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_output_quantizer
grad_input_quantizer
=
prev_op_grad_output_quantizer
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
if
with_quantized_compute
:
# Configure quantizers
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_current_scaling
():
input_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
input_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
weight_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
weight_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
grad_output_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
grad_output_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
if
self
.
sequence_parallel
and
self
.
tensor_parallel_mode
==
"column"
:
input_quantizer
.
with_amax_reduction
=
True
input_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
if
self
.
sequence_parallel
and
self
.
tensor_parallel_mode
==
"row"
:
grad_output_quantizer
.
with_amax_reduction
=
True
grad_output_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
# Get autocast dtype if needed
# Get autocast dtype if needed
if
torch
.
is_autocast_enabled
():
if
torch
.
is_autocast_enabled
():
...
@@ -997,6 +1021,7 @@ class BasicLinear(BasicOperation):
...
@@ -997,6 +1021,7 @@ class BasicLinear(BasicOperation):
weight_param
=
self
.
weight
weight_param
=
self
.
weight
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
accumulate_into_main_grad
=
not
getattr
(
weight_param
,
"overwrite_main_grad"
,
False
)
if
not
hasattr
(
weight_param
,
"main_grad"
):
if
not
hasattr
(
weight_param
,
"main_grad"
):
raise
RuntimeError
(
raise
RuntimeError
(
"BasicLinear op is configured with "
"BasicLinear op is configured with "
...
...
transformer_engine/pytorch/ops/basic/dropout.py
View file @
063ef88d
...
@@ -11,7 +11,7 @@ import torch
...
@@ -11,7 +11,7 @@ import torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...tensor
import
Quantizer
from
...tensor
import
Quantizer
from
...tensor.
_internal
.float8_tensor_
bas
e
import
Float8Tensor
Bas
e
from
...tensor.
storage
.float8_tensor_
storag
e
import
Float8Tensor
Storag
e
from
.._common
import
maybe_autocast_dtype
,
maybe_dequantize
from
.._common
import
maybe_autocast_dtype
,
maybe_dequantize
from
..op
import
BasicOperation
,
OperationContext
from
..op
import
BasicOperation
,
OperationContext
...
@@ -56,7 +56,7 @@ class Dropout(BasicOperation):
...
@@ -56,7 +56,7 @@ class Dropout(BasicOperation):
out
=
input_
out
=
input_
elif
impl
==
"fused"
:
elif
impl
==
"fused"
:
x
=
input_
x
=
input_
if
not
isinstance
(
x
,
Float8Tensor
Bas
e
):
if
not
isinstance
(
x
,
Float8Tensor
Storag
e
):
x
=
maybe_dequantize
(
x
,
dtype
=
dtype
)
x
=
maybe_dequantize
(
x
,
dtype
=
dtype
)
out
,
mask
=
tex
.
dropout_fwd
(
x
,
self
.
dropout_probability
)
out
,
mask
=
tex
.
dropout_fwd
(
x
,
self
.
dropout_probability
)
elif
impl
==
"unfused"
:
elif
impl
==
"unfused"
:
...
...
transformer_engine/pytorch/ops/basic/quantize.py
View file @
063ef88d
...
@@ -9,7 +9,7 @@ from typing import Optional
...
@@ -9,7 +9,7 @@ from typing import Optional
import
torch
import
torch
from
...
fp8
import
FP8GlobalStateManager
from
...
quantization
import
FP8GlobalStateManager
from
.._common
import
is_quantized_tensor
from
.._common
import
is_quantized_tensor
from
..op
import
BasicOperation
,
OperationContext
from
..op
import
BasicOperation
,
OperationContext
from
...tensor
import
Quantizer
from
...tensor
import
Quantizer
...
@@ -18,8 +18,8 @@ from ...tensor import Quantizer
...
@@ -18,8 +18,8 @@ from ...tensor import Quantizer
class
Quantize
(
BasicOperation
):
class
Quantize
(
BasicOperation
):
"""Quantize tensor data
"""Quantize tensor data
Uses
FP8
recipe from `
fp8_
autocast` context. When called outside
Uses recipe from `autocast` context. When called outside
of an `
fp8_
autocast` context, this is an identity operation.
of an `autocast` context, this is an identity operation.
Parameters
Parameters
----------
----------
...
...
transformer_engine/pytorch/ops/fused/backward_activation_bias.py
View file @
063ef88d
...
@@ -10,7 +10,7 @@ from typing import Optional
...
@@ -10,7 +10,7 @@ from typing import Optional
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.
fp8
import
Recipe
from
transformer_engine.pytorch.
quantization
import
Recipe
from
transformer_engine.pytorch.ops.basic
import
Bias
from
transformer_engine.pytorch.ops.basic
import
Bias
from
transformer_engine.pytorch.ops.basic.activation
import
(
from
transformer_engine.pytorch.ops.basic.activation
import
(
_ActivationOperation
,
_ActivationOperation
,
...
...
transformer_engine/pytorch/ops/fused/backward_linear_add.py
View file @
063ef88d
...
@@ -59,6 +59,7 @@ class BackwardLinearAdd(FusedOperation):
...
@@ -59,6 +59,7 @@ class BackwardLinearAdd(FusedOperation):
weight_param
=
linear_op
.
weight
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
accumulate_into_main_grad
=
not
getattr
(
weight_param
,
"overwrite_main_grad"
,
False
)
if
not
hasattr
(
weight_param
,
"main_grad"
):
if
not
hasattr
(
weight_param
,
"main_grad"
):
raise
RuntimeError
(
raise
RuntimeError
(
"BasicLinear op is configured with "
"BasicLinear op is configured with "
...
...
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
View file @
063ef88d
...
@@ -60,6 +60,7 @@ class BackwardLinearScale(FusedOperation):
...
@@ -60,6 +60,7 @@ class BackwardLinearScale(FusedOperation):
weight_param
=
linear_op
.
weight
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
accumulate_into_main_grad
=
not
getattr
(
weight_param
,
"overwrite_main_grad"
,
False
)
if
not
hasattr
(
weight_param
,
"main_grad"
):
if
not
hasattr
(
weight_param
,
"main_grad"
):
raise
RuntimeError
(
raise
RuntimeError
(
"BasicLinear op is configured with "
"BasicLinear op is configured with "
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
View file @
063ef88d
...
@@ -11,7 +11,7 @@ from typing import Any, Optional
...
@@ -11,7 +11,7 @@ from typing import Any, Optional
import
torch
import
torch
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...
fp8
import
FP8GlobalStateManager
from
...
quantization
import
FP8GlobalStateManager
from
...tensor
import
Quantizer
from
...tensor
import
Quantizer
from
..basic
import
BasicLinear
,
Bias
from
..basic
import
BasicLinear
,
Bias
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
...
@@ -85,7 +85,7 @@ class ForwardLinearBiasActivation(FusedOperation):
...
@@ -85,7 +85,7 @@ class ForwardLinearBiasActivation(FusedOperation):
input_requires_grad
=
linear_op_ctx
.
requires_grad
input_requires_grad
=
linear_op_ctx
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
#
FP8 metadata
#
Quantizers
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
next_op_input_quantizer
output_quantizer
=
next_op_input_quantizer
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
View file @
063ef88d
...
@@ -11,7 +11,7 @@ from typing import Any, Optional
...
@@ -11,7 +11,7 @@ from typing import Any, Optional
import
torch
import
torch
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...
fp8
import
FP8GlobalStateManager
from
...
quantization
import
FP8GlobalStateManager
from
...tensor
import
Quantizer
from
...tensor
import
Quantizer
from
..basic
import
AddExtraInput
,
BasicLinear
,
Bias
from
..basic
import
AddExtraInput
,
BasicLinear
,
Bias
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
...
@@ -79,7 +79,7 @@ class ForwardLinearBiasAdd(FusedOperation):
...
@@ -79,7 +79,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input_requires_grad
=
linear_op_ctx
.
requires_grad
input_requires_grad
=
linear_op_ctx
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
#
FP8 metadata
#
Quantizers
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
None
output_quantizer
=
None
...
...
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
View file @
063ef88d
...
@@ -11,7 +11,7 @@ from typing import Any, Optional
...
@@ -11,7 +11,7 @@ from typing import Any, Optional
import
torch
import
torch
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...
fp8
import
FP8GlobalStateManager
from
...
quantization
import
FP8GlobalStateManager
from
...tensor
import
Quantizer
from
...tensor
import
Quantizer
from
..basic
import
AddExtraInput
,
BasicLinear
,
ConstantScale
from
..basic
import
AddExtraInput
,
BasicLinear
,
ConstantScale
from
..op
import
(
from
..op
import
(
...
@@ -58,7 +58,7 @@ class ForwardLinearScaleAdd(FusedOperation):
...
@@ -58,7 +58,7 @@ class ForwardLinearScaleAdd(FusedOperation):
input_requires_grad
=
linear_op_ctx
.
requires_grad
input_requires_grad
=
linear_op_ctx
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
#
FP8 metadata
#
Quantizers
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
None
output_quantizer
=
None
...
...
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
063ef88d
...
@@ -523,6 +523,7 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -523,6 +523,7 @@ class UserbuffersBackwardLinear(FusedOperation):
weight_param
=
linear_op
.
weight
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
accumulate_into_main_grad
=
not
getattr
(
weight_param
,
"overwrite_main_grad"
,
False
)
if
not
hasattr
(
weight_param
,
"main_grad"
):
if
not
hasattr
(
weight_param
,
"main_grad"
):
raise
RuntimeError
(
raise
RuntimeError
(
"BasicLinear op is configured with "
"BasicLinear op is configured with "
...
...
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
063ef88d
...
@@ -14,7 +14,7 @@ from transformer_engine_torch import CommOverlapType
...
@@ -14,7 +14,7 @@ from transformer_engine_torch import CommOverlapType
from
...cpp_extensions
import
general_gemm
from
...cpp_extensions
import
general_gemm
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...distributed
import
get_distributed_world_size
from
...distributed
import
get_distributed_world_size
from
...
fp8
import
FP8GlobalStateManager
from
...
quantization
import
FP8GlobalStateManager
from
...module.base
import
(
from
...module.base
import
(
fill_userbuffers_buffer_for_all_gather
,
fill_userbuffers_buffer_for_all_gather
,
get_ub
,
get_ub
,
...
@@ -23,7 +23,7 @@ from ...module.base import (
...
@@ -23,7 +23,7 @@ from ...module.base import (
)
)
from
...tensor.quantized_tensor
import
Quantizer
from
...tensor.quantized_tensor
import
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
...tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
...tensor.
_internal
.float8_tensor_
bas
e
import
Float8Tensor
Bas
e
from
...tensor.
storage
.float8_tensor_
storag
e
import
Float8Tensor
Storag
e
from
.._common
import
maybe_dequantize
,
is_quantized_tensor
from
.._common
import
maybe_dequantize
,
is_quantized_tensor
from
..basic
import
BasicLinear
,
Bias
,
ReduceScatter
from
..basic
import
BasicLinear
,
Bias
,
ReduceScatter
from
..op
import
(
from
..op
import
(
...
@@ -267,7 +267,7 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -267,7 +267,7 @@ class UserbuffersForwardLinear(FusedOperation):
# Prepare input tensor for backward pass
# Prepare input tensor for backward pass
if
weight_requires_grad
:
if
weight_requires_grad
:
if
with_quantized_compute
and
is_quantized_tensor
(
x_local
):
if
with_quantized_compute
and
is_quantized_tensor
(
x_local
):
if
not
(
isinstance
(
x_local
,
Float8Tensor
Bas
e
)
and
with_ub_all_gather
):
if
not
(
isinstance
(
x_local
,
Float8Tensor
Storag
e
)
and
with_ub_all_gather
):
# FP8 does not support all-gather of transpose data
# FP8 does not support all-gather of transpose data
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
else
:
else
:
...
...
transformer_engine/pytorch/ops/fuser.py
View file @
063ef88d
...
@@ -11,7 +11,7 @@ import itertools
...
@@ -11,7 +11,7 @@ import itertools
import
torch
import
torch
from
transformer_engine.pytorch.
fp8
import
FP8GlobalStateManager
,
Recipe
,
DelayedScaling
from
transformer_engine.pytorch.
quantization
import
FP8GlobalStateManager
,
Recipe
,
DelayedScaling
from
transformer_engine.pytorch.ops.op
import
(
from
transformer_engine.pytorch.ops.op
import
(
BasicOperation
,
BasicOperation
,
FusibleOperation
,
FusibleOperation
,
...
@@ -472,6 +472,10 @@ class OperationFuser:
...
@@ -472,6 +472,10 @@ class OperationFuser:
# Attempt to fuse operations if neccesary
# Attempt to fuse operations if neccesary
self
.
maybe_fuse_ops
(
is_grad_enabled
,
recipe
,
input
,
basic_op_extra_inputs
)
self
.
maybe_fuse_ops
(
is_grad_enabled
,
recipe
,
input
,
basic_op_extra_inputs
)
# Initialization before forward
for
idx
,
op
in
enumerate
(
self
.
_basic_ops
):
op
.
pre_fuser_forward
(
requires_grad
=
idx
>=
self
.
first_op_requiring_backward
)
# Fuser forward pass
# Fuser forward pass
if
is_grad_enabled
:
if
is_grad_enabled
:
forward_func
=
_OperationFuserAutogradFunction
.
apply
forward_func
=
_OperationFuserAutogradFunction
.
apply
...
...
transformer_engine/pytorch/ops/op.py
View file @
063ef88d
...
@@ -14,10 +14,10 @@ from typing import Any, Optional
...
@@ -14,10 +14,10 @@ from typing import Any, Optional
import
torch
import
torch
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.common.recipe
import
Recipe
from
..
fp8
import
(
from
..
quantization
import
(
FP8GlobalStateManager
,
FP8GlobalStateManager
,
RecipeState
,
RecipeState
,
fp8_
autocast
,
autocast
,
)
)
from
..tensor
import
Quantizer
from
..tensor
import
Quantizer
...
@@ -65,6 +65,13 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
...
@@ -65,6 +65,13 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def
pre_first_fuser_forward
(
self
)
->
None
:
def
pre_first_fuser_forward
(
self
)
->
None
:
"""Preprocessing before first fuser forward pass"""
"""Preprocessing before first fuser forward pass"""
def
pre_fuser_forward
(
self
,
*
,
requires_grad
:
bool
,
# pylint: disable=unused-argument
)
->
None
:
"""Preprocessing before fuser forward pass"""
def
get_input_quantizer
(
self
)
->
Optional
[
Quantizer
]:
def
get_input_quantizer
(
self
)
->
Optional
[
Quantizer
]:
"""Get builder class for quantized input tensor"""
"""Get builder class for quantized input tensor"""
...
@@ -588,6 +595,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
...
@@ -588,6 +595,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
extra
[
key
]
=
val
extra
[
key
]
=
val
state
[
mode
][
"extra_fp8_variables"
]
=
extra
state
[
mode
][
"extra_fp8_variables"
]
=
extra
if
not
state
:
return
torch
.
empty
(
0
,
dtype
=
torch
.
uint8
)
# Serialize state into byte tensor
# Serialize state into byte tensor
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
state_serialized
=
bytearray
(
pickle
.
dumps
(
state
))
state_serialized
=
bytearray
(
pickle
.
dumps
(
state
))
...
@@ -624,7 +634,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
...
@@ -624,7 +634,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Get op's quantizer state, initializing if needed
# Get op's quantizer state, initializing if needed
if
self
.
_fp8_metas
is
None
or
self
.
_fp8_metas
[
mode
]
is
None
:
if
self
.
_fp8_metas
is
None
or
self
.
_fp8_metas
[
mode
]
is
None
:
with
fp8_
autocast
(
fp8_
recipe
=
state
[
mode
][
"recipe"
]):
with
autocast
(
recipe
=
state
[
mode
][
"recipe"
]):
self
.
reset_recipe_state
(
recipe
=
state
[
mode
][
"recipe"
])
self
.
reset_recipe_state
(
recipe
=
state
[
mode
][
"recipe"
])
fp8_meta
=
self
.
_fp8_metas
[
mode
]
fp8_meta
=
self
.
_fp8_metas
[
mode
]
...
@@ -710,6 +720,10 @@ class FusedOperation(FusibleOperation):
...
@@ -710,6 +720,10 @@ class FusedOperation(FusibleOperation):
for
op
in
self
.
basic_ops
:
for
op
in
self
.
basic_ops
:
op
.
pre_first_fuser_forward
()
op
.
pre_first_fuser_forward
()
def
pre_fuser_forward
(
self
,
*
,
requires_grad
:
bool
)
->
None
:
for
op
in
self
.
basic_ops
:
op
.
pre_fuser_forward
(
requires_grad
=
requires_grad
)
def
forward
(
def
forward
(
self
,
self
,
input
:
torch
.
Tensor
,
# pylint: disable=redefined-builtin
input
:
torch
.
Tensor
,
# pylint: disable=redefined-builtin
...
...
Prev
1
…
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment