Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
44740c6c
Commit
44740c6c
authored
Jul 22, 2025
by
yuguo
Browse files
Merge commit '
7a9a0825
' of...
Merge commit '
7a9a0825
' of
https://github.com/NVIDIA/TransformerEngine
parents
8113d9e0
7a9a0825
Changes
162
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
724 additions
and
176 deletions
+724
-176
transformer_engine/pytorch/ops/fused/backward_linear_add.py
transformer_engine/pytorch/ops/fused/backward_linear_add.py
+6
-4
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
...ngine/pytorch/ops/fused/forward_linear_bias_activation.py
+13
-12
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
...ormer_engine/pytorch/ops/fused/forward_linear_bias_add.py
+13
-9
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+16
-22
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+26
-39
transformer_engine/pytorch/ops/fuser.py
transformer_engine/pytorch/ops/fuser.py
+62
-23
transformer_engine/pytorch/ops/op.py
transformer_engine/pytorch/ops/op.py
+59
-42
transformer_engine/pytorch/ops/sequential.py
transformer_engine/pytorch/ops/sequential.py
+17
-2
transformer_engine/pytorch/permutation.py
transformer_engine/pytorch/permutation.py
+11
-4
transformer_engine/pytorch/router.py
transformer_engine/pytorch/router.py
+275
-0
transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
.../pytorch/tensor/_internal/float8_blockwise_tensor_base.py
+6
-3
transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py
...mer_engine/pytorch/tensor/_internal/float8_tensor_base.py
+17
-0
transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py
...rmer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py
+55
-5
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
+32
-0
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+27
-0
transformer_engine/pytorch/tensor/mxfp8_tensor.py
transformer_engine/pytorch/tensor/mxfp8_tensor.py
+61
-7
transformer_engine/pytorch/tensor/quantized_tensor.py
transformer_engine/pytorch/tensor/quantized_tensor.py
+6
-0
transformer_engine/pytorch/tensor/utils.py
transformer_engine/pytorch/tensor/utils.py
+1
-1
transformer_engine/pytorch/transformer.py
transformer_engine/pytorch/transformer.py
+7
-1
transformer_engine/pytorch/triton/cross_entropy.py
transformer_engine/pytorch/triton/cross_entropy.py
+14
-2
No files found.
transformer_engine/pytorch/ops/fused/backward_linear_add.py
View file @
44740c6c
...
...
@@ -57,6 +57,9 @@ class BackwardLinearAdd(FusedOperation):
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
grad_weight
=
None
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
linear_op
.
weight
,
"__fsdp_param__"
):
linear_op
.
weight
.
main_grad
=
linear_op
.
weight
.
get_main_grad
()
if
not
hasattr
(
linear_op
.
weight
,
"main_grad"
):
raise
RuntimeError
(
"BasicLinear op is configured with "
...
...
@@ -93,7 +96,6 @@ class BackwardLinearAdd(FusedOperation):
grad_weight
=
None
# Clear input tensor if possible
if
linear_op_ctx
.
has_prev_op
:
clear_tensor_data
(
x_local
)
return
grad_input
,
[(
grad_weight
,),
()],
[(),
()]
...
...
@@ -107,13 +109,13 @@ def fuse_backward_linear_add(
Parameters
----------
ops: list of tuples
For
ward pass operations and the indices of the corresponding
Back
ward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated
for
ward pass operations
Updated
back
ward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
View file @
44740c6c
...
...
@@ -13,11 +13,11 @@ import torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.ops.basic
import
BasicLinear
,
Bias
from
transformer_engine.pytorch.ops.op
import
(
BasicOperation
,
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
...tensor
import
Quantizer
class
ForwardLinearBiasActivation
(
FusedOperation
):
...
...
@@ -59,8 +59,8 @@ class ForwardLinearBiasActivation(FusedOperation):
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
basic_op_prev_ops
:
list
[
Optional
[
BasicOperation
]
],
basic_op_next_ops
:
list
[
Optional
[
BasicOperation
]
],
prev_op_grad_input_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
torch
.
Tensor
]]]:
...
...
@@ -70,10 +70,12 @@ class ForwardLinearBiasActivation(FusedOperation):
linear_op_ctx
=
basic_op_ctxs
[
idx
]
if
self
.
_op_idxs
[
"bias"
]
is
None
:
bias_op
=
None
bias_op_ctx
=
None
bias
=
None
else
:
idx
=
self
.
_op_idxs
[
"bias"
]
bias_op
=
self
.
basic_ops
[
idx
]
bias_op_ctx
=
basic_op_ctxs
[
idx
]
bias
=
bias_op
.
bias
if
basic_op_kwargs
[
idx
]:
raise
ValueError
(
"Bias operation forward does not expect keyword arguments"
)
...
...
@@ -83,7 +85,7 @@ class ForwardLinearBiasActivation(FusedOperation):
raise
NotImplementedError
(
"Activations are not yet supported"
)
# Check which grads are required
input_requires_grad
=
linear_op_ctx
.
requires_grad
and
input_
.
requires_grad
input_requires_grad
=
linear_op_ctx
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
# FP8 metadata
...
...
@@ -96,18 +98,15 @@ class ForwardLinearBiasActivation(FusedOperation):
if
with_quantized_compute
:
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
next_op
=
basic_op_next_ops
[
-
1
]
if
next_op
is
not
None
and
next_op
.
num_quantizers
(
"forward"
)
>
0
:
output_quantizer
=
next_op
.
get_quantizer
(
"forward"
,
0
)
output_quantizer
=
next_op_input_quantizer
grad_output_quantizer
=
linear_op
.
get_quantizer
(
"backward"
,
0
)
prev_op
=
basic_op_prev_ops
[
0
]
if
prev_op
is
not
None
and
prev_op
.
num_quantizers
(
"backward"
)
>
0
:
grad_input_quantizer
=
prev_op
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_input_quantizer
# Get autocast dtype if needed
dtype
=
None
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
else
:
dtype
=
linear_op
.
weight
.
dtype
# Linear forward
output
,
x_local
,
w
=
BasicLinear
.
_functional_forward
(
...
...
@@ -136,7 +135,9 @@ class ForwardLinearBiasActivation(FusedOperation):
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
input_requires_grad
=
input_requires_grad
linear_op_ctx
.
weight_requires_grad
=
weight_requires_grad
linear_op_ctx
.
has_prev_op
=
basic_op_prev_ops
[
0
]
is
not
None
if
bias_op
is
not
None
:
bias_op_ctx
.
with_quantized_compute
=
with_quantized_compute
bias_op_ctx
.
grad_input_quantizer
=
linear_op
.
get_grad_input_quantizer
()
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
View file @
44740c6c
...
...
@@ -13,11 +13,11 @@ import torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.ops.basic
import
AddInPlace
,
BasicLinear
,
Bias
from
transformer_engine.pytorch.ops.op
import
(
BasicOperation
,
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
transformer_engine.pytorch.tensor
import
Quantizer
class
ForwardLinearBiasAdd
(
FusedOperation
):
...
...
@@ -57,8 +57,8 @@ class ForwardLinearBiasAdd(FusedOperation):
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
basic_op_prev_ops
:
list
[
Optional
[
BasicOperation
]
],
basic_op_next_ops
:
list
[
Optional
[
BasicOperation
]
],
prev_op_grad_input_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
torch
.
Tensor
]]]:
...
...
@@ -68,16 +68,18 @@ class ForwardLinearBiasAdd(FusedOperation):
linear_op_ctx
=
basic_op_ctxs
[
idx
]
if
self
.
_op_idxs
[
"bias"
]
is
None
:
bias_op
=
None
bias_op_ctx
=
None
bias
=
None
else
:
idx
=
self
.
_op_idxs
[
"bias"
]
bias_op
=
self
.
basic_ops
[
idx
]
bias_op_ctx
=
basic_op_ctxs
[
idx
]
bias
=
bias_op
.
bias
if
basic_op_kwargs
[
idx
]:
raise
ValueError
(
"Bias operation forward does not expect keyword arguments"
)
# Check which grads are required
input_requires_grad
=
linear_op_ctx
.
requires_grad
and
input_
.
requires_grad
input_requires_grad
=
linear_op_ctx
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
# FP8 metadata
...
...
@@ -91,14 +93,13 @@ class ForwardLinearBiasAdd(FusedOperation):
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
grad_output_quantizer
=
linear_op
.
get_quantizer
(
"backward"
,
0
)
prev_op
=
basic_op_prev_ops
[
0
]
if
prev_op
is
not
None
and
prev_op
.
num_quantizers
(
"backward"
)
>
0
:
grad_input_quantizer
=
prev_op
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_input_quantizer
# Get autocast dtype if needed
dtype
=
None
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
else
:
dtype
=
linear_op
.
weight
.
dtype
# Linear forward
output
=
basic_op_extra_inputs
[
self
.
_op_idxs
[
"add"
]][
0
]
...
...
@@ -106,6 +107,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input
=
input_
,
weight
=
linear_op
.
weight
,
bias
=
bias
,
dtype
=
output
.
dtype
,
out
=
output
,
accumulate_into_out
=
True
,
tensor_parallel_mode
=
linear_op
.
tensor_parallel_mode
,
...
...
@@ -129,7 +131,9 @@ class ForwardLinearBiasAdd(FusedOperation):
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
input_requires_grad
=
input_requires_grad
linear_op_ctx
.
weight_requires_grad
=
weight_requires_grad
linear_op_ctx
.
has_prev_op
=
basic_op_prev_ops
[
0
]
is
not
None
if
bias_op
is
not
None
:
bias_op_ctx
.
with_quantized_compute
=
with_quantized_compute
bias_op_ctx
.
grad_input_quantizer
=
linear_op
.
get_grad_input_quantizer
()
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
...
...
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
44740c6c
...
...
@@ -20,10 +20,11 @@ from ...module.base import (
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
)
from
...tensor.quantized_tensor
import
QuantizedTensorBase
,
Quantizer
from
...tensor.quantized_tensor
import
Quantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
from
...utils
import
canonicalize_device
,
canonicalize_dtype
,
clear_tensor_data
from
..basic
import
BasicLinear
,
Bias
,
ReduceScatter
from
.._common
import
maybe_dequantize
,
is_quantized_tensor
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
...
...
@@ -279,7 +280,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Cast grad output tensor dtype if needed
dy_local
=
grad_output
if
with_quantized_compute
:
if
not
is
instance
(
dy_local
,
Q
uantized
T
ensor
Base
):
if
not
is
_q
uantized
_t
ensor
(
dy_local
):
with_columnwise
=
weight_requires_grad
if
(
with_columnwise
...
...
@@ -293,24 +294,18 @@ class UserbuffersBackwardLinear(FusedOperation):
)
dy_local
=
grad_output_quantizer
(
dy_local
)
else
:
if
isinstance
(
dy_local
,
QuantizedTensorBase
):
dy_local
=
dy_local
.
dequantize
(
dtype
=
dtype
)
elif
dy_local
.
dtype
!=
dtype
:
dy_local
=
dy_local
.
to
(
dtype
=
dtype
)
dy_local
=
maybe_dequantize
(
dy_local
,
dtype
)
# Cast weight tensor dtype if needed
if
weight
is
None
:
raise
ValueError
(
"Weight tensor is required to compute input grad"
)
w
=
weight
if
with_quantized_compute
:
if
not
is
instance
(
w
,
Q
uantized
T
ensor
Base
):
if
not
is
_q
uantized
_t
ensor
(
w
):
weight_quantizer
.
set_usage
(
columnwise
=
True
)
w
=
weight_quantizer
(
w
)
else
:
if
isinstance
(
w
,
QuantizedTensorBase
):
w
=
w
.
dequantize
(
dtype
=
dtype
)
elif
w
.
dtype
!=
dtype
:
w
=
w
.
to
(
dtype
=
dtype
)
w
=
maybe_dequantize
(
w
,
dtype
)
# Cast input tensor dtype if needed
x_local
=
None
...
...
@@ -319,14 +314,11 @@ class UserbuffersBackwardLinear(FusedOperation):
raise
ValueError
(
"Input tensor is required to compute weight grad"
)
x_local
=
input
if
with_quantized_compute
:
if
not
is
instance
(
x_local
,
Q
uantized
T
ensor
Base
):
if
not
is
_q
uantized
_t
ensor
(
x_local
):
input_quantizer
.
set_usage
(
columnwise
=
True
)
x_local
=
input_quantizer
(
x_local
)
else
:
if
isinstance
(
x_local
,
QuantizedTensorBase
):
x_local
=
x_local
.
dequantize
(
dtype
=
dtype
)
elif
x_local
.
dtype
!=
dtype
:
x_local
=
x_local
.
to
(
dtype
=
dtype
)
x_local
=
maybe_dequantize
(
x_local
,
dtype
)
# dgrad GEMM
dx_local
=
None
...
...
@@ -433,7 +425,7 @@ class UserbuffersBackwardLinear(FusedOperation):
raise
RuntimeError
(
"wgrad GEMM requires grad output tensor, which has not been initialized"
)
if
is
instance
(
dy
,
Q
uantized
T
ensor
Base
):
if
is
_q
uantized
_t
ensor
(
dy
):
dy
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
# Initialize input tensor
...
...
@@ -443,7 +435,7 @@ class UserbuffersBackwardLinear(FusedOperation):
raise
RuntimeError
(
"wgrad GEMM requires input tensor, which has not been initialized"
)
if
is
instance
(
x
,
Q
uantized
T
ensor
Base
):
if
is
_q
uantized
_t
ensor
(
x
):
x
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
# Check grad weight tensor
...
...
@@ -516,6 +508,9 @@ class UserbuffersBackwardLinear(FusedOperation):
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
grad_weight
=
None
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
linear_op
.
weight
,
"__fsdp_param__"
):
linear_op
.
weight
.
main_grad
=
linear_op
.
weight
.
get_main_grad
()
if
not
hasattr
(
linear_op
.
weight
,
"main_grad"
):
raise
RuntimeError
(
"BasicLinear op is configured with "
...
...
@@ -552,7 +547,6 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_bias
=
extra_outputs
[
"grad_bias"
]
# Clear input tensor if possible
if
linear_op_ctx
.
has_prev_op
:
clear_tensor_data
(
x_local
)
# Return gradients
...
...
@@ -574,13 +568,13 @@ def fuse_userbuffers_backward_linear(
Parameters
----------
ops: list of tuples
For
ward pass operations and the indices of the corresponding
Back
ward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated
for
ward pass operations
Updated
back
ward pass operations
"""
...
...
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
44740c6c
...
...
@@ -20,13 +20,12 @@ from ...module.base import (
get_workspace
,
_2X_ACC_FPROP
,
)
from
...tensor.quantized_tensor
import
QuantizedTensorBase
,
Quantizer
from
...tensor.quantized_tensor
import
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..
.utils
import
canonicalize_device
,
canonicalize_dtype
from
..
_common
import
maybe_dequantize
,
is_quantized_tensor
from
..basic
import
BasicLinear
,
Bias
,
ReduceScatter
from
..op
import
(
BasicOperation
,
FusedOperation
,
FusibleOperation
,
OperationContext
,
...
...
@@ -88,8 +87,8 @@ class UserbuffersForwardLinear(FusedOperation):
weight
:
torch
.
Tensor
,
*
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
tensor_parallel_mode
:
Optional
[
str
]
=
None
,
tensor_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
tensor_parallel_size
:
Optional
[
int
]
=
None
,
...
...
@@ -112,9 +111,9 @@ class UserbuffersForwardLinear(FusedOperation):
Weight tensor
bias: torch.Tensor, optional
Bias tensor
device: torch.device
, default = default CUDA device
device: torch.device
Tensor device
dtype: torch.dtype
, default = default dtype
dtype: torch.dtype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
...
...
@@ -156,16 +155,10 @@ class UserbuffersForwardLinear(FusedOperation):
"""
# Check device
if
device
is
None
:
device
=
weight
.
device
device
=
canonicalize_device
(
device
)
if
device
.
type
!=
"cuda"
:
raise
ValueError
(
f
"Only CUDA devices are supported (got
{
device
}
)"
)
# Check datatype
if
dtype
is
None
:
dtype
=
weight
.
dtype
dtype
=
canonicalize_dtype
(
dtype
)
if
dtype
not
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
raise
ValueError
(
f
"Supported dtypes are float32, float16, bfloat16 (got
{
dtype
}
)"
)
...
...
@@ -206,7 +199,7 @@ class UserbuffersForwardLinear(FusedOperation):
x
=
None
if
with_ub_all_gather
:
if
input_quantizer
is
not
None
:
if
not
is
instance
(
x_local
,
Q
uantized
T
ensor
Base
):
if
not
is
_q
uantized
_t
ensor
(
x_local
):
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
if
isinstance
(
input_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
...
...
@@ -222,26 +215,20 @@ class UserbuffersForwardLinear(FusedOperation):
)
else
:
if
with_quantized_compute
:
if
not
is
instance
(
x_local
,
Q
uantized
T
ensor
Base
):
if
not
is
_q
uantized
_t
ensor
(
x_local
):
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
x_local
=
input_quantizer
(
x_local
)
else
:
if
isinstance
(
x_local
,
QuantizedTensorBase
):
x_local
=
x_local
.
dequantize
(
dtype
=
dtype
)
if
x_local
.
dtype
!=
dtype
:
x_local
=
x_local
.
to
(
dtype
=
dtype
)
x_local
=
maybe_dequantize
(
x_local
,
dtype
)
x
=
x_local
# Initialize weight tensor
w
=
weight
w_is_quantized
=
isinstance
(
w
,
QuantizedTensorBase
)
if
with_quantized_compute
and
not
w_is_quantized
:
if
not
with_quantized_compute
:
w
=
maybe_dequantize
(
w
,
dtype
)
elif
with_quantized_compute
and
not
is_quantized_tensor
(
w
):
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
input_requires_grad
)
w
=
weight_quantizer
(
w
)
elif
not
with_quantized_compute
and
w_is_quantized
:
w
=
w
.
dequantize
()
if
not
with_quantized_compute
and
w
.
dtype
!=
dtype
:
w
=
w
.
to
(
dtype
=
dtype
)
# Construct output tensor if needed
reduce_scatter_output
=
None
...
...
@@ -271,18 +258,14 @@ class UserbuffersForwardLinear(FusedOperation):
# Prepare weight tensor for backward pass
if
input_requires_grad
:
if
w
is
not
weight
and
with_quantized_compute
and
is
instance
(
w
,
Q
uantized
T
ensor
Base
):
if
w
is
not
weight
and
with_quantized_compute
and
is
_q
uantized
_t
ensor
(
w
):
w
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
else
:
w
=
None
# Prepare input tensor for backward pass
if
weight_requires_grad
:
if
x_local
is
input
:
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local
=
x_local
.
detach
()
if
with_quantized_compute
and
isinstance
(
x_local
,
QuantizedTensorBase
):
if
with_quantized_compute
and
is_quantized_tensor
(
x_local
):
if
not
(
isinstance
(
x_local
,
Float8TensorBase
)
and
with_ub_all_gather
):
# FP8 does not support all-gather of transpose data
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
...
...
@@ -299,8 +282,8 @@ class UserbuffersForwardLinear(FusedOperation):
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
basic_op_prev_ops
:
list
[
Optional
[
BasicOperation
]
],
basic_op_next_ops
:
list
[
Optional
[
BasicOperation
]
],
prev_op_grad_input_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
torch
.
Tensor
]]]:
...
...
@@ -309,16 +292,18 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op
=
self
.
basic_ops
[
idx
]
linear_op_ctx
=
basic_op_ctxs
[
idx
]
bias_op
=
None
bias_op_ctx
=
None
bias
=
None
if
self
.
_op_idxs
[
"bias"
]
is
not
None
:
idx
=
self
.
_op_idxs
[
"bias"
]
bias_op
=
self
.
basic_ops
[
idx
]
bias_op_ctx
=
basic_op_ctxs
[
idx
]
bias
=
bias_op
.
bias
if
basic_op_kwargs
[
idx
]:
raise
ValueError
(
"Bias operation forward does not expect keyword arguments"
)
# Check which grads are required
input_requires_grad
=
linear_op_ctx
.
requires_grad
and
input_
.
requires_grad
input_requires_grad
=
linear_op_ctx
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
# Quantization metadata
...
...
@@ -336,14 +321,13 @@ class UserbuffersForwardLinear(FusedOperation):
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
grad_output_quantizer
=
linear_op
.
get_quantizer
(
"backward"
,
0
)
prev_op
=
basic_op_prev_ops
[
0
]
if
prev_op
is
not
None
and
prev_op
.
num_quantizers
(
"backward"
)
>
0
and
recipe
.
delayed
():
grad_input_quantizer
=
prev_op
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_input_quantizer
# Get autocast dtype if needed
dtype
=
None
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
else
:
dtype
=
linear_op
.
weight
.
dtype
# Userbuffers options
if
linear_op
.
_userbuffers_options
is
None
:
...
...
@@ -355,6 +339,7 @@ class UserbuffersForwardLinear(FusedOperation):
weight
=
linear_op
.
weight
,
bias
=
bias
,
dtype
=
dtype
,
device
=
linear_op
.
weight
.
device
,
tensor_parallel_mode
=
self
.
tensor_parallel_mode
,
tensor_parallel_group
=
self
.
tensor_parallel_group
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
...
...
@@ -381,7 +366,9 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op_ctx
.
input_dims
=
input_
.
size
()
linear_op_ctx
.
input_requires_grad
=
input_requires_grad
linear_op_ctx
.
weight_requires_grad
=
weight_requires_grad
linear_op_ctx
.
has_prev_op
=
basic_op_prev_ops
[
0
]
is
not
None
if
bias_op
is
not
None
:
bias_op_ctx
.
with_quantized_compute
=
with_quantized_compute
bias_op_ctx
.
grad_input_quantizer
=
linear_op
.
get_grad_input_quantizer
()
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
...
...
transformer_engine/pytorch/ops/fuser.py
View file @
44740c6c
...
...
@@ -10,19 +10,24 @@ from typing import Any, Optional
import
torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
,
Recipe
from
transformer_engine.pytorch.ops.op
import
(
BasicOperation
,
FusibleOperation
,
OperationContext
,
)
from
transformer_engine.pytorch.ops.fused
import
(
fuse_backward_bias_activation
,
fuse_backward_linear_add
,
fuse_forward_linear_bias_activation
,
fuse_forward_linear_bias_add
,
fuse_userbuffers_backward_linear
,
fuse_userbuffers_forward_linear
,
)
from
transformer_engine.pytorch.tensor.quantized_tensor
import
(
prepare_for_saving
,
restore_from_saved
,
)
def
_split_tuple
(
t
:
tuple
,
idx
:
int
)
->
tuple
[
tuple
,
tuple
]:
...
...
@@ -96,6 +101,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Operation autograd contexts
basic_op_ctxs
=
[
OperationContext
()
for
_
in
range
(
fuser
.
_num_basic_ops
)]
# Mark input tensors as not deletable in backward
for
tensor
in
(
input_
,)
+
params_and_extra_inputs
:
tensor
.
do_not_clear
=
True
# Unflatten list of parameters and extra tensor inputs
extra_inputs
=
params_and_extra_inputs
[
-
fuser
.
_num_extra_inputs
:]
basic_op_extra_inputs
=
[]
...
...
@@ -106,6 +115,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Apply forward ops
x
=
input_
requires_grad
=
is_grad_enabled
and
x
.
requires_grad
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
extra_outputs
=
[
None
]
*
fuser
.
_num_basic_ops
for
op
,
basic_op_idxs
in
fuser
.
_forward_ops
:
...
...
@@ -117,31 +127,31 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
requires_grad
=
any
(
any
(
x
.
requires_grad
for
x
in
xs
)
for
xs
in
extra_inputs
)
for
idx
in
basic_op_idxs
:
basic_op_ctxs
[
idx
].
requires_grad
=
requires_grad
if
requires_grad
!=
x
.
requires_grad
:
if
requires_grad
:
x
.
requires_grad_
()
else
:
x
=
x
.
detach
()
# Forward op
extra_inputs
=
[
basic_op_extra_inputs
[
idx
]
for
idx
in
basic_op_idxs
]
prev_ops
=
[
fuser
.
_basic_ops
[
idx
-
1
]
if
idx
>
0
else
None
for
idx
in
basic_op_idxs
]
next_ops
=
[
fuser
.
_basic_ops
[
idx
+
1
]
if
(
idx
<
fuser
.
_num_basic_ops
-
1
)
else
None
for
idx
in
basic_op_idxs
]
prev_op_idx
=
basic_op_idxs
[
0
]
-
1
prev_op
=
fuser
.
_basic_ops
[
prev_op_idx
]
if
prev_op_idx
>=
0
else
None
prev_op_grad_input_quantizer
=
None
if
prev_op
is
not
None
and
with_quantized_compute
:
prev_op_grad_input_quantizer
=
prev_op
.
get_grad_input_quantizer
()
next_op_idx
=
basic_op_idxs
[
-
1
]
+
1
next_op
=
fuser
.
_basic_ops
[
next_op_idx
]
if
next_op_idx
<
fuser
.
_num_basic_ops
else
None
next_op_input_quantizer
=
None
if
next_op
is
not
None
and
with_quantized_compute
:
next_op_input_quantizer
=
next_op
.
get_input_quantizer
()
x
,
fused_op_extra_outputs
=
op
.
fuser_forward
(
[
basic_op_ctxs
[
idx
]
for
idx
in
basic_op_idxs
],
x
,
basic_op_extra_inputs
=
extra_inputs
,
basic_op_prev_ops
=
prev_ops
,
basic_op_next_ops
=
next_ops
,
prev_op_grad_input_quantizer
=
prev_op_grad_input_quantizer
,
next_op_input_quantizer
=
next_op_input_quantizer
,
basic_op_kwargs
=
[
basic_op_kwargs
[
idx
]
for
idx
in
basic_op_idxs
],
)
x
.
requires_grad_
(
requires_grad
=
requires_grad
)
for
idx
,
ys
in
zip
(
basic_op_idxs
,
fused_op_extra_outputs
):
for
y
in
ys
:
y
.
requires_grad_
(
requires_grad
=
requires_grad
)
y
.
requires_grad_
(
requires_grad
)
extra_outputs
[
idx
]
=
ys
# Flatten list of extra outputs
...
...
@@ -169,6 +179,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
range_end
=
len
(
to_save
)
ctx
.
to_save
=
None
ctx
.
_saved_tensors_range
=
(
range_start
,
range_end
)
# Save tensors for backward
if
with_quantized_compute
:
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
*
to_save
)
func_ctx
.
save_for_backward
(
*
tensors_to_save
)
func_ctx
.
tensor_objects
=
tensor_objects
else
:
func_ctx
.
save_for_backward
(
*
to_save
)
# Other context
...
...
@@ -179,9 +196,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx
.
num_extra_inputs
=
fuser
.
_num_extra_inputs
func_ctx
.
num_extra_outputs
=
len
(
extra_outputs_flat
)
func_ctx
.
is_first_module
=
FP8GlobalStateManager
.
is_first_fp8_module
()
func_ctx
.
with_quantized_compute
=
with_quantized_compute
x
.
requires_grad_
(
requires_grad
)
if
extra_outputs_flat
:
return
x
,
*
extra_outputs_flat
return
x
@
staticmethod
...
...
@@ -198,8 +219,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_ops
=
func_ctx
.
basic_ops
basic_op_ctxs
=
func_ctx
.
basic_op_ctxs
# Unflatten list of saved tensors
# Restore saved tensors
if
func_ctx
.
with_quantized_compute
:
saved_tensors
=
restore_from_saved
(
func_ctx
.
tensor_objects
,
func_ctx
.
saved_tensors
)
else
:
saved_tensors
=
func_ctx
.
saved_tensors
# Unflatten list of saved tensors
for
ctx
in
basic_op_ctxs
:
ctx
.
saved_tensors
=
saved_tensors
[
slice
(
*
ctx
.
_saved_tensors_range
)]
ctx
.
_saved_tensors_range
=
None
...
...
@@ -291,15 +317,19 @@ class OperationFuser:
----------
ops: list of FusibleOperation
Pipeline of operations
fuse_ops: bool
, default = `True`
fuse_ops: bool
Whether to attempt fusing operations
recipe: Recipe, optional
Quantization recipe to use when fusing and executing operations.
Note: certain fusions may depend on what kind of recipe is being used.
"""
def
__init__
(
self
,
ops
:
list
[
FusibleOperation
],
fuse_ops
:
bool
=
True
,
fuse_ops
:
bool
,
recipe
:
Optional
[
Recipe
],
)
->
None
:
# Get list of basic operations
...
...
@@ -321,7 +351,11 @@ class OperationFuser:
self
.
_forward_ops
=
[(
op
,
(
idx
,))
for
idx
,
op
in
enumerate
(
self
.
_basic_ops
)]
self
.
_backward_ops
=
list
(
reversed
(
self
.
_forward_ops
))
# Flag for checking if this is the first iteration
self
.
_is_first_forward
=
True
# Fuse ops if needed
self
.
recipe
=
recipe
if
fuse_ops
:
self
.
fuse_ops
()
...
...
@@ -333,6 +367,7 @@ class OperationFuser:
def
_fuse_forward_ops
(
cls
,
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
recipe
:
Optional
[
Recipe
],
# pylint: disable=unused-argument
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
"""Attempt to fuse operations in forward pass"""
ops
=
fuse_userbuffers_forward_linear
(
ops
)
...
...
@@ -344,16 +379,18 @@ class OperationFuser:
def
_fuse_backward_ops
(
cls
,
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
recipe
:
Optional
[
Recipe
],
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
"""Attempt to fuse operations in backward pass"""
ops
=
fuse_userbuffers_backward_linear
(
ops
)
ops
=
fuse_backward_linear_add
(
ops
)
ops
=
fuse_backward_bias_activation
(
ops
,
recipe
)
return
ops
def
fuse_ops
(
self
)
->
None
:
"""Attempt to fuse operations"""
self
.
_forward_ops
=
self
.
_fuse_forward_ops
(
self
.
_forward_ops
)
self
.
_backward_ops
=
self
.
_fuse_backward_ops
(
self
.
_backward_ops
)
self
.
_forward_ops
=
self
.
_fuse_forward_ops
(
self
.
_forward_ops
,
self
.
recipe
)
self
.
_backward_ops
=
self
.
_fuse_backward_ops
(
self
.
_backward_ops
,
self
.
recipe
)
def
__call__
(
self
,
...
...
@@ -368,8 +405,10 @@ class OperationFuser:
)
# Initialization before forward pass
if
self
.
_is_first_forward
:
for
op
in
self
.
_basic_ops
:
op
.
pre_forward
()
op
.
pre_first_forward
(
recipe
=
self
.
recipe
)
self
.
_is_first_forward
=
False
# Canonicalize op kwargs
if
basic_op_kwargs
is
None
:
...
...
transformer_engine/pytorch/ops/op.py
View file @
44740c6c
...
...
@@ -65,17 +65,27 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def
is_fused_op
(
self
)
->
bool
:
"""Whether this op is the fusion of one or more basic ops"""
def
pre_forward
(
self
)
->
None
:
def
pre_first_forward
(
self
,
*
,
recipe
:
Optional
[
Recipe
],
)
->
None
:
"""Preprocessing before forward pass"""
def
get_input_quantizer
(
self
)
->
Optional
[
Quantizer
]:
"""Get builder class for quantized input tensor"""
def
get_grad_input_quantizer
(
self
)
->
Optional
[
Quantizer
]:
"""Get builder class for quantized input's grad tensor"""
def
fuser_forward
(
self
,
basic_op_ctxs
:
list
[
OperationContext
],
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
basic_op_prev_ops
:
list
[
Optional
[
BasicOperation
]
],
basic_op_next_ops
:
list
[
Optional
[
BasicOperation
]
],
prev_op_grad_input_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
torch
.
Tensor
]]]:
"""Forward pass
...
...
@@ -94,12 +104,10 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
Input tensor
basic_op_extra_inputs: list of torch.Tensor
Extra tensor inputs to basic operations
basic_op_prev_ops: list of BasicOperation
Basic operations that preceed this operation's basic
operations
basic_op_next_ops: list of BasicOperation
Basic operations that follow this operation's basic
operations
prev_op_grad_input_quantizer: Quantizer, optional
The grad_input_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation
basic_op_kwargs: list of dict
Keyword arguments to forward functions of basic
operations.
...
...
@@ -201,17 +209,23 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""
return
0
def
get_input_quantizer
(
self
)
->
Optional
[
Quantizer
]:
if
self
.
num_quantizers
(
"forward"
)
>
0
:
return
self
.
get_quantizer
(
"forward"
,
0
)
return
None
def
get_grad_input_quantizer
(
self
)
->
Optional
[
Quantizer
]:
if
self
.
num_quantizers
(
"backward"
)
>
0
:
return
self
.
get_quantizer
(
"backward"
,
0
)
return
None
def
_reset_quantization_recipe_state
(
self
,
*
,
recipe
:
Optional
[
Recipe
]
=
Non
e
,
recipe
:
Recip
e
,
)
->
None
:
"""Construct state for quantization recipe"""
# Quantization recipe
if
recipe
is
None
:
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
# Quantization recipe state for forward and backward pass
self
.
_fp8_metas
=
{
"forward"
:
None
,
"backward"
:
None
}
self
.
_quantizers
=
{
"forward"
:
[],
"backward"
:
[]}
...
...
@@ -246,14 +260,10 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
def
_update_quantization_recipe_state
(
self
,
*
,
recipe
:
Optional
[
Recipe
]
=
Non
e
,
recipe
:
Recip
e
,
)
->
None
:
"""Make sure quantizer state matches quantization recipe"""
# Quantization recipe
if
recipe
is
None
:
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
# Reset quantization state if needed
if
self
.
_fp8_metas
is
None
or
self
.
_quantizers
is
None
:
self
.
_reset_quantization_recipe_state
(
recipe
=
recipe
)
...
...
@@ -327,7 +337,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""
if
self
.
_quantizers
is
None
:
self
.
_reset_quantization_recipe_state
()
self
.
_reset_quantization_recipe_state
(
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
)
return
self
.
_quantizers
[
mode
][
index
]
@
torch
.
no_grad
()
...
...
@@ -378,19 +388,16 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self
.
_fp8_metas
[
mode
][
fp8_meta_key
].
scale
.
copy_
(
scale
)
self
.
_fp8_metas
[
mode
][
fp8_meta_key
].
amax_history
.
copy_
(
amax_history
)
def
pre_forward
(
def
pre_
first_
forward
(
self
,
*
,
fp8_enabled
:
Optional
[
bool
]
=
None
,
fp8_recipe
:
Optional
[
Recipe
]
=
None
,
recipe
:
Optional
[
Recipe
],
)
->
None
:
"""Preprocessing before forward pass"""
# Initialize FP8 metadata if needed
if
fp8_enabled
is
None
:
fp8_enabled
=
FP8GlobalStateManager
.
is_fp8_enabled
()
if
fp8_enabled
:
self
.
_update_quantization_recipe_state
(
recipe
=
fp8_recipe
)
if
recipe
is
not
None
:
self
.
_update_quantization_recipe_state
(
recipe
=
recipe
)
if
not
FP8GlobalStateManager
.
fp8_graph_capturing
():
if
self
.
num_quantizers
(
"forward"
):
FP8GlobalStateManager
.
add_fp8_tensors_to_global_buffer
(
...
...
@@ -407,8 +414,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
*
,
prev_op
:
Optional
[
BasicOperation
]
=
None
,
next_op
:
Optional
[
BasicOperation
]
=
None
,
prev_op
_grad_input_quantizer
:
Optional
[
Quantizer
]
,
next_op
_input_quantizer
:
Optional
[
Quantizer
]
,
**
kwargs
:
Any
,
)
->
torch
.
Tensor
:
"""Forward pass
...
...
@@ -419,10 +426,10 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
Context to coordinate between forward and backward passes
input_: torch.Tensor
Input tensor
prev_op
: BasicOperation
, optional
Basic operation
th
at
preceed
s this
operation
next_op
: BasicOperation
, optional
Basic operation
th
at
follow
s this
operation
prev_op
_grad_input_quantizer: Quantizer
, optional
The grad_input_quantizer of
th
e
preceed
ing
operation
next_op
_input_quantizer: Quantizer
, optional
The input_quantizer of
th
e
follow
ing
operation
Returns
-------
...
...
@@ -461,8 +468,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
basic_op_prev_ops
:
list
[
Optional
[
BasicOperation
]
],
basic_op_next_ops
:
list
[
Optional
[
BasicOperation
]
],
prev_op_grad_input_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
list
[
tuple
[()]]]:
if
self
.
num_extra_inputs
>
0
or
self
.
num_extra_outputs
>
0
:
...
...
@@ -475,8 +482,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
output
=
self
.
op_forward
(
basic_op_ctxs
[
0
],
input_
,
prev_op
=
basic_op_prev_ops
[
0
]
,
next_op
=
basic_op_next_ops
[
0
]
,
prev_op
_grad_input_quantizer
=
prev_op_grad_input_quantizer
,
next_op
_input_quantizer
=
next_op_input_quantizer
,
**
basic_op_kwargs
[
0
],
)
return
output
,
[()]
...
...
@@ -511,7 +518,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""Apply operation"""
from
.fuser
import
OperationFuser
return
OperationFuser
([
self
],
fuse_ops
=
False
)(
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
with_quantized_compute
else
None
return
OperationFuser
([
self
],
fuse_ops
=
False
,
recipe
=
recipe
)(
input
,
*
extra_inputs
,
basic_op_kwargs
=
[
kwargs
],
...
...
@@ -621,7 +630,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Get op's quantizer state, initializing if needed
if
self
.
_fp8_metas
is
None
or
self
.
_fp8_metas
[
mode
]
is
None
:
with
fp8_autocast
(
fp8_recipe
=
state
[
mode
][
"recipe"
]):
self
.
_reset_quantization_recipe_state
()
self
.
_reset_quantization_recipe_state
(
recipe
=
state
[
mode
][
"recipe"
]
)
fp8_meta
=
self
.
_fp8_metas
[
mode
]
# Load extra items
...
...
@@ -696,10 +705,16 @@ class FusedOperation(FusibleOperation):
def
is_fused_op
(
self
)
->
bool
:
return
True
def
pre_forward
(
self
)
->
None
:
def
get_input_quantizer
(
self
)
->
Optional
[
Quantizer
]:
return
self
.
basic_ops
[
0
].
get_input_quantizer
()
def
get_grad_input_quantizer
(
self
)
->
Optional
[
Quantizer
]:
return
self
.
basic_ops
[
-
1
].
get_grad_input_quantizer
()
def
pre_first_forward
(
self
,
*
args
,
**
kwargs
)
->
None
:
"""Preprocessing before forward pass"""
for
op
in
self
.
basic_ops
:
op
.
pre_forward
()
op
.
pre_
first_
forward
(
*
args
,
**
kwargs
)
def
forward
(
self
,
...
...
@@ -712,7 +727,9 @@ class FusedOperation(FusibleOperation):
basic_op_kwargs
=
[{}
for
_
in
range
(
len
(
self
.
basic_ops
))]
from
.fuser
import
OperationFuser
return
OperationFuser
([
self
],
fuse_ops
=
False
)(
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
with_quantized_compute
else
None
return
OperationFuser
([
self
],
fuse_ops
=
False
,
recipe
=
recipe
)(
input
,
*
extra_inputs
,
basic_op_kwargs
=
basic_op_kwargs
,
...
...
transformer_engine/pytorch/ops/sequential.py
View file @
44740c6c
...
...
@@ -10,6 +10,7 @@ from typing import Optional
import
torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
,
Recipe
from
transformer_engine.pytorch.ops.op
import
FusibleOperation
from
transformer_engine.pytorch.ops.fuser
import
OperationFuser
...
...
@@ -37,6 +38,9 @@ class Sequential(torch.nn.Module):
self
.
_module_groups
:
Optional
[
list
[
OperationFuser
|
torch
.
nn
.
Module
]]
self
.
_module_groups
=
None
# Global state of last iteration
self
.
_last_global_state
=
None
# Add modules
if
len
(
args
)
==
1
and
isinstance
(
args
[
0
],
dict
):
for
key
,
module
in
args
[
0
].
items
():
...
...
@@ -143,6 +147,7 @@ class Sequential(torch.nn.Module):
def
_make_module_groups
(
cls
,
modules
:
Iterable
[
torch
.
nn
.
Module
],
recipe
:
Optional
[
Recipe
],
)
->
list
[
OperationFuser
|
torch
.
nn
.
Module
]:
"""Make list of modules, with fusible operations grouped together"""
...
...
@@ -157,7 +162,7 @@ class Sequential(torch.nn.Module):
groups
.
append
(
module
)
for
idx
,
group
in
enumerate
(
groups
):
if
isinstance
(
group
,
list
):
groups
[
idx
]
=
OperationFuser
(
group
,
fuse_ops
=
True
)
groups
[
idx
]
=
OperationFuser
(
group
,
fuse_ops
=
True
,
recipe
=
recipe
)
# Check if operations expect extra input or output tensors
# Note: If any op has extra inputs or outputs, then the entire
...
...
@@ -185,9 +190,19 @@ class Sequential(torch.nn.Module):
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
...]:
"""Forward pass"""
# Get current global state
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
with_quantized_compute
else
None
global_state
=
(
with_quantized_compute
,
type
(
recipe
))
# Reset module groups is global state changed
if
self
.
_last_global_state
!=
global_state
:
self
.
_module_groups
=
None
self
.
_last_global_state
=
global_state
# Create module groups if needed
if
self
.
_module_groups
is
None
:
self
.
_module_groups
=
self
.
_make_module_groups
(
self
.
_modules
.
values
())
self
.
_module_groups
=
self
.
_make_module_groups
(
self
.
_modules
.
values
()
,
recipe
)
# Forward pass for each module group
x
=
input
...
...
transformer_engine/pytorch/permutation.py
View file @
44740c6c
...
...
@@ -349,7 +349,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
if
restore_shape
is
None
:
restore_shape
=
inp
.
shape
num_tokens
,
hidden_size
=
restore_shape
num_experts
=
row_id_map
.
size
(
0
)
num_experts
=
(
row_id_map
.
size
(
1
)
-
1
)
//
2
with_probs
=
merging_probs
is
not
None
if
with_probs
:
...
...
@@ -651,14 +651,20 @@ class _moe_chunk_sort(torch.autograd.Function):
fp8_scale_inv
=
inp
.
_scale_inv
fake_dtype
=
inp
.
dtype
inp
=
inp
.
_data
output
,
row_id_map
,
permuted_probs
=
triton_permutation
.
sort_chunks_by_idx
(
inp
,
row_id_map
=
triton_permutation
.
make_chunk_sort_map
(
split_sizes
,
sorted_idxs
,
num_tokens
,
num_splits
,
)
output
,
permuted_probs
=
triton_permutation
.
sort_chunks_by_map
(
inp
,
row_id_map
,
probs
,
num_tokens
,
hidden_size
,
num_splits
,
is_forward
=
True
,
)
if
fp8
:
output
=
Float8Tensor
(
...
...
@@ -700,6 +706,7 @@ class _moe_chunk_sort(torch.autograd.Function):
permuted_probs_grad
,
ctx
.
num_tokens
,
ctx
.
hidden_size
,
is_forward
=
False
,
)
if
fp8
:
act_grad
=
Float8Tensor
(
...
...
transformer_engine/pytorch/router.py
0 → 100644
View file @
44740c6c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Fused functions used in the MoE router
"""
import
torch
import
transformer_engine_torch
as
tex
class
FusedTopkScoreFunction
(
torch
.
autograd
.
Function
):
"""
Fused Topk with Score Function router.
Currently, only support softmax and sigmoid.
"""
@
staticmethod
def
forward
(
ctx
,
logits
:
torch
.
Tensor
,
topk
:
int
,
use_pre_softmax
:
bool
,
num_groups
:
int
,
group_topk
:
int
,
scaling_factor
:
float
,
score_function
:
str
,
expert_bias
:
torch
.
Tensor
,
):
# pylint: disable=missing-function-docstring
# Save the shape of the logits
tensor_shape
=
logits
.
shape
logits
=
logits
.
view
(
-
1
,
tensor_shape
[
-
1
])
# Get the metadata of the viewed logits
num_tokens
=
logits
.
size
(
0
)
num_experts
=
logits
.
size
(
1
)
probs
,
routing_map
,
intermediate_output
=
tex
.
fused_topk_with_score_function_fwd
(
logits
,
topk
,
use_pre_softmax
,
num_groups
,
group_topk
,
scaling_factor
,
score_function
,
expert_bias
,
)
# Restore the shape
probs
=
probs
.
view
(
tensor_shape
)
ctx
.
save_for_backward
(
routing_map
,
intermediate_output
)
ctx
.
num_tokens
=
num_tokens
ctx
.
num_experts
=
num_experts
ctx
.
use_pre_softmax
=
use_pre_softmax
ctx
.
topk
=
topk
ctx
.
scaling_factor
=
scaling_factor
ctx
.
score_function
=
score_function
return
probs
,
routing_map
@
staticmethod
def
backward
(
ctx
,
grad_probs
,
_
):
# pylint: disable=missing-function-docstring
routing_map
,
intermediate_output
=
ctx
.
saved_tensors
# Save the shape of the grad_probs
tensor_shape
=
grad_probs
.
shape
# Adjust the shape of the grad_probs to 2D shape
grad_probs
=
grad_probs
.
contiguous
().
view
(
-
1
,
tensor_shape
[
-
1
])
grad_logits
=
tex
.
fused_topk_with_score_function_bwd
(
ctx
.
num_tokens
,
ctx
.
num_experts
,
routing_map
,
intermediate_output
,
grad_probs
,
ctx
.
topk
,
ctx
.
use_pre_softmax
,
ctx
.
scaling_factor
,
ctx
.
score_function
,
)
# Restore the shape
grad_logits
=
grad_logits
.
view
(
tensor_shape
)
return
grad_logits
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
fused_topk_with_score_function
(
logits
:
torch
.
Tensor
,
topk
:
int
,
use_pre_softmax
:
bool
,
num_groups
:
int
,
group_topk
:
int
,
scaling_factor
:
float
,
score_function
:
str
,
expert_bias
:
torch
.
Tensor
,
):
"""
Fused topk with score function router.
Parameters
----------
logits: torch.Tensor
topk: int
use_pre_softmax: bool
if enabled, the computation order: softmax -> topk
num_groups: int
used in the group topk
group_topk: int
used in the group topk
scaling_factor: float
score_function: str
currently only support softmax and sigmoid
expert_bias: torch.Tensor
could be used in the sigmoid
Returns
-------
probs: torch.Tensor
routing_map: torch.Tensor
"""
if
logits
.
dtype
==
torch
.
float64
:
raise
ValueError
(
"Current TE does not support float64 router type"
)
return
FusedTopkScoreFunction
.
apply
(
logits
,
topk
,
use_pre_softmax
,
num_groups
,
group_topk
,
scaling_factor
,
score_function
,
expert_bias
,
)
class
FusedComputeScoresForMoEAuxLoss
(
torch
.
autograd
.
Function
):
"""
Fused compute scores for MoE aux loss.
"""
@
staticmethod
def
forward
(
ctx
,
logits
:
torch
.
Tensor
,
topk
:
int
,
score_function
:
str
,
):
# pylint: disable=missing-function-docstring
# Save the shape of the logits
tensor_shape
=
logits
.
shape
logits
=
logits
.
view
(
-
1
,
tensor_shape
[
-
1
])
# Get the metadata of the viewed logits
num_tokens
=
logits
.
size
(
0
)
num_experts
=
logits
.
size
(
1
)
scores
,
routing_map
,
intermediate_output
=
tex
.
fused_score_for_moe_aux_loss_fwd
(
logits
=
logits
,
topk
=
topk
,
score_function
=
score_function
,
)
ctx
.
save_for_backward
(
intermediate_output
)
ctx
.
topk
=
topk
ctx
.
score_function
=
score_function
ctx
.
num_tokens
=
num_tokens
ctx
.
num_experts
=
num_experts
return
routing_map
,
scores
@
staticmethod
def
backward
(
ctx
,
_
,
grad_scores
):
# pylint: disable=missing-function-docstring
intermediate_output
=
ctx
.
saved_tensors
[
0
]
# Save the shape of the grad_scores
tensor_shape
=
grad_scores
.
shape
# Adjust the shape of the grad_scores to 2D shape
grad_scores
=
grad_scores
.
contiguous
().
view
(
-
1
,
tensor_shape
[
-
1
])
grad_logits
=
tex
.
fused_score_for_moe_aux_loss_bwd
(
num_tokens
=
ctx
.
num_tokens
,
num_experts
=
ctx
.
num_experts
,
intermediate_output
=
intermediate_output
,
grad_scores
=
grad_scores
,
topk
=
ctx
.
topk
,
score_function
=
ctx
.
score_function
,
)
# Restore the shape
grad_logits
=
grad_logits
.
view
(
tensor_shape
)
return
grad_logits
,
None
,
None
def
fused_compute_score_for_moe_aux_loss
(
logits
:
torch
.
Tensor
,
topk
:
int
,
score_function
:
str
,
):
"""
Fused compute scores for MoE aux loss, subset of the fused_topk_with_score_function.
Parameters
----------
logits: torch.Tensor
topk: int
score_function: str
currently only support softmax and sigmoid
Returns
-------
routing_map: torch.Tensor
scores: torch.Tensor
"""
return
FusedComputeScoresForMoEAuxLoss
.
apply
(
logits
,
topk
,
score_function
)
class
FusedAuxLoss
(
torch
.
autograd
.
Function
):
"""
Fused MoE aux loss.
"""
@
staticmethod
def
forward
(
ctx
,
probs
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
total_num_tokens
:
int
,
num_experts
:
int
,
topk
:
int
,
coeff
:
float
,
):
# pylint: disable=missing-function-docstring
num_rows
=
probs
.
size
(
0
)
num_cols
=
probs
.
size
(
1
)
aux_loss
,
Const_buf
=
tex
.
fused_moe_aux_loss_fwd
(
probs
=
probs
,
tokens_per_expert
=
tokens_per_expert
,
total_num_tokens
=
total_num_tokens
,
num_experts
=
num_experts
,
num_rows
=
num_rows
,
num_cols
=
num_cols
,
topk
=
topk
,
coeff
=
coeff
,
)
ctx
.
save_for_backward
(
Const_buf
,
tokens_per_expert
)
ctx
.
num_rows
=
num_rows
ctx
.
num_cols
=
num_cols
return
aux_loss
@
staticmethod
def
backward
(
ctx
,
grad_aux_loss
):
# pylint: disable=missing-function-docstring
Const_buf
,
tokens_per_expert
=
ctx
.
saved_tensors
grad_probs
=
tex
.
fused_moe_aux_loss_bwd
(
Const_buf
=
Const_buf
,
tokens_per_expert
=
tokens_per_expert
,
num_rows
=
ctx
.
num_rows
,
num_cols
=
ctx
.
num_cols
,
grad_aux_loss
=
grad_aux_loss
,
)
return
grad_probs
,
None
,
None
,
None
,
None
,
None
def
fused_moe_aux_loss
(
probs
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
total_num_tokens
:
int
,
num_experts
:
int
,
topk
:
int
,
coeff
:
float
,
):
"""
Fused MoE aux loss.
Parameters
----------
probs: torch.Tensor
tokens_per_expert: torch.Tensor
the number of tokens per expert
total_num_tokens: int
the total number of tokens, involved in the aux loss calculation
num_experts: int
topk: int
coeff: float
the coefficient of the aux loss
Returns
-------
aux_loss: torch.scalar
"""
return
FusedAuxLoss
.
apply
(
probs
,
tokens_per_expert
,
total_num_tokens
,
num_experts
,
topk
,
coeff
)
transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
View file @
44740c6c
...
...
@@ -43,7 +43,6 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
def
__new__
(
cls
,
*
args
,
rowwise_data
:
Optional
[
torch
.
Tensor
],
rowwise_scale_inv
:
Optional
[
torch
.
Tensor
],
columnwise_data
:
Optional
[
torch
.
Tensor
],
...
...
@@ -51,9 +50,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
fp8_dtype
:
TE_DType
,
quantizer
:
Quantizer
,
is_2D_scaled
:
bool
,
data_format
:
Float8BlockScaleTensorFormat
=
Float8BlockScaleTensorFormat
.
GEMM_READY
,
data_format
:
Float8BlockScaleTensorFormat
,
*
args
,
**
kwargs
,
):
if
cls
is
Float8BlockwiseQTensorBase
:
instance
=
object
.
__new__
(
cls
)
else
:
instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
instance
.
_rowwise_data
=
rowwise_data
instance
.
_columnwise_data
=
columnwise_data
...
...
transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py
View file @
44740c6c
...
...
@@ -143,6 +143,23 @@ class Float8TensorBase(QuantizedTensorBase):
size
=
self
.
_transpose
.
size
(
*
args
,
**
kwargs
)
return
torch
.
Size
([
size
[
-
1
],
math
.
prod
(
size
[:
-
1
])])
def
view
(
self
,
shape
:
torch
.
Size
):
# pylint: disable=missing-function-docstring
out_data
=
self
.
_data
.
view
(
shape
)
out_transpose
=
None
if
self
.
_transpose_invalid
else
self
.
_transpose
if
out_transpose
is
not
None
:
out_transpose_shape
=
out_transpose
.
size
()
if
out_transpose_shape
[
0
]
!=
shape
[
-
1
]
or
out_transpose_shape
[
1
:]
!=
shape
[:
-
1
]:
out_transpose
=
None
return
Float8TensorBase
(
data
=
out_data
,
fp8_scale_inv
=
self
.
_scale_inv
,
fp8_dtype
=
self
.
_fp8_dtype
,
data_transpose
=
out_transpose
,
quantizer
=
self
.
_quantizer
,
)
def
__repr__
(
self
):
return
(
"Float8TensorBase("
...
...
transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py
View file @
44740c6c
...
...
@@ -6,6 +6,8 @@
from
__future__
import
annotations
from
typing
import
Optional
,
Dict
,
Any
,
Tuple
from
collections.abc
import
Iterable
import
math
import
torch
import
transformer_engine_torch
as
tex
...
...
@@ -66,15 +68,18 @@ class MXFP8TensorBase(QuantizedTensorBase):
def
__new__
(
cls
,
*
args
,
rowwise_data
:
Optional
[
torch
.
Tensor
],
rowwise_scale_inv
:
torch
.
Tensor
,
rowwise_scale_inv
:
Optional
[
torch
.
Tensor
]
,
columnwise_data
:
Optional
[
torch
.
Tensor
],
columnwise_scale_inv
:
torch
.
Tensor
,
columnwise_scale_inv
:
Optional
[
torch
.
Tensor
]
,
fp8_dtype
:
TE_DType
,
quantizer
:
Optional
[
Quantizer
]
=
None
,
quantizer
:
Optional
[
Quantizer
],
*
args
,
**
kwargs
,
):
if
cls
is
MXFP8TensorBase
:
instance
=
object
.
__new__
(
cls
)
else
:
instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
instance
.
_rowwise_data
=
rowwise_data
instance
.
_columnwise_data
=
columnwise_data
...
...
@@ -145,6 +150,51 @@ class MXFP8TensorBase(QuantizedTensorBase):
return
self
.
_rowwise_data
.
size
(
*
args
,
**
kwargs
)
return
self
.
_columnwise_data
.
size
(
*
args
,
**
kwargs
)
def
view
(
self
,
shape
:
torch
.
Size
):
# pylint: disable=missing-function-docstring
# Return input tensor if view not needed
cur_shape
=
self
.
size
()
if
shape
is
None
or
shape
==
cur_shape
:
return
self
# Canonicalize shape
if
not
isinstance
(
shape
,
Iterable
):
shape
=
[
shape
]
elif
len
(
shape
)
==
1
and
isinstance
(
shape
[
0
],
Iterable
):
shape
=
shape
[
0
]
if
-
1
in
shape
:
shape
=
list
(
shape
)
d_inferred
=
-
math
.
prod
(
cur_shape
)
//
math
.
prod
(
shape
)
for
i
,
d
in
enumerate
(
shape
):
if
d
==
-
1
:
shape
[
i
]
=
d_inferred
break
if
shape
[
-
1
]
!=
cur_shape
[
-
1
]:
raise
RuntimeError
(
"MXFP8Tensor does not support reshaping inner dimension "
f
"(attempted to reshape dims=
{
tuple
(
cur_shape
)
}
to
{
tuple
(
shape
)
}
)"
)
# Construct new tensor
cur_rowwise_data
=
self
.
_rowwise_data
cur_columnwise_data
=
self
.
_columnwise_data
new_rowwise_data
=
None
new_columnwise_data
=
None
if
cur_rowwise_data
is
not
None
:
new_rowwise_data
=
cur_rowwise_data
.
view
(
*
shape
)
if
cur_columnwise_data
is
not
None
:
new_columnwise_data
=
cur_columnwise_data
.
view
(
*
shape
)
return
MXFP8TensorBase
(
rowwise_data
=
new_rowwise_data
,
rowwise_scale_inv
=
self
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
columnwise_scale_inv
=
self
.
_columnwise_scale_inv
,
fp8_dtype
=
self
.
_fp8_dtype
,
quantizer
=
self
.
_quantizer
,
)
def
__repr__
(
self
):
data_rowwise
=
self
.
dequantize
()
...
...
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
View file @
44740c6c
...
...
@@ -11,6 +11,7 @@ import torch
import
transformer_engine_torch
as
tex
import
os
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine_torch
import
Float8BlockScaleTensorFormat
from
transformer_engine.common.recipe
import
Float8BlockScaling
,
Recipe
from
._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
...
...
@@ -297,6 +298,37 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
holds configuration about quantization and dequantization modes.
"""
# NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorBase with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def
__new__
(
cls
,
*
args
,
rowwise_data
:
Optional
[
torch
.
Tensor
],
rowwise_scale_inv
:
Optional
[
torch
.
Tensor
],
columnwise_data
:
Optional
[
torch
.
Tensor
],
columnwise_scale_inv
:
Optional
[
torch
.
Tensor
],
fp8_dtype
:
TE_DType
,
quantizer
:
Quantizer
,
is_2D_scaled
:
bool
,
data_format
:
tex
.
Float8BlockScaleTensorFormat
=
Float8BlockScaleTensorFormat
.
GEMM_READY
,
**
kwargs
,
):
instance
=
super
().
__new__
(
cls
,
rowwise_data
,
rowwise_scale_inv
,
columnwise_data
,
columnwise_scale_inv
,
fp8_dtype
,
quantizer
,
is_2D_scaled
,
data_format
,
*
args
,
**
kwargs
,
)
return
instance
def
__repr__
(
self
,
*
,
tensor_contents
=
None
):
return
(
f
"Float8BlockwiseQTensor(fp8_dtype=
{
self
.
_fp8_dtype
}
,"
...
...
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
44740c6c
...
...
@@ -167,6 +167,21 @@ class Float8Quantizer(Quantizer):
quantizer
=
self
,
)
def
onnx_quantize
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Function using primitives with ONNX defined translations."""
# Q inputs are currently constrained to FP32 due to a similar limitation in ORT
# custom ops, so cast the input if needed.
if
tensor
.
dtype
!=
torch
.
float32
:
tensor
=
tensor
.
to
(
torch
.
float32
)
data
=
torch
.
ops
.
tex
.
fp8_quantize
(
tensor
,
self
.
scale
.
item
())
return
self
.
create_tensor_from_data
(
data
,
fake_dtype
=
torch
.
float32
)
def
onnx_dequantize
(
self
,
tensor
:
QuantizedTensor
)
->
torch
.
Tensor
:
"""Function using primitives with ONNX defined translations."""
out
=
torch
.
ops
.
tex
.
fp8_dequantize
(
tensor
.
_data
,
self
.
scale
.
item
())
out
=
out
.
to
(
tensor
.
dtype
)
return
out
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
return
DelayedScaling
...
...
@@ -328,6 +343,18 @@ class Float8CurrentScalingQuantizer(Quantizer):
quantizer
=
self
,
)
def
onnx_quantize
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Function using primitives with ONNX defined translations."""
raise
NotImplementedError
(
"Float8CurrentScalingQuantizer does not support ONNX quantization yet."
)
def
onnx_dequantize
(
self
,
tensor
:
QuantizedTensor
)
->
torch
.
Tensor
:
"""Function using primitives with ONNX defined translations."""
raise
NotImplementedError
(
"Float8CurrentScalingQuantizer does not support ONNX dequantization yet."
)
def
_canonicalized_amax_reduction_group
(
self
)
->
dist_group_type
:
"""Get process group for amax reduction"""
return
canonicalize_process_group
(
self
.
amax_reduction_group
)
...
...
transformer_engine/pytorch/tensor/mxfp8_tensor.py
View file @
44740c6c
...
...
@@ -136,6 +136,34 @@ class MXFP8Quantizer(Quantizer):
# TODO(ksivamani): No calibration needed for mxfp8?
pass
def
create_tensor_from_data
(
self
,
data
:
torch
.
Tensor
,
scale_inv
:
torch
.
Tensor
,
fake_dtype
:
torch
.
dtype
,
fp8_dtype
:
TE_DType
=
tex
.
DType
.
kFloat8E4M3
,
)
->
MXFP8Tensor
:
"""Create a new MXFP8Tensor from data and scale_inv."""
return
MXFP8Tensor
(
shape
=
data
.
shape
,
dtype
=
fake_dtype
,
rowwise_data
=
data
,
rowwise_scale_inv
=
scale_inv
,
columnwise_data
=
None
,
columnwise_scale_inv
=
None
,
fp8_dtype
=
fp8_dtype
,
quantizer
=
self
,
)
def
onnx_quantize
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
if
tensor
.
dtype
!=
torch
.
float32
:
tensor
=
tensor
.
to
(
dtype
=
torch
.
float32
)
data
,
scale_inv
=
torch
.
ops
.
tex
.
mxfp8_quantize
(
tensor
)
return
self
.
create_tensor_from_data
(
data
,
scale_inv
,
fake_dtype
=
torch
.
float32
)
def
onnx_dequantize
(
self
,
tensor
:
Union
[
MXFP8TensorBase
,
MXFP8Tensor
])
->
torch
.
Tensor
:
return
torch
.
ops
.
tex
.
mxfp8_dequantize
(
tensor
.
_rowwise_data
,
tensor
.
_rowwise_scale_inv
)
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
return
MXFP8BlockScaling
...
...
@@ -165,6 +193,32 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
"""
# NOTE: We reorder the *args so that we can instantiate a MXFP8TensorBase with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def
__new__
(
cls
,
*
args
,
rowwise_data
:
Optional
[
torch
.
Tensor
],
rowwise_scale_inv
:
Optional
[
torch
.
Tensor
],
columnwise_data
:
Optional
[
torch
.
Tensor
],
columnwise_scale_inv
:
Optional
[
torch
.
Tensor
],
fp8_dtype
:
TE_DType
,
quantizer
:
Optional
[
Quantizer
],
**
kwargs
,
):
instance
=
super
().
__new__
(
cls
,
rowwise_data
,
rowwise_scale_inv
,
columnwise_data
,
columnwise_scale_inv
,
fp8_dtype
,
quantizer
,
*
args
,
**
kwargs
,
)
return
instance
def
__repr__
(
self
,
*
,
tensor_contents
=
None
):
return
f
"MXFP8Tensor(fp8_dtype=
{
self
.
_fp8_dtype
}
, data=
{
self
.
dequantize
(
dtype
=
self
.
dtype
)
}
)"
...
...
@@ -302,6 +356,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
fp8_dtype
:
TE_DType
,
dtype
:
torch
.
dtype
,
shape
:
torch
.
shape
,
quantizer
:
Optional
[
Quantizer
]
=
None
,
)
->
MXFP8Tensor
:
"""Build MXFP8Tensor, for use in __reduce__
...
...
@@ -317,6 +372,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
columnwise_scale_inv
=
columnwise_scale_inv
,
dtype
=
dtype
,
shape
=
shape
,
quantizer
=
quantizer
,
)
def
__reduce_ex__
(
self
,
protocol
:
int
)
->
tuple
:
...
...
@@ -331,6 +387,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
self
.
_fp8_dtype
,
self
.
dtype
,
self
.
shape
,
self
.
_quantizer
,
),
)
...
...
@@ -437,8 +494,7 @@ class _ViewFunc(torch.autograd.Function):
if
tensor
.
_rowwise_data
is
not
None
:
new_rowwise_data
=
tensor
.
_rowwise_data
.
view
(
*
shape
)
if
tensor
.
_columnwise_data
is
not
None
:
columnwise_shape
=
[
shape
[
-
1
]]
+
list
(
shape
[:
-
1
])
new_columnwise_data
=
tensor
.
_columnwise_data
.
view
(
columnwise_shape
)
new_columnwise_data
=
tensor
.
_columnwise_data
.
view
(
*
shape
)
return
MXFP8Tensor
(
shape
,
tensor
.
dtype
,
...
...
@@ -462,7 +518,7 @@ class _ViewFunc(torch.autograd.Function):
grad
.
_rowwise_data
.
view
(
*
ctx
.
shape
)
if
grad
.
_rowwise_data
is
not
None
else
None
)
if
grad
.
_columnwise_data
is
not
None
:
new_columnwise_data
=
grad
.
_columnwise_data
.
view
(
ctx
.
shape
[
-
1
],
-
1
)
new_columnwise_data
=
grad
.
_columnwise_data
.
view
(
*
ctx
.
shape
)
else
:
new_columnwise_data
=
None
dgrad
=
MXFP8Tensor
(
...
...
@@ -523,8 +579,7 @@ class _ReshapeFunc(torch.autograd.Function):
if
tensor
.
_rowwise_data
is
not
None
:
new_rowwise_data
=
tensor
.
_rowwise_data
.
reshape
(
*
shape
)
if
tensor
.
_columnwise_data
is
not
None
:
columnwise_shape
=
[
shape
[
-
1
]]
+
list
(
shape
[:
-
1
])
new_columnwise_data
=
tensor
.
_columnwise_data
.
view
(
columnwise_shape
)
new_columnwise_data
=
tensor
.
_columnwise_data
.
view
(
*
shape
)
return
MXFP8Tensor
(
shape
,
...
...
@@ -550,8 +605,7 @@ class _ReshapeFunc(torch.autograd.Function):
if
grad
.
_rowwise_data
is
not
None
:
new_rowwise_data
=
grad
.
_rowwise_data
.
view
(
*
ctx
.
shape
)
if
grad
.
_columnwise_data
is
not
None
:
columnwise_shape
=
[
ctx
.
shape
[
-
1
]]
+
list
(
ctx
.
shape
[:
-
1
])
new_columnwise_data
=
grad
.
_columnwise_data
.
view
(
columnwise_shape
)
new_columnwise_data
=
grad
.
_columnwise_data
.
view
(
*
ctx
.
shape
)
dgrad
=
MXFP8Tensor
(
ctx
.
shape
,
grad
.
dtype
,
...
...
transformer_engine/pytorch/tensor/quantized_tensor.py
View file @
44740c6c
...
...
@@ -250,6 +250,12 @@ class Quantizer(abc.ABC):
"""Create shallow copy"""
return
copy
.
copy
(
self
)
def
onnx_quantize
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Symbolic function for ONNX export"""
def
onnx_dequantize
(
self
,
tensor
)
->
torch
.
Tensor
:
"""Symbolic function for ONNX export"""
@
abc
.
abstractmethod
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
"""Returns recipe class that is compatible with this quantizer"""
...
...
transformer_engine/pytorch/tensor/utils.py
View file @
44740c6c
...
...
@@ -194,7 +194,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo
quantizer
.
update_quantized
(
master_weight
.
view
(
1
,
-
1
),
shard_model_weight_fp8
)
if
len
(
amaxes
)
>
0
:
dummy_overflow_buf
=
torch
.
tensor
([
0
]
,
dtype
=
torch
.
int
,
device
=
amaxes
[
0
].
device
)
dummy_overflow_buf
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
,
device
=
amaxes
[
0
].
device
)
# Reduce amaxes.
packed_amaxes
=
torch
.
empty
(
len
(
amaxes
),
dtype
=
torch
.
float32
,
device
=
amaxes
[
0
].
device
)
...
...
transformer_engine/pytorch/transformer.py
View file @
44740c6c
...
...
@@ -33,6 +33,7 @@ from transformer_engine.pytorch.constants import (
dist_group_type
,
)
from
transformer_engine.pytorch.distributed
import
get_distributed_world_size
from
transformer_engine.pytorch.export
import
is_in_onnx_export_mode
from
transformer_engine.pytorch.module.base
import
TransformerEngineBaseModule
...
...
@@ -814,7 +815,12 @@ class TransformerLayer(torch.nn.Module):
return
output
def
_bias_dropout_add
(
self
,
hidden_state
,
bias
,
residual
,
drop_path
=
None
):
if
drop_path
is
None
and
bias
is
not
None
and
bias
.
numel
()
!=
0
:
if
(
drop_path
is
None
and
bias
is
not
None
and
bias
.
numel
()
!=
0
and
not
is_in_onnx_export_mode
()
):
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
...
...
transformer_engine/pytorch/triton/cross_entropy.py
View file @
44740c6c
...
...
@@ -98,6 +98,7 @@ def cross_entropy_kernel(
ignore_idx
,
n_cols
,
n_non_ignore
,
reduce_loss
:
tl
.
constexpr
,
label_smoothing
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
...
...
@@ -177,7 +178,13 @@ def cross_entropy_kernel(
if
label_smoothing
>
0
:
# scale X beforehand to avoid overflow
scaled_x_sum
+=
tl
.
sum
(
tl
.
where
(
X_offsets
<
n_cols
,
-
eps
*
X_block
,
0.0
))
# Scale gradients based on reduction mode
# For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore
# For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here
if
reduce_loss
:
X_block
=
(
tl
.
exp
(
X_block
-
m
)
/
d
-
eps
)
/
(
n_non_ignore
)
else
:
X_block
=
tl
.
exp
(
X_block
-
m
)
/
d
-
eps
tl
.
store
(
X_ptr
+
X_offsets
,
X_block
.
to
(
grad_dtype
),
mask
=
X_offsets
<
n_cols
)
# We need tl.debug_barrier() to ensure the new result of X_ptr is written
...
...
@@ -205,7 +212,11 @@ def cross_entropy_kernel(
if
y
>=
vocab_start_idx
:
if
y
<
vocab_end_idx
:
X_y
=
tl
.
load
(
X_ptr
+
y
-
vocab_start_idx
)
# Apply the same conditional scaling logic for the target token
if
reduce_loss
:
X_y
+=
-
(
1
-
label_smoothing
)
/
(
n_non_ignore
)
else
:
X_y
+=
-
(
1
-
label_smoothing
)
tl
.
store
(
X_ptr
+
y
-
vocab_start_idx
,
X_y
)
tl
.
store
(
loss_ptr
,
loss
)
...
...
@@ -319,6 +330,7 @@ def cross_entropy_forward(
ignore_idx
=
ignore_idx
,
n_cols
=
V
,
n_non_ignore
=
n_rows
,
reduce_loss
=
reduce_loss
,
label_smoothing
=
label_smoothing
,
BLOCK_SIZE
=
BLOCK_SIZE
,
num_warps
=
16
if
IS_HIP_EXTENSION
else
32
,
...
...
Prev
1
…
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment