Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
2b05e121
Commit
2b05e121
authored
Jun 17, 2025
by
yuguo
Browse files
Merge commit '
a69692ac
' of...
Merge commit '
a69692ac
' of
https://github.com/NVIDIA/TransformerEngine
parents
0fd441c2
a69692ac
Changes
245
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1003 additions
and
353 deletions
+1003
-353
transformer_engine/pytorch/jit.py
transformer_engine/pytorch/jit.py
+91
-0
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+124
-17
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+36
-24
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+81
-50
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+74
-47
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+72
-43
transformer_engine/pytorch/ops/basic/__init__.py
transformer_engine/pytorch/ops/basic/__init__.py
+1
-0
transformer_engine/pytorch/ops/basic/activation.py
transformer_engine/pytorch/ops/basic/activation.py
+22
-10
transformer_engine/pytorch/ops/basic/basic_linear.py
transformer_engine/pytorch/ops/basic/basic_linear.py
+73
-27
transformer_engine/pytorch/ops/basic/l2normalization.py
transformer_engine/pytorch/ops/basic/l2normalization.py
+128
-0
transformer_engine/pytorch/ops/fused/backward_linear_add.py
transformer_engine/pytorch/ops/fused/backward_linear_add.py
+2
-2
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
...ngine/pytorch/ops/fused/forward_linear_bias_activation.py
+10
-4
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
...ormer_engine/pytorch/ops/fused/forward_linear_bias_add.py
+10
-4
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+21
-11
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+51
-23
transformer_engine/pytorch/ops/fuser.py
transformer_engine/pytorch/ops/fuser.py
+33
-51
transformer_engine/pytorch/pyproject.toml
transformer_engine/pytorch/pyproject.toml
+10
-0
transformer_engine/pytorch/setup.py
transformer_engine/pytorch/setup.py
+3
-13
transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
.../pytorch/tensor/_internal/float8_blockwise_tensor_base.py
+77
-18
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
+84
-9
No files found.
transformer_engine/pytorch/jit.py
View file @
2b05e121
...
@@ -123,6 +123,35 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
...
@@ -123,6 +123,35 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
return
dgelu
return
dgelu
@
jit_fuser
def
l2normalization_fused_
(
x
:
torch
.
Tensor
,
eps
:
float
)
->
torch
.
Tensor
:
"""L2 normalization fused - inference version"""
x_squared
=
x
.
pow
(
2
)
l2_norm_squared
=
x_squared
.
sum
(
dim
=-
1
,
keepdim
=
True
)
rsqrt_norm
=
torch
.
rsqrt
(
l2_norm_squared
+
eps
)
return
x
*
rsqrt_norm
@
jit_fuser
def
l2normalization_fwd_fused_
(
x
:
torch
.
Tensor
,
eps
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""L2 normalization fused - training version that returns intermediate values"""
x_squared
=
x
.
pow
(
2
)
l2_norm_squared
=
x_squared
.
sum
(
dim
=-
1
,
keepdim
=
True
)
rsqrt_norm
=
torch
.
rsqrt
(
l2_norm_squared
+
eps
)
y
=
x
*
rsqrt_norm
return
y
,
rsqrt_norm
@
jit_fuser
def
l2normalization_backward_fused_
(
grad_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
rsqrt_norm
:
torch
.
Tensor
,
eps
:
float
)
->
torch
.
Tensor
:
"""L2 normalization backward fused"""
x_dy_sum
=
(
x
*
grad_output
).
sum
(
dim
=-
1
,
keepdim
=
True
)
x_norm_squared
=
x
.
pow
(
2
).
sum
(
dim
=-
1
,
keepdim
=
True
)
+
eps
return
rsqrt_norm
*
(
grad_output
-
x
*
x_dy_sum
/
x_norm_squared
)
def
bias_gelu_fused
(
inp
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
bias_gelu_fused
(
inp
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Disable native AMP for bias_gelu_fused_"""
"""Disable native AMP for bias_gelu_fused_"""
with
gpu_autocast_ctx
(
enabled
=
False
):
with
gpu_autocast_ctx
(
enabled
=
False
):
...
@@ -141,6 +170,26 @@ def bgrad_dgelu_fused(
...
@@ -141,6 +170,26 @@ def bgrad_dgelu_fused(
return
None
,
dgelu_fused_
(
grad_output
,
inp
)
return
None
,
dgelu_fused_
(
grad_output
,
inp
)
def
l2normalization_fused
(
x
:
torch
.
Tensor
,
eps
:
float
)
->
torch
.
Tensor
:
"""Disable native AMP for l2normalization_fused_ - inference version"""
with
gpu_autocast_ctx
(
enabled
=
False
):
return
l2normalization_fused_
(
x
,
eps
)
def
l2normalization_fwd_fused
(
x
:
torch
.
Tensor
,
eps
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Disable native AMP for l2normalization_fwd_fused_ - training version"""
with
gpu_autocast_ctx
(
enabled
=
False
):
return
l2normalization_fwd_fused_
(
x
,
eps
)
def
l2normalization_backward_fused
(
grad_output
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
rsqrt_norm
:
torch
.
Tensor
,
eps
:
float
)
->
torch
.
Tensor
:
"""Disable native AMP for l2normalization_backward_fused_"""
with
gpu_autocast_ctx
(
enabled
=
False
):
return
l2normalization_backward_fused_
(
grad_output
,
x
,
rsqrt_norm
,
eps
)
def
bias_dropout_add
(
def
bias_dropout_add
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
...
@@ -266,3 +315,45 @@ def warmup_jit_bias_gelu_all_dtypes(
...
@@ -266,3 +315,45 @@ def warmup_jit_bias_gelu_all_dtypes(
"""Call `warmup_jit_bias_gelu` for all training dtypes"""
"""Call `warmup_jit_bias_gelu` for all training dtypes"""
for
dtype
in
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
]:
for
dtype
in
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
]:
warmup_jit_bias_gelu
(
ffn_hidden_size
,
dtype
,
seq_length
,
micro_batch_size
)
warmup_jit_bias_gelu
(
ffn_hidden_size
,
dtype
,
seq_length
,
micro_batch_size
)
def
warmup_jit_l2normalization
(
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
seq_length
:
int
,
micro_batch_size
:
int
)
->
None
:
"""Compile L2Normalization JIT function before the main training steps"""
# Save cuda RNG state to ensure warmup does not affect reproducibility.
rng_state
=
torch
.
cuda
.
get_rng_state
()
inp
=
torch
.
rand
(
(
seq_length
*
micro_batch_size
,
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
,
)
eps
=
1e-6
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for
input_grad
in
[
False
,
True
]:
inp
.
requires_grad
=
input_grad
for
_
in
range
(
5
):
if
input_grad
:
# Test training version that returns intermediate values
output
,
rsqrt_norm
=
l2normalization_fwd_fused_
(
inp
,
eps
)
# Test backward pass as well
grad_out
=
torch
.
rand_like
(
output
)
_
=
l2normalization_backward_fused_
(
grad_out
,
inp
,
rsqrt_norm
,
eps
)
else
:
# Test inference version
output
=
l2normalization_fused_
(
inp
,
eps
)
del
inp
,
output
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
def
warmup_jit_l2normalization_all_dtypes
(
hidden_size
:
int
,
seq_length
:
int
,
micro_batch_size
:
int
)
->
None
:
"""Call `warmup_jit_l2normalization` for all training dtypes"""
for
dtype
in
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
]:
warmup_jit_l2normalization
(
hidden_size
,
dtype
,
seq_length
,
micro_batch_size
)
transformer_engine/pytorch/module/base.py
View file @
2b05e121
...
@@ -44,7 +44,7 @@ from ..tensor._internal.float8_tensor_base import Float8TensorBase
...
@@ -44,7 +44,7 @@ from ..tensor._internal.float8_tensor_base import Float8TensorBase
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..utils
import
torch_get_autocast_gpu_dtype
from
..utils
import
torch_get_autocast_gpu_dtype
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
...common.recipe
import
Recipe
from
...common.recipe
import
DelayedScaling
,
Recipe
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
,
DebugQuantizedTensor
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
,
DebugQuantizedTensor
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
@@ -89,7 +89,7 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
...
@@ -89,7 +89,7 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas."""
"""Returns workspace for multi-stream cublas."""
global
_multi_stream_cublas_workspace
global
_multi_stream_cublas_workspace
if
not
_multi_stream_cublas_workspace
:
if
not
_multi_stream_cublas_workspace
:
for
_
in
range
(
tex
.
_num_cublas_streams
):
for
_
in
range
(
tex
.
get
_num_cublas_streams
()
):
_multi_stream_cublas_workspace
.
append
(
_multi_stream_cublas_workspace
.
append
(
torch
.
empty
(
get_cublas_workspace_size_bytes
(),
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
torch
.
empty
(
get_cublas_workspace_size_bytes
(),
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
)
)
...
@@ -685,6 +685,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -685,6 +685,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Update quantizers with new amax pointers.
# Update quantizers with new amax pointers.
self
.
quantizers
[
meta_key
]
=
self
.
fp8_meta
[
meta_key
].
make_quantizers
()
self
.
quantizers
[
meta_key
]
=
self
.
fp8_meta
[
meta_key
].
make_quantizers
()
# Make sure weight tensors has correct quantizers
self
.
_update_weight_quantizers
()
# Update the global buffers with new amax and history pointers.
# Update the global buffers with new amax and history pointers.
if
FP8GlobalStateManager
.
get_buffer_info
()
in
self
.
fp8_meta
:
if
FP8GlobalStateManager
.
get_buffer_info
()
in
self
.
fp8_meta
:
...
@@ -738,6 +740,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -738,6 +740,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self
.
fp8_meta
[
fp8_meta_tensor_key
]
=
recipe_state
self
.
fp8_meta
[
fp8_meta_tensor_key
]
=
recipe_state
self
.
quantizers
[
fp8_meta_tensor_key
]
=
recipe_state
.
make_quantizers
()
self
.
quantizers
[
fp8_meta_tensor_key
]
=
recipe_state
.
make_quantizers
()
def
_update_weight_quantizers
(
self
)
->
None
:
"""Update the quantizers for the weight tensors."""
weight_tensors
=
self
.
_get_weight_tensors
()
weight_quantizers
=
self
.
_get_weight_quantizers
()
assert
len
(
weight_tensors
)
==
len
(
weight_quantizers
),
(
f
"Number of weight tensors (
{
len
(
weight_tensors
)
}
) and quantizers "
f
"(
{
len
(
weight_quantizers
)
}
) must match"
)
for
weight
,
quantizer
in
zip
(
weight_tensors
,
weight_quantizers
):
if
quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensorBase
):
weight
.
update_quantizer
(
quantizer
)
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorBase
]]:
"""Get the weight tensors of the module."""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement _get_weight_tensors function"
)
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement _get_weight_quantizers function"
)
def
init_fp8_meta_tensors
(
self
,
recipe
:
Recipe
)
->
None
:
def
init_fp8_meta_tensors
(
self
,
recipe
:
Recipe
)
->
None
:
"""Init scales and amaxes."""
"""Init scales and amaxes."""
self
.
set_meta_tensor
(
True
,
recipe
)
self
.
set_meta_tensor
(
True
,
recipe
)
...
@@ -777,7 +803,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -777,7 +803,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
reset
(
"scaling_fwd"
)
reset
(
"scaling_fwd"
)
reset
(
"scaling_bwd"
)
reset
(
"scaling_bwd"
)
def
get_extra_state
(
self
)
->
Optional
[
torch
.
Tensor
]
:
def
get_extra_state
(
self
)
->
torch
.
Tensor
:
"""Save before checkpointing."""
"""Save before checkpointing."""
# This implementation is working around a few issues:
# This implementation is working around a few issues:
...
@@ -812,7 +838,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -812,7 +838,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state
=
None
state
=
None
fp8_checkpoint
=
self
.
fp8_meta
[
"fp8_checkpoint"
]
or
self
.
fp8
or
self
.
fp8_calibration
fp8_checkpoint
=
self
.
fp8_meta
[
"fp8_checkpoint"
]
or
self
.
fp8
or
self
.
fp8_calibration
if
not
fp8_checkpoint
:
if
not
fp8_checkpoint
:
return
None
return
torch
.
empty
(
0
,
dtype
=
torch
.
uint8
)
# Copy tensors to CPU and store
# Copy tensors to CPU and store
state
=
{}
state
=
{}
...
@@ -838,13 +864,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -838,13 +864,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_serialized
=
torch
.
frombuffer
(
state_serialized
,
dtype
=
torch
.
uint8
)
state_serialized
=
torch
.
frombuffer
(
state_serialized
,
dtype
=
torch
.
uint8
)
return
state_serialized
return
state_serialized
def
set_extra_state
(
self
,
state
:
Optional
[
torch
.
Tensor
]
)
->
None
:
def
set_extra_state
(
self
,
state
:
torch
.
Tensor
)
->
None
:
"""Load previous state."""
"""Load previous state."""
# Maintain backwards compatibility with older checkpoints.
if
state
is
None
:
if
state
is
None
:
return
return
# Load state
# Load state
if
isinstance
(
state
,
torch
.
Tensor
):
if
isinstance
(
state
,
torch
.
Tensor
):
# No FP8 is indicated by an empty tensor we don't need to unpickle.
if
state
.
numel
()
==
0
:
return
# Default format: byte tensor with pickled data
# Default format: byte tensor with pickled data
state
=
pickle
.
loads
(
state
.
detach
().
cpu
().
numpy
().
tobytes
())
state
=
pickle
.
loads
(
state
.
detach
().
cpu
().
numpy
().
tobytes
())
elif
isinstance
(
state
,
io
.
BytesIO
):
elif
isinstance
(
state
,
io
.
BytesIO
):
...
@@ -857,6 +888,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -857,6 +888,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if
state
is
None
:
if
state
is
None
:
return
return
# TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing
if
"recipe"
not
in
state
:
# TE 1.x only supported delayed scaling, which was the default recipe
state
[
"recipe"
]
=
DelayedScaling
()
# TE 1.x also saved scale_inv, which is not needed with Recipe object
state
.
pop
(
"scale_inv_fwd"
,
None
)
state
.
pop
(
"scale_inv_bwd"
,
None
)
# Load extra items
# Load extra items
self
.
fp8_meta
.
update
(
state
[
"extra_fp8_variables"
])
self
.
fp8_meta
.
update
(
state
[
"extra_fp8_variables"
])
self
.
fp8_meta
[
"recipe"
]
=
state
[
"recipe"
]
self
.
fp8_meta
[
"recipe"
]
=
state
[
"recipe"
]
...
@@ -930,6 +969,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -930,6 +969,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# assume FP8 execution.
# assume FP8 execution.
def
init_fp8_metadata
(
self
,
num_gemms
:
int
=
1
)
->
None
:
def
init_fp8_metadata
(
self
,
num_gemms
:
int
=
1
)
->
None
:
"""Initialize fp8 related metadata and tensors during fprop."""
"""Initialize fp8 related metadata and tensors during fprop."""
_original_recipe
=
self
.
fp8_meta
.
get
(
"recipe"
,
None
)
self
.
fp8_parameters
=
FP8GlobalStateManager
.
with_fp8_parameters
()
self
.
fp8_parameters
=
FP8GlobalStateManager
.
with_fp8_parameters
()
self
.
fp8
=
FP8GlobalStateManager
.
is_fp8_enabled
()
self
.
fp8
=
FP8GlobalStateManager
.
is_fp8_enabled
()
self
.
fp8_calibration
=
FP8GlobalStateManager
.
is_fp8_calibration
()
self
.
fp8_calibration
=
FP8GlobalStateManager
.
is_fp8_calibration
()
...
@@ -968,6 +1009,19 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -968,6 +1009,19 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self
.
fp8_meta
[
"recipe"
]
=
FP8GlobalStateManager
.
get_fp8_recipe
()
self
.
fp8_meta
[
"recipe"
]
=
FP8GlobalStateManager
.
get_fp8_recipe
()
_current_recipe
=
self
.
fp8_meta
[
"recipe"
]
if
_original_recipe
is
not
None
and
not
(
issubclass
(
_current_recipe
.
__class__
,
_original_recipe
.
__class__
)
or
issubclass
(
_original_recipe
.
__class__
,
_current_recipe
.
__class__
)
):
warnings
.
warn
(
f
"Recipe type changed from
{
_original_recipe
.
__class__
.
__name__
}
"
f
"to
{
_current_recipe
.
__class__
.
__name__
}
. "
"This may affect model behavior."
)
# Clear cached workspaces as they were created with the old recipe/quantizer type
self
.
_fp8_workspaces
.
clear
()
@
contextmanager
@
contextmanager
def
prepare_forward
(
def
prepare_forward
(
self
,
self
,
...
@@ -992,6 +1046,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -992,6 +1046,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self
.
set_activation_dtype
(
inp
)
self
.
set_activation_dtype
(
inp
)
self
.
init_fp8_metadata
(
num_gemms
=
num_gemms
)
self
.
init_fp8_metadata
(
num_gemms
=
num_gemms
)
self
.
_check_weight_tensor_recipe_correspondence
()
if
self
.
fp8
and
self
.
sequence_parallel
and
self
.
fp8_meta
[
"recipe"
].
delayed
():
if
self
.
fp8
and
self
.
sequence_parallel
and
self
.
fp8_meta
[
"recipe"
].
delayed
():
assert
self
.
fp8_meta
[
"recipe"
].
reduce_amax
,
(
assert
self
.
fp8_meta
[
"recipe"
].
reduce_amax
,
(
...
@@ -1103,7 +1158,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1103,7 +1158,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if
(
if
(
isinstance
(
isinstance
(
grad_output_
.
get_tensor
(
True
),
grad_output_
.
get_tensor
(
True
),
(
QuantizedTensor
,
Float8TensorBase
,
MXFP8TensorBase
),
(
QuantizedTensor
,
Float8TensorBase
,
MXFP8TensorBase
,
Float8BlockwiseQTensorBase
,
),
)
)
and
ctx
.
use_bias
and
ctx
.
use_bias
):
):
...
@@ -1169,18 +1229,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1169,18 +1229,23 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
with
get_rng_state_tracker
().
fork
():
with
get_rng_state_tracker
().
fork
():
init_fn
(
param
)
init_fn
(
param
)
#
If primary weights are in fp8, wrap the parameter as FP8Tensor
#
Wrap parameters in QuantizedTensor if needed
fp8_meta_index
=
self
.
param_init_meta
[
name
].
fp8_meta_index
fp8_meta_index
=
self
.
param_init_meta
[
name
].
fp8_meta_index
high_precision_init_val
=
None
high_precision_init_val
=
None
if
self
.
primary_weights_in_fp8
and
fp8_meta_index
is
not
None
:
if
self
.
primary_weights_in_fp8
and
fp8_meta_index
is
not
None
:
# Keep high-precision values on CPU if needed
if
self
.
preserve_high_precision_init_val
:
if
self
.
preserve_high_precision_init_val
:
high_precision_init_val
=
param
.
detach
().
cpu
()
high_precision_init_val
=
param
.
detach
().
cpu
()
# Configure quantizer
quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
fp8_meta_index
]
quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
fp8_meta_index
]
assert
(
if
quantizer
is
None
:
quantizer
i
s
not
None
raise
RuntimeError
(
"Weight
quantizer
ha
s not
been initialized"
)
)
# to use primary fp8 weight one needs to use FP8 autocast with specific recipe.
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
torch
.
is_grad_enabled
())
quantizer
.
internal
=
False
quantizer
.
internal
=
False
# Quantize parameter
param
=
quantizer
(
param
)
param
=
quantizer
(
param
)
# Redo parameter wrap in case we broke it above
# Redo parameter wrap in case we broke it above
...
@@ -1188,6 +1253,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1188,6 +1253,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety.
# a parameter so we always re-apply it just for extra safety.
param
=
torch
.
nn
.
Parameter
(
param
)
param
=
torch
.
nn
.
Parameter
(
param
)
# Keep high-precision values on CPU if needed
if
high_precision_init_val
is
not
None
:
if
high_precision_init_val
is
not
None
:
# - Master weights are initialized from model weights, if we use fp8 primary
# - Master weights are initialized from model weights, if we use fp8 primary
...
@@ -1231,7 +1298,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1231,7 +1298,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fsdp_group
:
Optional
[
dist_group_type
]
=
None
,
fsdp_group
:
Optional
[
dist_group_type
]
=
None
,
workspace_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
workspace_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
QuantizedTensor
:
)
->
QuantizedTensor
:
"""Get
FP8
workspace buffer and maybe update its values
"""Get workspace buffer
for weights
and maybe update its values
The workspace buffer may be cached for future function calls.
The workspace buffer may be cached for future function calls.
...
@@ -1257,13 +1324,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1257,13 +1324,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
for debug quantization, this is dtype of the tensor.
for debug quantization, this is dtype of the tensor.
"""
"""
# FP8 primary weights
# Handle case where weights are already quantized
# Note: Make sure weights have required usages, but do not
# destroy unnecessary usages since they may be used later.
if
isinstance
(
tensor
,
QuantizedTensor
):
if
isinstance
(
tensor
,
QuantizedTensor
):
if
update_workspace
and
quantizer
is
not
None
:
update_rowwise_usage
=
True
if
quantizer
.
rowwise_usage
else
None
tensor
.
update_usage
(
update_columnwise_usage
=
True
if
quantizer
.
columnwise_usage
else
None
rowwise_usage
=
quantizer
.
rowwise_usage
,
tensor
.
update_usage
(
columnwise_usage
=
quantizer
.
columnwise_usage
,
rowwise_usage
=
update_rowwise_usage
,
)
columnwise_usage
=
update_columnwise_usage
,
)
return
tensor
return
tensor
# Try getting workspace from cache
# Try getting workspace from cache
...
@@ -1387,6 +1457,43 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1387,6 +1457,43 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
)
)
self
.
name
=
f
"Layer_
{
TEDebugState
.
get_layer_count
()
}
"
self
.
name
=
f
"Layer_
{
TEDebugState
.
get_layer_count
()
}
"
def
_check_weight_tensor_recipe_correspondence
(
self
)
->
None
:
"""
Verify that the weight tensor types match their corresponding recipe type.
This is invoked in the forward().
This establishes a 1:1 correspondence between recipe types and tensor types:
- DelayedScaling → Float8Tensor
- Float8CurrentScaling → Float8Tensor
- MXFP8BlockScaling → MXFP8Tensor
- Float8BlockScaling → Float8BlockTensor
Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()),
but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()).
"""
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
return
if
not
hasattr
(
self
,
"weight_names"
)
or
not
self
.
weight_names
:
return
recipe
=
self
.
fp8_meta
[
"recipe"
]
weight_tensors
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
for
i
,
tensor
in
enumerate
(
weight_tensors
):
if
isinstance
(
tensor
,
QuantizedTensorBase
):
quantizer
=
tensor
.
_get_quantizer
()
if
quantizer
is
None
:
continue
compatible_recipe_class
=
quantizer
.
_get_compatible_recipe
()
if
compatible_recipe_class
is
None
:
continue
if
not
isinstance
(
recipe
,
compatible_recipe_class
):
raise
RuntimeError
(
f
"Recipe mismatch for '
{
self
.
weight_names
[
i
]
}
': tensor supports recipe"
f
"
{
compatible_recipe_class
.
__name__
}
, but got
{
recipe
.
__class__
.
__name__
}
."
" Please check the recipes assigned during fp8_model_init() and"
" fp8_autocast() calls."
)
def
_turn_off_unsupported_features_in_debug
(
self
):
def
_turn_off_unsupported_features_in_debug
(
self
):
if
(
if
(
getattr
(
self
,
"ub_bulk_wgrad"
,
False
)
getattr
(
self
,
"ub_bulk_wgrad"
,
False
)
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
2b05e121
...
@@ -242,8 +242,8 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -242,8 +242,8 @@ class _GroupedLinear(torch.autograd.Function):
biases
=
saved_tensors
[
3
*
N
:
4
*
N
]
biases
=
saved_tensors
[
3
*
N
:
4
*
N
]
main_grads
=
ctx
.
main_grads
main_grads
=
ctx
.
main_grads
if
ctx
.
cpu_offloading
and
ctx
.
fuse_wgrad_accumulation
:
# TOSO
if
ctx
.
cpu_offloading
and
ctx
.
fuse_wgrad_accumulation
:
for
i
in
ctx
.
num_gemms
:
for
i
in
range
(
ctx
.
num_gemms
)
:
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
.
main_grad
=
main_grads
[
i
]
w
.
main_grad
=
main_grads
[
i
]
weights
[
i
]
=
w
weights
[
i
]
=
w
...
@@ -673,26 +673,19 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -673,26 +673,19 @@ class GroupedLinear(TransformerEngineBaseModule):
),
"GroupedLinear doesn't support input tensor in FP8."
),
"GroupedLinear doesn't support input tensor in FP8."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
skip_fp8_weight_update
=
FP8GlobalStateManager
.
get_skip_fp8_weight_update_tensor
()
if
FP8GlobalStateManager
.
fp8_graph_capturing
():
skip_fp8_weight_update
=
FP8GlobalStateManager
.
get_skip_fp8_weight_update_tensor
()
else
:
skip_fp8_weight_update
=
None
if
skip_fp8_weight_update
is
not
None
:
if
skip_fp8_weight_update
is
not
None
:
is_first_microbatch
=
False
is_first_microbatch
=
False
with
self
.
prepare_forward
(
inp
,
num_gemms
=
self
.
num_gemms
)
as
inp
:
with
self
.
prepare_forward
(
inp
,
num_gemms
=
self
.
num_gemms
)
as
inp
:
weight_tensors
=
self
.
_get_weight_tensors
()
weight_tensors
=
[
getattr
(
self
,
f
"weight
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
bias_tensors
=
[
getattr
(
self
,
f
"bias
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
bias_tensors
=
[
getattr
(
self
,
f
"bias
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
if
not
self
.
fp8
and
any
(
isinstance
(
w
,
QuantizedTensorBase
)
for
w
in
weight_tensors
):
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors
=
[
w
.
dequantize
()
if
isinstance
(
w
,
QuantizedTensorBase
)
else
w
for
w
in
weight_tensors
]
inpu
t_quantizers
,
weight_quantizers
,
output_quantizers
=
(
weigh
t_quantizers
=
self
.
_get_
weight_quantizers
(
)
[
None
]
*
self
.
num_gemms
,
input_quantizers
,
output_quantizers
=
(
[
None
]
*
self
.
num_gemms
,
[
None
]
*
self
.
num_gemms
,
[
None
]
*
self
.
num_gemms
,
[
None
]
*
self
.
num_gemms
,
)
)
...
@@ -707,14 +700,6 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -707,14 +700,6 @@ class GroupedLinear(TransformerEngineBaseModule):
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for
i
in
range
(
self
.
num_gemms
):
for
i
in
range
(
self
.
num_gemms
):
input_quantizers
[
i
].
internal
=
False
input_quantizers
[
i
].
internal
=
False
weight_quantizers
=
[
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"weight"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"fwd"
]
]
for
i
in
range
(
self
.
num_gemms
)
]
for
i
in
range
(
self
.
num_gemms
):
weight_quantizers
[
i
].
internal
=
True
if
torch
.
is_grad_enabled
():
if
torch
.
is_grad_enabled
():
grad_output_quantizers
=
[
grad_output_quantizers
=
[
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
...
@@ -813,3 +798,30 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -813,3 +798,30 @@ class GroupedLinear(TransformerEngineBaseModule):
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"bwd"
]
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"bwd"
]
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorBase
]]:
"""Get the weight tensors of the module."""
weight_tensors
=
[
getattr
(
self
,
f
"weight
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
if
not
self
.
fp8
and
any
(
isinstance
(
w
,
QuantizedTensorBase
)
for
w
in
weight_tensors
):
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors
=
[
w
.
dequantize
()
if
isinstance
(
w
,
QuantizedTensorBase
)
else
w
for
w
in
weight_tensors
]
return
weight_tensors
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
if
not
self
.
fp8
:
return
[
None
]
*
self
.
num_gemms
weight_quantizers
=
[
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"weight"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"fwd"
]
]
for
i
in
range
(
self
.
num_gemms
)
]
for
i
in
range
(
self
.
num_gemms
):
weight_quantizers
[
i
].
internal
=
True
return
weight_quantizers
transformer_engine/pytorch/module/layernorm_linear.py
View file @
2b05e121
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
"""LayerNormLinear API"""
"""LayerNormLinear API"""
import
os
import
os
import
warnings
import
warnings
from
typing
import
Callable
,
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
Dict
,
Optional
,
Tuple
,
Union
,
List
from
functools
import
reduce
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
from
operator
import
mul
as
multiply_op
...
@@ -63,10 +63,11 @@ from ..tensor.quantized_tensor import (
...
@@ -63,10 +63,11 @@ from ..tensor.quantized_tensor import (
)
)
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.utils
import
any_feature_enabled
from
...debug.pytorch.utils
import
any_feature_enabled
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..cpp_extensions
import
(
from
..cpp_extensions
import
(
...
@@ -190,19 +191,13 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -190,19 +191,13 @@ class _LayerNormLinear(torch.autograd.Function):
# All-gather is not supported with FP8 column-wise data
# All-gather is not supported with FP8 column-wise data
input_quantizer
.
set_usage
(
columnwise
=
False
)
input_quantizer
.
set_usage
(
columnwise
=
False
)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_blockwise_ln_out_gather
=
(
fp8
and
with_input_all_gather
and
isinstance
(
input_quantizer
,
Float8BlockQuantizer
)
)
# Avoid quantized norm kernel if norm output will be returned
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
# or if a gather of ln_out must be in high precision.
with_quantized_norm
=
(
with_quantized_norm
=
(
fp8
fp8
and
not
debug
and
not
return_layernorm_output
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
and
not
return_layernorm_output_gathered
and
not
force_hp_blockwise_ln_out_gather
)
)
# Apply normalization
# Apply normalization
...
@@ -239,15 +234,16 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -239,15 +234,16 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
)
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
)
ln_out_return
=
ln_out_total
ln_out_return
=
ln_out_total
if
fp8
or
debug
:
if
fp8
or
debug
:
if
not
force_hp_blockwise_ln_out_gather
:
ln_out
=
input_quantizer
(
ln_out
)
ln_out
=
input_quantizer
(
ln_out
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
isinstance
(
input_quantizer
,
Float8BlockQuantizer
):
input_quantizer
.
all_gather_usage
=
False
ln_out_total
=
input_quantizer
(
ln_out_total
)
ln_out_total
=
input_quantizer
(
ln_out_total
)
else
:
else
:
quantizer
=
None
quantizer
=
None
if
fp8
or
debug
:
if
fp8
or
debug
:
quantizer
=
input_quantizer
quantizer
=
input_quantizer
if
not
with_quantized_norm
and
not
force_hp_blockwise_ln_out_gather
:
if
not
with_quantized_norm
:
ln_out
=
quantizer
(
ln_out
)
ln_out
=
quantizer
(
ln_out
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag_fprop
:
# Initialize Userbuffers all-gather
if
ub_overlap_ag_fprop
:
# Initialize Userbuffers all-gather
...
@@ -282,7 +278,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -282,7 +278,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Configure quantizer
# Configure quantizer
if
weight_quantizer
is
not
None
:
if
weight_quantizer
is
not
None
:
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
is_grad_enabled
)
# Get quantized weight
# Get quantized weight
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
...
@@ -397,7 +393,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -397,7 +393,6 @@ class _LayerNormLinear(torch.autograd.Function):
ctx
.
ln_out_needs_gather
=
(
ctx
.
ln_out_needs_gather
=
(
weight
.
requires_grad
and
parallel_mode
==
"column"
and
sequence_parallel
weight
.
requires_grad
and
parallel_mode
==
"column"
and
sequence_parallel
)
)
ctx
.
force_hp_blockwise_ln_out_gather
=
force_hp_blockwise_ln_out_gather
# Input with column-wise usage is needed for wgrad GEMM.
# Input with column-wise usage is needed for wgrad GEMM.
if
backward_needs_input
:
if
backward_needs_input
:
...
@@ -405,7 +400,10 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -405,7 +400,10 @@ class _LayerNormLinear(torch.autograd.Function):
# For sequence parallel in vanilla FP8, rowwise data is
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
# can be allgathered.
if
isinstance
(
ln_out
,
MXFP8TensorBase
)
or
not
ctx
.
ln_out_needs_gather
:
if
(
isinstance
(
ln_out
,
(
MXFP8TensorBase
,
Float8BlockwiseQTensorBase
))
or
not
ctx
.
ln_out_needs_gather
):
ln_out
.
update_usage
(
rowwise_usage
=
False
)
ln_out
.
update_usage
(
rowwise_usage
=
False
)
# Weight with column-wise usage is needed for dgrad GEMM.
# Weight with column-wise usage is needed for dgrad GEMM.
...
@@ -502,8 +500,8 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -502,8 +500,8 @@ class _LayerNormLinear(torch.autograd.Function):
if
return_layernorm_output
:
if
return_layernorm_output
:
if
return_layernorm_output_gathered
:
if
return_layernorm_output_gathered
:
shape
=
list
(
inp
_
shape
)
shape
=
list
(
inp
.
shape
)
shape
[
0
]
*=
tp_size
shape
[
0
]
*=
tp_size
if
with_input_all_gather
else
1
return
out
,
ln_out_return
.
view
(
shape
)
return
out
,
ln_out_return
.
view
(
shape
)
return
out
,
ln_out_return
.
view
(
inp_shape
)
return
out
,
ln_out_return
.
view
(
inp_shape
)
return
out
return
out
...
@@ -637,7 +635,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -637,7 +635,7 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_work
=
None
ln_out_total_work
=
None
if
ctx
.
ln_out_needs_gather
:
if
ctx
.
ln_out_needs_gather
:
quantizer
=
None
quantizer
=
None
if
ctx
.
input_quantizer
is
not
None
and
not
ctx
.
force_hp_blockwise_ln_out_gather
:
if
ctx
.
input_quantizer
is
not
None
:
quantizer
=
ctx
.
input_quantizer
quantizer
=
ctx
.
input_quantizer
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
# If data is in FP8, we compute FP8 transposes manually
# If data is in FP8, we compute FP8 transposes manually
...
@@ -752,6 +750,31 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -752,6 +750,31 @@ class _LayerNormLinear(torch.autograd.Function):
wgrad
=
None
wgrad
=
None
if
ctx
.
requires_wgrad
:
if
ctx
.
requires_wgrad
:
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
grad_output_quantizer
,
MXFP8Quantizer
):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream
=
ub_obj_dgrad
.
get_communication_stream
()
with
torch
.
cuda
.
stream
(
dgrad_comm_stream
):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output
,
mxfp8_grad_output_work
=
gather_along_first_dim
(
grad_outputs
[
0
],
ctx
.
tp_group
,
async_op
=
True
,
quantizer
=
ctx
.
grad_output_quantizer
,
)
# Synchronize with the main stream
mxfp8_grad_output_work
.
wait
()
# Prepare input tensor
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# Note: Synchronize tensor-parallel communication and
...
@@ -766,22 +789,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -766,22 +789,6 @@ class _LayerNormLinear(torch.autograd.Function):
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ln_out_total
=
ctx
.
input_quantizer
(
ln_out_total
)
ln_out_total
=
ctx
.
input_quantizer
(
ln_out_total
)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
grad_output_quantizer
,
MXFP8Quantizer
):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output
,
_
=
gather_along_first_dim
(
grad_outputs
[
0
],
ctx
.
tp_group
,
quantizer
=
ctx
.
grad_output_quantizer
,
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
grad_output
.
update_usage
(
columnwise_usage
=
True
)
...
@@ -1389,6 +1396,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1389,6 +1396,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_current_scaling
():
if
recipe
.
float8_current_scaling
():
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
# elif other recipes (mxfp8, etc)
# elif other recipes (mxfp8, etc)
def
reset_layer_norm_parameters
(
self
)
->
None
:
def
reset_layer_norm_parameters
(
self
)
->
None
:
...
@@ -1484,20 +1493,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1484,20 +1493,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
)
as
inp
:
)
as
inp
:
# Get concatenated weight and bias tensors
# Get concatenated weight and bias tensors
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
unfused_weights
=
self
.
_get_weight_tensors
()
if
any
(
isinstance
(
w
,
QuantizedTensor
)
for
w
in
unfused_weights
):
if
self
.
fp8
:
if
len
(
unfused_weights
)
!=
1
:
raise
RuntimeError
(
"Splitting QuantizedTensor into multiple params is not supported"
)
else
:
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights
=
[
w
.
dequantize
()
for
w
in
unfused_weights
]
weight_tensor
=
noop_cat
(
unfused_weights
)
weight_tensor
=
noop_cat
(
unfused_weights
)
if
self
.
use_bias
:
if
self
.
use_bias
:
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
...
@@ -1603,8 +1599,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1603,8 +1599,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
output_quantizer
=
None
output_quantizer
=
None
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
input_quantizer
.
internal
=
True
input_quantizer
.
internal
=
True
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
(
weight_quantizer
,)
=
self
.
_get_weight_quantizers
()
weight_quantizer
.
internal
=
True
if
fp8_output
:
if
fp8_output
:
output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_OUTPUT
]
output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_OUTPUT
]
if
torch
.
is_grad_enabled
():
if
torch
.
is_grad_enabled
():
...
@@ -1679,3 +1674,39 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1679,3 +1674,39 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorBase
]]:
"""Get the weight tensors of the module."""
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
if
any
(
isinstance
(
w
,
QuantizedTensor
)
for
w
in
unfused_weights
):
if
self
.
fp8
:
if
len
(
unfused_weights
)
!=
1
:
raise
RuntimeError
(
"Splitting QuantizedTensor into multiple params is not supported"
)
else
:
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights
=
[
w
.
dequantize
()
for
w
in
unfused_weights
]
return
unfused_weights
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
if
not
self
.
fp8
:
return
[
None
]
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
.
internal
=
True
return
[
weight_quantizer
]
def
_customize_quantizers_float8_blockwise_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on blockwise scaling recipe + layernorm_linear."""
assert
(
recipe
.
float8_block_scaling
()
),
"blockwise scaling recipe quantizer customization here"
if
fwd
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"column"
:
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
all_gather_usage
=
True
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
2b05e121
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
"""LayerNormMLP API"""
"""LayerNormMLP API"""
import
os
import
os
import
warnings
import
warnings
from
typing
import
Callable
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
Optional
,
Tuple
,
Union
,
List
from
functools
import
reduce
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
from
operator
import
mul
as
multiply_op
...
@@ -244,26 +244,18 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -244,26 +244,18 @@ class _LayerNormMLP(torch.autograd.Function):
# All-gather is not supported with FP8 column-wise data
# All-gather is not supported with FP8 column-wise data
fc1_input_quantizer
.
set_usage
(
columnwise
=
False
)
fc1_input_quantizer
.
set_usage
(
columnwise
=
False
)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_fc1_input_gather
=
(
fp8
and
sequence_parallel
and
isinstance
(
fc1_input_quantizer
,
Float8BlockQuantizer
)
)
# for fp8 DelayedScaling: layernorm output = FP8
# for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
# only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned
# high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm
# for debug: : layernorm output = High precision to enable processing of this norm
with_quantized_norm
=
(
with_quantized_norm
=
(
fp8
fp8
and
not
debug
and
not
return_layernorm_output
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
and
not
return_layernorm_output_gathered
and
not
debug
)
)
if
isinstance
(
fc1_input_quantizer
,
Float8BlockQuantizer
):
# Kernels not available for norm fusion.
with_quantized_norm
=
False
# Apply normalization
# Apply normalization
ln_out
,
mu
,
rsigma
=
apply_normalization
(
ln_out
,
mu
,
rsigma
=
apply_normalization
(
...
@@ -293,15 +285,16 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -293,15 +285,16 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
)
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
)
ln_out_return
=
ln_out_total
ln_out_return
=
ln_out_total
if
fp8
or
debug
:
if
fp8
or
debug
:
if
not
force_hp_fc1_input_gather
:
ln_out
=
fc1_input_quantizer
(
ln_out
)
ln_out
=
fc1_input_quantizer
(
ln_out
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
isinstance
(
fc1_input_quantizer
,
Float8BlockQuantizer
):
fc1_input_quantizer
.
all_gather_usage
=
False
ln_out_total
=
fc1_input_quantizer
(
ln_out_total
)
ln_out_total
=
fc1_input_quantizer
(
ln_out_total
)
else
:
else
:
quantizer
=
None
quantizer
=
None
if
fp8
or
debug
:
if
fp8
or
debug
:
quantizer
=
fc1_input_quantizer
quantizer
=
fc1_input_quantizer
if
not
with_quantized_norm
and
not
force_hp_fc1_input_gather
:
if
not
with_quantized_norm
:
ln_out
=
fc1_input_quantizer
(
ln_out
)
ln_out
=
fc1_input_quantizer
(
ln_out
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag
:
if
ub_overlap_ag
:
...
@@ -333,8 +326,8 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -333,8 +326,8 @@ class _LayerNormMLP(torch.autograd.Function):
# which handles weight caching etc.
# which handles weight caching etc.
# FP8 cast to workspace buffer
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
fc1_weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
fc1_weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
is_grad_enabled
)
fc2_weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
fc2_weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
is_grad_enabled
)
fc1_weight_final
=
module
.
get_weight_workspace
(
fc1_weight_final
=
module
.
get_weight_workspace
(
tensor
=
fc1_weight
,
tensor
=
fc1_weight
,
quantizer
=
fc1_weight_quantizer
,
quantizer
=
fc1_weight_quantizer
,
...
@@ -567,7 +560,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -567,7 +560,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
tensor_objects
=
tensor_objects
ctx
.
tensor_objects
=
tensor_objects
ctx
.
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
ctx
.
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
ctx
.
force_hp_fc1_input_gather
=
force_hp_fc1_input_gather
ctx
.
fc1_grad_input_quantizer
=
fc1_grad_input_quantizer
ctx
.
fc1_grad_input_quantizer
=
fc1_grad_input_quantizer
ctx
.
fc1_grad_weight_quantizer
=
fc1_grad_weight_quantizer
ctx
.
fc1_grad_weight_quantizer
=
fc1_grad_weight_quantizer
ctx
.
fc1_grad_output_quantizer
=
fc1_grad_output_quantizer
ctx
.
fc1_grad_output_quantizer
=
fc1_grad_output_quantizer
...
@@ -628,7 +620,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -628,7 +620,7 @@ class _LayerNormMLP(torch.autograd.Function):
if
return_layernorm_output
:
if
return_layernorm_output
:
if
return_layernorm_output_gathered
:
if
return_layernorm_output_gathered
:
shape
=
list
(
inp_shape
)
shape
=
list
(
inp_shape
)
shape
[
0
]
*=
tp_size
shape
[
0
]
*=
tp_size
if
(
sequence_parallel
and
set_parallel_mode
)
else
1
return
fc2_out
,
ln_out_return
.
view
(
shape
)
return
fc2_out
,
ln_out_return
.
view
(
shape
)
return
fc2_out
,
ln_out_return
.
view
(
inp_shape
)
return
fc2_out
,
ln_out_return
.
view
(
inp_shape
)
return
fc2_out
return
fc2_out
...
@@ -743,7 +735,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -743,7 +735,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc1_dgrad
=
None
ub_obj_fc1_dgrad
=
None
if
ctx
.
fc1_weight_requires_grad
and
ctx
.
tensor_parallel
and
ctx
.
sequence_parallel
:
if
ctx
.
fc1_weight_requires_grad
and
ctx
.
tensor_parallel
and
ctx
.
sequence_parallel
:
quantizer
=
None
quantizer
=
None
if
ctx
.
fp8
or
ctx
.
debug
and
not
ctx
.
force_hp_fc1_input_gather
:
if
ctx
.
fp8
or
ctx
.
debug
:
quantizer
=
ctx
.
fc1_input_quantizer
quantizer
=
ctx
.
fc1_input_quantizer
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
# If data is in FP8, we compute FP8 transposes manually
# If data is in FP8, we compute FP8 transposes manually
...
@@ -841,6 +833,30 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -841,6 +833,30 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_wgrad
=
None
fc2_wgrad
=
None
if
ctx
.
fc2_weight_requires_grad
:
if
ctx
.
fc2_weight_requires_grad
:
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
fc2_grad_output_quantizer
,
MXFP8Quantizer
):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream
=
ub_obj_fc2_dgrad
.
get_communication_stream
()
with
torch
.
cuda
.
stream
(
dgrad_comm_stream
):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output
,
mxfp8_fc2_grad_output_work
=
gather_along_first_dim
(
grad_outputs
[
0
],
ctx
.
tp_group
,
async_op
=
True
,
quantizer
=
ctx
.
fc2_grad_output_quantizer
,
)
# Synchronize with the main stream
mxfp8_fc2_grad_output_work
.
wait
()
# Prepare input tensor
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# Note: Synchronize tensor-parallel communication and
...
@@ -852,22 +868,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -852,22 +868,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
fc2_input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
fc2_input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
act_out
=
ctx
.
fc2_input_quantizer
(
act_out
)
act_out
=
ctx
.
fc2_input_quantizer
(
act_out
)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
ctx
.
ub_overlap_ag
and
isinstance
(
ctx
.
fc2_grad_output_quantizer
,
MXFP8Quantizer
):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output
,
_
=
gather_along_first_dim
(
grad_outputs
[
0
],
ctx
.
tp_group
,
quantizer
=
ctx
.
fc2_grad_output_quantizer
,
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
grad_output
.
update_usage
(
columnwise_usage
=
True
)
...
@@ -1661,8 +1661,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1661,8 +1661,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
super
().
set_meta_tensor
(
fwd
,
recipe
)
super
().
set_meta_tensor
(
fwd
,
recipe
)
# customize quantizers based on each recipe & layer configs
# customize quantizers based on each recipe & layer configs
if
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_current_scaling
():
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_current_scaling
():
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
# elif for other recipes (mxfp8, etc.)
# elif for other recipes (mxfp8, etc.)
def
reset_layer_norm_parameters
(
self
)
->
None
:
def
reset_layer_norm_parameters
(
self
)
->
None
:
...
@@ -1772,15 +1775,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1772,15 +1775,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
)
=
quantizers
)
=
quantizers
# Get weight tensors
# Get weight tensors
fc1_weight
=
self
.
fc1
_weight
fc1_weight
,
fc2_weight
=
self
.
_get
_weight
_tensors
()
fc1_bias
=
self
.
fc1_bias
if
self
.
use_bias
else
None
fc1_bias
=
self
.
fc1_bias
if
self
.
use_bias
else
None
fc2_weight
=
self
.
fc2_weight
fc2_bias
=
self
.
fc2_bias
if
self
.
use_bias
else
None
fc2_bias
=
self
.
fc2_bias
if
self
.
use_bias
else
None
if
not
self
.
fp8
:
if
not
self
.
fp8
:
if
isinstance
(
fc1_weight
,
Float8Tensor
):
if
isinstance
(
fc1_weight
,
Float8Tensor
):
fc1_weight
=
fc1_weight
.
from_float8
()
fc1_weight
=
fc1_weight
.
dequantize
()
if
isinstance
(
fc2_weight
,
Float8Tensor
):
if
isinstance
(
fc2_weight
,
Float8Tensor
):
fc2_weight
=
fc2_weight
.
from_float8
()
fc2_weight
=
fc2_weight
.
dequantize
()
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if
(
not
IS_HIP_EXTENSION
if
(
not
IS_HIP_EXTENSION
...
@@ -1866,31 +1868,26 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1866,31 +1868,26 @@ class LayerNormMLP(TransformerEngineBaseModule):
def
_get_quantizers
(
self
,
fp8_output
):
def
_get_quantizers
(
self
,
fp8_output
):
(
(
fc1_input_quantizer
,
fc1_input_quantizer
,
fc1_weight_quantizer
,
fc1_output_quantizer
,
fc1_output_quantizer
,
fc1_grad_input_quantizer
,
fc1_grad_input_quantizer
,
fc1_grad_weight_quantizer
,
fc1_grad_weight_quantizer
,
fc1_grad_output_quantizer
,
fc1_grad_output_quantizer
,
fc2_input_quantizer
,
fc2_input_quantizer
,
fc2_weight_quantizer
,
fc2_output_quantizer
,
fc2_output_quantizer
,
fc2_grad_input_quantizer
,
fc2_grad_input_quantizer
,
fc2_grad_weight_quantizer
,
fc2_grad_weight_quantizer
,
fc2_grad_output_quantizer
,
fc2_grad_output_quantizer
,
)
=
[
None
]
*
12
)
=
[
None
]
*
10
fc1_weight_quantizer
,
fc2_weight_quantizer
=
self
.
_get_weight_quantizers
()
if
self
.
fp8
:
if
self
.
fp8
:
fc1_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
fc1_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
fc1_input_quantizer
.
internal
=
True
fc1_input_quantizer
.
internal
=
True
fc1_weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
fc1_weight_quantizer
.
internal
=
True
fc2_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_INPUT
]
fc2_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_INPUT
]
fc2_input_quantizer
.
set_usage
(
fc2_input_quantizer
.
set_usage
(
rowwise
=
True
,
rowwise
=
True
,
columnwise
=
isinstance
(
fc2_input_quantizer
,
(
MXFP8Quantizer
,
Float8BlockQuantizer
)),
columnwise
=
isinstance
(
fc2_input_quantizer
,
(
MXFP8Quantizer
,
Float8BlockQuantizer
)),
)
)
fc1_input_quantizer
.
internal
=
True
fc1_input_quantizer
.
internal
=
True
fc2_weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_WEIGHT
]
fc2_weight_quantizer
.
internal
=
True
if
fp8_output
:
if
fp8_output
:
fc2_output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
fc2_output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_OUTPUT
tex
.
FP8FwdTensors
.
GEMM2_OUTPUT
...
@@ -2007,6 +2004,36 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -2007,6 +2004,36 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
].
amax_reduction_group
=
self
.
tp_group
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorBase
]]:
"""Get the weight tensors of the module."""
return
[
self
.
fc1_weight
,
self
.
fc2_weight
]
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
if
not
self
.
fp8
:
return
[
None
,
None
]
fc1_weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
fc1_weight_quantizer
.
internal
=
True
fc2_weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_WEIGHT
]
fc2_weight_quantizer
.
internal
=
True
return
[
fc1_weight_quantizer
,
fc2_weight_quantizer
]
def
_customize_quantizers_float8_blockwise_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on blockwise scaling recipe + layernorm_mlp."""
assert
(
recipe
.
float8_block_scaling
()
),
"blockwise scaling recipe quantizer customization here"
if
fwd
:
if
self
.
sequence_parallel
and
self
.
set_parallel_mode
:
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
all_gather_usage
=
True
else
:
if
self
.
sequence_parallel
and
self
.
set_parallel_mode
:
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
].
all_gather_usage
=
True
def
backward_dw
(
self
):
def
backward_dw
(
self
):
"""
"""
Execute the delayed weight gradient computation.
Execute the delayed weight gradient computation.
...
...
transformer_engine/pytorch/module/linear.py
View file @
2b05e121
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
# See LICENSE for license information.
# See LICENSE for license information.
"""Linear API"""
"""Linear API"""
from
typing
import
Callable
,
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
Dict
,
Optional
,
Tuple
,
Union
,
List
from
functools
import
reduce
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
from
operator
import
mul
as
multiply_op
import
warnings
import
warnings
...
@@ -67,7 +67,7 @@ from ..tensor.quantized_tensor import (
...
@@ -67,7 +67,7 @@ from ..tensor.quantized_tensor import (
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor.float8_blockwise_tensor
import
Float8Block
Quantizer
from
..tensor.
_internal.
float8_blockwise_tensor
_base
import
Float8Block
wiseQTensorBase
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.utils
import
any_feature_enabled
from
...debug.pytorch.utils
import
any_feature_enabled
...
@@ -137,12 +137,6 @@ class _Linear(torch.autograd.Function):
...
@@ -137,12 +137,6 @@ class _Linear(torch.autograd.Function):
parallel_mode
==
"column"
and
sequence_parallel
and
not
ub_overlap_ag_fprop
parallel_mode
==
"column"
and
sequence_parallel
and
not
ub_overlap_ag_fprop
)
)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_input_gather
=
(
fp8
and
with_input_all_gather_nccl
and
isinstance
(
input_quantizer
,
Float8BlockQuantizer
)
)
# Configure Userbuffers communication (comm+GEMM overlap)
# Configure Userbuffers communication (comm+GEMM overlap)
ub_obj
=
None
ub_obj
=
None
ub_type
=
None
ub_type
=
None
...
@@ -169,7 +163,7 @@ class _Linear(torch.autograd.Function):
...
@@ -169,7 +163,7 @@ class _Linear(torch.autograd.Function):
if
fp8
or
debug
:
if
fp8
or
debug
:
if
input_quantizer
is
None
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
raise
ValueError
(
"Missing quantizer for input tensor"
)
if
not
force_hp_input_gather
and
not
isinstance
(
inputmat
,
QuantizedTensorBase
):
if
not
isinstance
(
inputmat
,
QuantizedTensorBase
):
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
if
isinstance
(
if
isinstance
(
input_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
input_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
...
@@ -348,10 +342,11 @@ class _Linear(torch.autograd.Function):
...
@@ -348,10 +342,11 @@ class _Linear(torch.autograd.Function):
# For sequence parallel in vanilla FP8, rowwise data is
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
# can be allgathered.
if
isinstance
(
inputmat
,
MXFP8TensorBase
)
or
not
ctx
.
backward_input_needs_gather
:
if
(
isinstance
(
inputmat
,
(
MXFP8TensorBase
,
Float8BlockwiseQTensorBase
))
or
not
ctx
.
backward_input_needs_gather
):
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
if
force_hp_input_gather
:
assert
not
isinstance
(
inputmat
,
QuantizedTensorBase
)
saved_inputmat
=
inputmat
saved_inputmat
=
inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
# Weight with column-wise usage is needed for dgrad GEMM.
...
@@ -397,7 +392,6 @@ class _Linear(torch.autograd.Function):
...
@@ -397,7 +392,6 @@ class _Linear(torch.autograd.Function):
ctx
.
activation_dtype
=
activation_dtype
ctx
.
activation_dtype
=
activation_dtype
ctx
.
fp8
=
fp8
ctx
.
fp8
=
fp8
ctx
.
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
ctx
.
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
ctx
.
force_hp_input_gather
=
force_hp_input_gather
ctx
.
input_quantizer
=
input_quantizer
ctx
.
input_quantizer
=
input_quantizer
ctx
.
grad_input_quantizer
=
grad_input_quantizer
ctx
.
grad_input_quantizer
=
grad_input_quantizer
ctx
.
grad_weight_quantizer
=
grad_weight_quantizer
ctx
.
grad_weight_quantizer
=
grad_weight_quantizer
...
@@ -558,7 +552,7 @@ class _Linear(torch.autograd.Function):
...
@@ -558,7 +552,7 @@ class _Linear(torch.autograd.Function):
inputmat_total_work
=
None
inputmat_total_work
=
None
if
ctx
.
backward_input_needs_gather
:
if
ctx
.
backward_input_needs_gather
:
quantizer
=
None
quantizer
=
None
if
(
ctx
.
fp8
or
ctx
.
debug
)
and
not
ctx
.
force_hp_input_gather
:
if
ctx
.
fp8
or
ctx
.
debug
:
quantizer
=
ctx
.
input_quantizer
quantizer
=
ctx
.
input_quantizer
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
# If data is in FP8, we compute FP8 transposes manually
# If data is in FP8, we compute FP8 transposes manually
...
@@ -696,14 +690,23 @@ class _Linear(torch.autograd.Function):
...
@@ -696,14 +690,23 @@ class _Linear(torch.autograd.Function):
# all-gather with wgrad GEMM. Also, we can't
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around
with blocking
# for the dgrad GEMM. We work around
by explicitly
#
all-gather for column-scaled MXFP8 data
.
#
overlapping the NCCL operation with the dgrad GEMM
.
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output
,
_
=
gather_along_first_dim
(
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
grad_output_arg
,
dgrad_comm_stream
=
ub_obj_dgrad
.
get_communication_stream
()
ctx
.
tp_group
,
with
torch
.
cuda
.
stream
(
dgrad_comm_stream
):
quantizer
=
ctx
.
grad_output_quantizer
,
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
)
# This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output
,
grad_output_work
=
gather_along_first_dim
(
grad_output_arg
,
ctx
.
tp_group
,
async_op
=
True
,
quantizer
=
ctx
.
grad_output_quantizer
,
)
# Synchronize with the main stream
grad_output_work
.
wait
()
if
ctx
.
fp8
or
ctx
.
debug
:
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
if
isinstance
(
grad_output
,
QuantizedTensorBase
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
grad_output
.
update_usage
(
columnwise_usage
=
True
)
...
@@ -1218,6 +1221,8 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1218,6 +1221,8 @@ class Linear(TransformerEngineBaseModule):
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_current_scaling
():
if
recipe
.
float8_current_scaling
():
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
# elif for other recipes (mxfp8, etc.)
# elif for other recipes (mxfp8, etc.)
def
reset_parameters
(
self
,
defer_init
=
False
):
def
reset_parameters
(
self
,
defer_init
=
False
):
...
@@ -1294,20 +1299,7 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1294,20 +1299,7 @@ class Linear(TransformerEngineBaseModule):
)
as
inp
:
)
as
inp
:
# Get concatenated weight and bias tensors
# Get concatenated weight and bias tensors
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
unfused_weights
=
self
.
_get_weight_tensors
()
if
any
(
isinstance
(
w
,
QuantizedTensor
)
for
w
in
unfused_weights
):
if
self
.
fp8
:
if
len
(
unfused_weights
)
!=
1
:
raise
RuntimeError
(
"Splitting QuantizedTensor into multiple params is not supported"
)
else
:
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights
=
[
w
.
dequantize
()
for
w
in
unfused_weights
]
weight_tensor
=
noop_cat
(
unfused_weights
)
weight_tensor
=
noop_cat
(
unfused_weights
)
if
self
.
use_bias
:
if
self
.
use_bias
:
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
...
@@ -1337,12 +1329,6 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1337,12 +1329,6 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer
,
grad_output_quantizer
,
)
=
quantizers
)
=
quantizers
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
# recipe changed
if
weight_quantizer
is
not
None
and
isinstance
(
weight_tensor
,
QuantizedTensor
):
weight_tensor
.
_quantizer
=
weight_quantizer
if
torch
.
is_grad_enabled
():
if
torch
.
is_grad_enabled
():
linear_fn
=
_Linear
.
apply
linear_fn
=
_Linear
.
apply
args
=
[]
args
=
[]
...
@@ -1403,8 +1389,7 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1403,8 +1389,7 @@ class Linear(TransformerEngineBaseModule):
output_quantizer
=
None
output_quantizer
=
None
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
input_quantizer
.
internal
=
True
input_quantizer
.
internal
=
True
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
(
weight_quantizer
,)
=
self
.
_get_weight_quantizers
()
weight_quantizer
.
internal
=
True
if
fp8_output
:
if
fp8_output
:
output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_OUTPUT
]
output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_OUTPUT
]
if
torch
.
is_grad_enabled
():
if
torch
.
is_grad_enabled
():
...
@@ -1478,3 +1463,47 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1478,3 +1463,47 @@ class Linear(TransformerEngineBaseModule):
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorBase
]]:
"""Get the weight tensors of the module."""
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
if
any
(
isinstance
(
w
,
QuantizedTensor
)
for
w
in
unfused_weights
):
if
self
.
fp8
:
if
len
(
unfused_weights
)
!=
1
:
raise
RuntimeError
(
"Splitting QuantizedTensor into multiple params is not supported"
)
else
:
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights
=
[
w
.
dequantize
()
for
w
in
unfused_weights
]
return
unfused_weights
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
if
not
self
.
fp8
:
return
[
None
]
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
.
internal
=
True
return
[
weight_quantizer
]
def
_customize_quantizers_float8_blockwise_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on blockwise scaling recipe + linear."""
assert
(
recipe
.
float8_block_scaling
()
),
"blockwise scaling recipe quantizer customization here"
if
fwd
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"column"
:
# set compact for inp tensor X
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
all_gather_usage
=
True
else
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"row"
:
# set compact for grad_output tensor dY
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
all_gather_usage
=
True
transformer_engine/pytorch/ops/basic/__init__.py
View file @
2b05e121
...
@@ -11,6 +11,7 @@ from .all_reduce import AllReduce
...
@@ -11,6 +11,7 @@ from .all_reduce import AllReduce
from
.basic_linear
import
BasicLinear
from
.basic_linear
import
BasicLinear
from
.bias
import
Bias
from
.bias
import
Bias
from
.identity
import
Identity
from
.identity
import
Identity
from
.l2normalization
import
L2Normalization
from
.layer_norm
import
LayerNorm
from
.layer_norm
import
LayerNorm
from
.make_extra_output
import
MakeExtraOutput
from
.make_extra_output
import
MakeExtraOutput
from
.quantize
import
Quantize
from
.quantize
import
Quantize
...
...
transformer_engine/pytorch/ops/basic/activation.py
View file @
2b05e121
...
@@ -96,12 +96,15 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
...
@@ -96,12 +96,15 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
if
not
x
.
is_contiguous
():
if
not
x
.
is_contiguous
():
x
=
x
.
contiguous
()
x
=
x
.
contiguous
()
# Check if FP8 is enabled
# Check if quantized compute is enabled
fp8_enabled
=
FP8GlobalStateManager
.
is_fp8_enabled
()
quantized_compute_enabled
=
FP8GlobalStateManager
.
is_fp8_enabled
()
if
fp8_enabled
and
next_op
is
not
None
and
next_op
.
num_quantizers
(
"forward"
)
>
0
:
quantizer
=
None
if
(
quantized_compute_enabled
and
next_op
is
not
None
and
next_op
.
num_quantizers
(
"forward"
)
>
0
):
quantizer
=
next_op
.
get_quantizer
(
"forward"
,
0
)
quantizer
=
next_op
.
get_quantizer
(
"forward"
,
0
)
else
:
quantizer
=
None
# Launch kernel
# Launch kernel
y
=
self
.
_activation_forward_impl
(
y
=
self
.
_activation_forward_impl
(
...
@@ -115,13 +118,13 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
...
@@ -115,13 +118,13 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Quantize input to FP8 before caching if needed
# Quantize input to FP8 before caching if needed
if
self
.
cache_quantized_input
:
if
self
.
cache_quantized_input
:
quantizer
=
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
x
.
device
)
input_
quantizer
=
Float8CurrentScalingQuantizer
(
tex
.
DType
.
kFloat8E4M3
,
x
.
device
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
input_
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
x
=
quantizer
(
x
)
x
=
input_
quantizer
(
x
)
# Save state for backward pass
# Save state for backward pass
ctx
.
save_for_backward
(
x
.
detach
())
ctx
.
save_for_backward
(
x
.
detach
())
ctx
.
fp8_enabled
=
fp8
_enabled
ctx
.
quantized_compute_enabled
=
quantized_compute
_enabled
ctx
.
dtype
=
dtype
ctx
.
dtype
=
dtype
ctx
.
prev_op
=
prev_op
ctx
.
prev_op
=
prev_op
...
@@ -153,11 +156,20 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
...
@@ -153,11 +156,20 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
if
not
dy
.
is_contiguous
():
if
not
dy
.
is_contiguous
():
dy
=
dy
.
contiguous
()
dy
=
dy
.
contiguous
()
# Check if quantized compute is enabled
quantizer
=
None
if
(
ctx
.
quantized_compute_enabled
and
ctx
.
prev_op
is
not
None
and
ctx
.
prev_op
.
num_quantizers
(
"backward"
)
>
0
):
quantizer
=
ctx
.
prev_op
.
get_quantizer
(
"backward"
,
0
)
# Launch kernel
# Launch kernel
dx
=
self
.
_activation_backward_impl
(
dx
=
self
.
_activation_backward_impl
(
reshape
(
dy
,
(
-
1
,
dy
.
size
(
-
1
))),
reshape
(
dy
,
(
-
1
,
dy
.
size
(
-
1
))),
reshape
(
x
,
(
-
1
,
x
.
size
(
-
1
))),
reshape
(
x
,
(
-
1
,
x
.
size
(
-
1
))),
None
,
quantizer
,
)
)
# Check grad input tensor
# Check grad input tensor
...
...
transformer_engine/pytorch/ops/basic/basic_linear.py
View file @
2b05e121
...
@@ -22,7 +22,7 @@ from ...distributed import (
...
@@ -22,7 +22,7 @@ from ...distributed import (
from
...fp8
import
FP8GlobalStateManager
from
...fp8
import
FP8GlobalStateManager
from
...module.base
import
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
from
...module.base
import
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
from
...tensor
import
Quantizer
,
QuantizedTensor
from
...tensor
import
Quantizer
,
QuantizedTensor
from
...tensor.float8_tensor
import
Float8Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
...tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
...tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
...
@@ -324,12 +324,38 @@ class BasicLinear(BasicOperation):
...
@@ -324,12 +324,38 @@ class BasicLinear(BasicOperation):
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
is_grad_enabled
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
is_grad_enabled
)
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
# Recipe-specific configuration
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_current_scaling
():
if
any
(
not
isinstance
(
q
,
Float8CurrentScalingQuantizer
)
for
q
in
(
input_quantizer
,
weight_quantizer
,
grad_output_quantizer
)
):
raise
RuntimeError
(
"FP8 current-scaling recipe is enabled, "
f
"but input quantizer is
{
input_quantizer
.
__class__
.
__name__
}
, "
f
"weight quantizer is
{
weight_quantizer
.
__class__
.
__name__
}
, "
f
"grad output quantizer is
{
grad_output_quantizer
.
__class__
.
__name__
}
"
)
input_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
input_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
weight_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
weight_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
grad_output_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
grad_output_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
if
self
.
sequence_parallel
and
self
.
tensor_parallel_mode
==
"column"
:
input_quantizer
.
with_amax_reduction
=
True
input_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
if
self
.
sequence_parallel
and
self
.
tensor_parallel_mode
==
"row"
:
grad_output_quantizer
.
with_amax_reduction
=
True
grad_output_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
# Make sure weight tensor has correct quantizer
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
# Note: Quantizer might have changed if quantization
# recipe changed
# recipe changed
if
isinstance
(
weight_quantizer
,
Float8Quantizer
)
and
isinstance
(
if
isinstance
(
weight
,
Float8
TensorBase
weight
_quantizer
,
(
Float8
Quantizer
,
Float8CurrentScalingQuantizer
)
):
)
and
isinstance
(
weight
,
Float8TensorBase
)
:
weight
.
_quantizer
=
weight_quantizer
weight
.
_quantizer
=
weight_quantizer
@
staticmethod
@
staticmethod
...
@@ -349,7 +375,9 @@ class BasicLinear(BasicOperation):
...
@@ -349,7 +375,9 @@ class BasicLinear(BasicOperation):
input_quantizer
:
Optional
[
Quantizer
]
=
None
,
input_quantizer
:
Optional
[
Quantizer
]
=
None
,
weight_quantizer
:
Optional
[
Quantizer
]
=
None
,
weight_quantizer
:
Optional
[
Quantizer
]
=
None
,
output_quantizer
:
Optional
[
Quantizer
]
=
None
,
output_quantizer
:
Optional
[
Quantizer
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
input_requires_grad
:
bool
=
True
,
weight_requires_grad
:
bool
=
True
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
"""Functional API for forward pass
"""Functional API for forward pass
Parameters
Parameters
...
@@ -385,17 +413,25 @@ class BasicLinear(BasicOperation):
...
@@ -385,17 +413,25 @@ class BasicLinear(BasicOperation):
Builder class for quantized weight tensor.
Builder class for quantized weight tensor.
output_quantizer: Quantizer, optional
output_quantizer: Quantizer, optional
Builder class for quantized output tensor.
Builder class for quantized output tensor.
input_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the input tensor is
required in the backward pass.
weight_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the weight tensor is
required in the backward pass.
Returns
Returns
-------
-------
torch.Tensor
torch.Tensor
Output tensor
Output tensor
torch.Tensor
torch.Tensor, optional
Input tensor used in GEMM, possibly cast and reshaped from
Input tensor, ready for use in backward pass. `None` is
provided input tensor
returned if loss gradient w.r.t. the weight tensor is not
torch.Tensor
required.
Weight tensor used in GEMM, possibly cast and reshaped from
torch.Tensor, optional
provided weight tensor
Weight tensor, ready for use in backward pass. `None` is
returned if loss gradient w.r.t. the input tensor is not
required.
"""
"""
...
@@ -416,7 +452,7 @@ class BasicLinear(BasicOperation):
...
@@ -416,7 +452,7 @@ class BasicLinear(BasicOperation):
if
with_quantized_compute
:
if
with_quantized_compute
:
if
input_quantizer
is
None
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
raise
ValueError
(
"Missing quantizer for input tensor"
)
input_quantizer
.
set_usage
(
rowwise
=
True
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
if
with_x_all_gather
:
if
with_x_all_gather
:
input_quantizer
.
set_usage
(
columnwise
=
False
)
input_quantizer
.
set_usage
(
columnwise
=
False
)
x
,
x_async
=
gather_along_first_dim
(
x
,
x_async
=
gather_along_first_dim
(
...
@@ -449,7 +485,7 @@ class BasicLinear(BasicOperation):
...
@@ -449,7 +485,7 @@ class BasicLinear(BasicOperation):
if
with_quantized_compute
and
not
w_is_quantized
:
if
with_quantized_compute
and
not
w_is_quantized
:
if
weight_quantizer
is
None
:
if
weight_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for weight tensor"
)
raise
ValueError
(
"Missing quantizer for weight tensor"
)
weight_quantizer
.
set_usage
(
rowwise
=
True
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
input_requires_grad
)
w
=
weight_quantizer
(
w
)
w
=
weight_quantizer
(
w
)
elif
not
with_quantized_compute
and
w_is_quantized
:
elif
not
with_quantized_compute
and
w_is_quantized
:
w
=
w
.
dequantize
()
w
=
w
.
dequantize
()
...
@@ -526,17 +562,25 @@ class BasicLinear(BasicOperation):
...
@@ -526,17 +562,25 @@ class BasicLinear(BasicOperation):
else
:
else
:
torch
.
distributed
.
all_reduce
(
y
,
group
=
tensor_parallel_group
)
torch
.
distributed
.
all_reduce
(
y
,
group
=
tensor_parallel_group
)
# Detach input tensor if needed
# Prepare weight tensor for backward pass
# Note: PyTorch autograd produces esoteric errors if we save
if
input_requires_grad
:
# input tensor as context for backward pass.
if
w
is
not
weight
and
with_quantized_compute
and
isinstance
(
w
,
QuantizedTensor
):
if
x_local
is
input
:
w
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
x_local
=
x_local
.
detach
()
else
:
w
=
None
# Configure input tensor for backward pass
# Prepare input tensor for backward pass
if
with_quantized_compute
and
isinstance
(
x_local
,
QuantizedTensor
):
if
weight_requires_grad
:
if
not
(
isinstance
(
x_local
,
Float8TensorBase
)
and
with_x_all_gather
):
if
x_local
is
input
:
# FP8 does not support all-gather of transpose data
# PyTorch autograd produces esoteric errors if we
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
# cache input tensor directly.
x_local
=
x_local
.
detach
()
if
with_quantized_compute
and
isinstance
(
x_local
,
QuantizedTensor
):
if
not
(
isinstance
(
x_local
,
Float8TensorBase
)
and
with_x_all_gather
):
# FP8 does not support all-gather of transpose data
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
else
:
x_local
=
None
return
y
,
x_local
,
w
return
y
,
x_local
,
w
...
@@ -892,7 +936,7 @@ class BasicLinear(BasicOperation):
...
@@ -892,7 +936,7 @@ class BasicLinear(BasicOperation):
dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
# Linear forward
# Linear forward
output
,
x_local
,
_
=
BasicLinear
.
_functional_forward
(
output
,
x_local
,
w
=
BasicLinear
.
_functional_forward
(
input
=
input_
,
input
=
input_
,
weight
=
self
.
weight
,
weight
=
self
.
weight
,
dtype
=
dtype
,
dtype
=
dtype
,
...
@@ -903,10 +947,12 @@ class BasicLinear(BasicOperation):
...
@@ -903,10 +947,12 @@ class BasicLinear(BasicOperation):
input_quantizer
=
input_quantizer
,
input_quantizer
=
input_quantizer
,
weight_quantizer
=
weight_quantizer
,
weight_quantizer
=
weight_quantizer
,
output_quantizer
=
output_quantizer
,
output_quantizer
=
output_quantizer
,
input_requires_grad
=
input_requires_grad
,
weight_requires_grad
=
weight_requires_grad
,
)
)
# Save state for backward pass
# Save state for backward pass
ctx
.
save_for_backward
(
x_local
)
ctx
.
save_for_backward
(
x_local
,
w
)
ctx
.
with_quantized_compute
=
with_quantized_compute
ctx
.
with_quantized_compute
=
with_quantized_compute
ctx
.
input_quantizer
=
input_quantizer
ctx
.
input_quantizer
=
input_quantizer
ctx
.
weight_quantizer
=
weight_quantizer
ctx
.
weight_quantizer
=
weight_quantizer
...
@@ -926,7 +972,7 @@ class BasicLinear(BasicOperation):
...
@@ -926,7 +972,7 @@ class BasicLinear(BasicOperation):
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Optional
[
torch
.
Tensor
]]]:
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Optional
[
torch
.
Tensor
]]]:
# Saved tensors from forward pass
# Saved tensors from forward pass
(
x_local
,)
=
ctx
.
saved_tensors
(
x_local
,
w
)
=
ctx
.
saved_tensors
# wgrad fusion
# wgrad fusion
accumulate_into_main_grad
=
self
.
_accumulate_into_main_grad
accumulate_into_main_grad
=
self
.
_accumulate_into_main_grad
...
@@ -946,7 +992,7 @@ class BasicLinear(BasicOperation):
...
@@ -946,7 +992,7 @@ class BasicLinear(BasicOperation):
grad_input
,
grad_weight
=
BasicLinear
.
_functional_backward
(
grad_input
,
grad_weight
=
BasicLinear
.
_functional_backward
(
grad_output
=
grad_output
,
grad_output
=
grad_output
,
input
=
x_local
,
input
=
x_local
,
weight
=
self
.
weight
,
weight
=
w
,
input_requires_grad
=
ctx
.
input_requires_grad
,
input_requires_grad
=
ctx
.
input_requires_grad
,
weight_requires_grad
=
ctx
.
weight_requires_grad
,
weight_requires_grad
=
ctx
.
weight_requires_grad
,
dtype
=
ctx
.
dtype
,
dtype
=
ctx
.
dtype
,
...
...
transformer_engine/pytorch/ops/basic/l2normalization.py
0 → 100644
View file @
2b05e121
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusable operation for L2 Normalization."""
from
__future__
import
annotations
from
typing
import
Optional
import
torch
from
...tensor
import
QuantizedTensor
from
...utils
import
clear_tensor_data
from
..op
import
BasicOperation
,
OperationContext
from
...jit
import
(
l2normalization_fused
,
l2normalization_fwd_fused
,
l2normalization_backward_fused
,
set_jit_fusion_options
,
warmup_jit_l2normalization_all_dtypes
,
)
class
L2Normalization
(
BasicOperation
):
r
"""L2 Normalization
Applies L2 normalization over the last dimension of input tensors.
This is a parameter-free normalization that scales each vector to unit L2 norm.
.. math::
y = \frac{x}{\sqrt{\sum_{i} x_i^2 + \varepsilon}}
This operation is used e.g. for query-key normalization in attention mechanisms.
Parameters
----------
eps : float, default = 1e-6
A value added to the denominator for numerical stability
seq_length: int, default = None
sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
functions are warmed up before training to ensure same kernels are used for forward
propagation and activation recompute phase.
micro_batch_size: int, default = None
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propagation and activation recompute phase.
"""
def
__init__
(
self
,
*
,
eps
:
float
=
1e-6
,
seq_length
:
Optional
[
int
]
=
None
,
micro_batch_size
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
eps
:
float
=
eps
# JIT warmup for L2Normalization fused operations
if
seq_length
and
micro_batch_size
:
if
torch
.
cuda
.
is_available
():
set_jit_fusion_options
()
# For L2Normalization, we don't know the hidden size until forward pass,
# but we can warm up with common sizes. For QK normalization, this will be
# the attention head dimension (hidden_size_per_attention_head), not the full
# model hidden dimension. Common head dimensions are 32, 64, 80, 96, 128, 256.
common_hidden_sizes
=
[
32
,
64
,
80
,
96
,
128
,
256
]
for
hidden_size
in
common_hidden_sizes
:
warmup_jit_l2normalization_all_dtypes
(
hidden_size
,
seq_length
,
micro_batch_size
)
def
op_forward
(
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op
:
Optional
[
BasicOperation
]
=
None
,
next_op
:
Optional
[
BasicOperation
]
=
None
,
)
->
torch
.
Tensor
:
# Use input directly - torch.compile can handle multi-dimensional tensors
x
=
input_
if
isinstance
(
x
,
QuantizedTensor
):
x
=
x
.
dequantize
()
# Check if backward pass is needed
requires_grad
=
ctx
.
requires_grad
# Compute L2 normalization using fused implementation
# L2 norm: x / sqrt(sum(x^2) + eps) = x * rsqrt(sum(x^2) + eps)
if
requires_grad
:
# Training: use version that returns both output and intermediate values
y
,
rsqrt_norm
=
l2normalization_fwd_fused
(
x
,
self
.
eps
)
else
:
# Inference: use lightweight version that only returns output
y
=
l2normalization_fused
(
x
,
self
.
eps
)
rsqrt_norm
=
None
# Not needed for inference
# Save state for backward pass
if
requires_grad
:
ctx
.
save_for_backward
(
x
,
rsqrt_norm
)
ctx
.
has_prev_op
=
prev_op
is
not
None
return
y
def
op_backward
(
self
,
ctx
:
OperationContext
,
grad_output
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
tuple
[()]]:
# Saved tensors from forward pass
x
,
rsqrt_norm
=
ctx
.
saved_tensors
dy
=
grad_output
if
isinstance
(
dy
,
QuantizedTensor
):
dy
=
dy
.
dequantize
()
# Compute L2 norm backward pass using fused implementation
dx
=
l2normalization_backward_fused
(
dy
,
x
,
rsqrt_norm
,
self
.
eps
)
# Clear saved tensors if possible
if
ctx
.
has_prev_op
:
clear_tensor_data
(
x
)
clear_tensor_data
(
rsqrt_norm
)
# No parameters, so empty tuple for param grads
return
dx
,
()
transformer_engine/pytorch/ops/fused/backward_linear_add.py
View file @
2b05e121
...
@@ -51,7 +51,7 @@ class BackwardLinearAdd(FusedOperation):
...
@@ -51,7 +51,7 @@ class BackwardLinearAdd(FusedOperation):
linear_op_ctx
=
basic_op_ctxs
[
0
]
linear_op_ctx
=
basic_op_ctxs
[
0
]
# Saved tensors from forward pass
# Saved tensors from forward pass
(
x_local
,)
=
linear_op_ctx
.
saved_tensors
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
# wgrad fusion
# wgrad fusion
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
...
@@ -72,7 +72,7 @@ class BackwardLinearAdd(FusedOperation):
...
@@ -72,7 +72,7 @@ class BackwardLinearAdd(FusedOperation):
grad_input
,
grad_weight
=
BasicLinear
.
_functional_backward
(
grad_input
,
grad_weight
=
BasicLinear
.
_functional_backward
(
grad_output
=
grad_output
,
grad_output
=
grad_output
,
input
=
x_local
,
input
=
x_local
,
weight
=
linear_op
.
weight
,
weight
=
w
,
input_requires_grad
=
linear_op_ctx
.
input_requires_grad
,
input_requires_grad
=
linear_op_ctx
.
input_requires_grad
,
weight_requires_grad
=
linear_op_ctx
.
weight_requires_grad
,
weight_requires_grad
=
linear_op_ctx
.
weight_requires_grad
,
dtype
=
grad_input
.
dtype
,
dtype
=
grad_input
.
dtype
,
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
View file @
2b05e121
...
@@ -82,6 +82,10 @@ class ForwardLinearBiasActivation(FusedOperation):
...
@@ -82,6 +82,10 @@ class ForwardLinearBiasActivation(FusedOperation):
else
:
else
:
raise
NotImplementedError
(
"Activations are not yet supported"
)
raise
NotImplementedError
(
"Activations are not yet supported"
)
# Check which grads are required
input_requires_grad
=
linear_op_ctx
.
requires_grad
and
input_
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
# FP8 metadata
# FP8 metadata
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
input_quantizer
=
None
input_quantizer
=
None
...
@@ -106,7 +110,7 @@ class ForwardLinearBiasActivation(FusedOperation):
...
@@ -106,7 +110,7 @@ class ForwardLinearBiasActivation(FusedOperation):
dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
# Linear forward
# Linear forward
output
,
x_local
,
_
=
BasicLinear
.
_functional_forward
(
output
,
x_local
,
w
=
BasicLinear
.
_functional_forward
(
input
=
input_
,
input
=
input_
,
weight
=
linear_op
.
weight
,
weight
=
linear_op
.
weight
,
bias
=
bias
,
bias
=
bias
,
...
@@ -118,18 +122,20 @@ class ForwardLinearBiasActivation(FusedOperation):
...
@@ -118,18 +122,20 @@ class ForwardLinearBiasActivation(FusedOperation):
input_quantizer
=
input_quantizer
,
input_quantizer
=
input_quantizer
,
weight_quantizer
=
weight_quantizer
,
weight_quantizer
=
weight_quantizer
,
output_quantizer
=
output_quantizer
,
output_quantizer
=
output_quantizer
,
input_requires_grad
=
input_requires_grad
,
weight_requires_grad
=
weight_requires_grad
,
)
)
# Save state for backward pass
# Save state for backward pass
linear_op_ctx
.
save_for_backward
(
x_local
)
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input_quantizer
=
input_quantizer
linear_op_ctx
.
input_quantizer
=
input_quantizer
linear_op_ctx
.
weight_quantizer
=
weight_quantizer
linear_op_ctx
.
weight_quantizer
=
weight_quantizer
linear_op_ctx
.
grad_output_quantizer
=
grad_output_quantizer
linear_op_ctx
.
grad_output_quantizer
=
grad_output_quantizer
linear_op_ctx
.
grad_input_quantizer
=
grad_input_quantizer
linear_op_ctx
.
grad_input_quantizer
=
grad_input_quantizer
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
input_requires_grad
=
input_
.
requires_grad
linear_op_ctx
.
input_requires_grad
=
input_requires_grad
linear_op_ctx
.
weight_requires_grad
=
linear_op
.
weight
.
requires_grad
linear_op_ctx
.
weight_requires_grad
=
weight
_
requires_grad
linear_op_ctx
.
has_prev_op
=
basic_op_prev_ops
[
0
]
is
not
None
linear_op_ctx
.
has_prev_op
=
basic_op_prev_ops
[
0
]
is
not
None
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
View file @
2b05e121
...
@@ -76,6 +76,10 @@ class ForwardLinearBiasAdd(FusedOperation):
...
@@ -76,6 +76,10 @@ class ForwardLinearBiasAdd(FusedOperation):
if
basic_op_kwargs
[
idx
]:
if
basic_op_kwargs
[
idx
]:
raise
ValueError
(
"Bias operation forward does not expect keyword arguments"
)
raise
ValueError
(
"Bias operation forward does not expect keyword arguments"
)
# Check which grads are required
input_requires_grad
=
linear_op_ctx
.
requires_grad
and
input_
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
# FP8 metadata
# FP8 metadata
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
input_quantizer
=
None
input_quantizer
=
None
...
@@ -98,7 +102,7 @@ class ForwardLinearBiasAdd(FusedOperation):
...
@@ -98,7 +102,7 @@ class ForwardLinearBiasAdd(FusedOperation):
# Linear forward
# Linear forward
output
=
basic_op_extra_inputs
[
self
.
_op_idxs
[
"add"
]][
0
]
output
=
basic_op_extra_inputs
[
self
.
_op_idxs
[
"add"
]][
0
]
output
,
x_local
,
_
=
BasicLinear
.
_functional_forward
(
output
,
x_local
,
w
=
BasicLinear
.
_functional_forward
(
input
=
input_
,
input
=
input_
,
weight
=
linear_op
.
weight
,
weight
=
linear_op
.
weight
,
bias
=
bias
,
bias
=
bias
,
...
@@ -111,18 +115,20 @@ class ForwardLinearBiasAdd(FusedOperation):
...
@@ -111,18 +115,20 @@ class ForwardLinearBiasAdd(FusedOperation):
input_quantizer
=
input_quantizer
,
input_quantizer
=
input_quantizer
,
weight_quantizer
=
weight_quantizer
,
weight_quantizer
=
weight_quantizer
,
output_quantizer
=
output_quantizer
,
output_quantizer
=
output_quantizer
,
input_requires_grad
=
input_requires_grad
,
weight_requires_grad
=
weight_requires_grad
,
)
)
# Save state for backward pass
# Save state for backward pass
linear_op_ctx
.
save_for_backward
(
x_local
)
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input_quantizer
=
input_quantizer
linear_op_ctx
.
input_quantizer
=
input_quantizer
linear_op_ctx
.
weight_quantizer
=
weight_quantizer
linear_op_ctx
.
weight_quantizer
=
weight_quantizer
linear_op_ctx
.
grad_output_quantizer
=
grad_output_quantizer
linear_op_ctx
.
grad_output_quantizer
=
grad_output_quantizer
linear_op_ctx
.
grad_input_quantizer
=
grad_input_quantizer
linear_op_ctx
.
grad_input_quantizer
=
grad_input_quantizer
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
input_requires_grad
=
input_
.
requires_grad
linear_op_ctx
.
input_requires_grad
=
input_requires_grad
linear_op_ctx
.
weight_requires_grad
=
linear_op
.
weight
.
requires_grad
linear_op_ctx
.
weight_requires_grad
=
weight
_
requires_grad
linear_op_ctx
.
has_prev_op
=
basic_op_prev_ops
[
0
]
is
not
None
linear_op_ctx
.
has_prev_op
=
basic_op_prev_ops
[
0
]
is
not
None
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
...
...
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
2b05e121
...
@@ -407,16 +407,26 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -407,16 +407,26 @@ class UserbuffersBackwardLinear(FusedOperation):
# Initialize grad output
# Initialize grad output
if
tensor_parallel_mode
==
"row"
and
isinstance
(
grad_output_quantizer
,
MXFP8Quantizer
):
if
tensor_parallel_mode
==
"row"
and
isinstance
(
grad_output_quantizer
,
MXFP8Quantizer
):
# UB does not support overlapping grad output
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, MXFP8 does not
# all-gather with wgrad GEMM. Also, we can't
# allow reusing the grad output that was gathered for
# convert row-scaled MXFP8 to column-scaled, so we
# the dgrad GEMM. We work around with blocking
# can't reuse the grad output that was gathered
# all-gather for column-scaled MXFP8 data.
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
dy
,
_
=
gather_along_first_dim
(
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
grad_output
,
dgrad_comm_stream
=
ub_comm_dgrad
.
get_communication_stream
()
tensor_parallel_group
,
with
torch
.
cuda
.
stream
(
dgrad_comm_stream
):
quantizer
=
grad_output_quantizer
,
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
)
# This ensures that we don't start until all communication for the dgrad GEMM is complete
dy
,
dy_work
=
gather_along_first_dim
(
dy_local
,
tensor_parallel_group
,
async_op
=
True
,
quantizer
=
grad_output_quantizer
,
)
# Synchronize with the main stream
dy_work
.
wait
()
if
tensor_parallel_mode
==
"column"
:
if
tensor_parallel_mode
==
"column"
:
dy
=
dy_local
dy
=
dy_local
if
dy
is
None
:
if
dy
is
None
:
...
@@ -500,7 +510,7 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -500,7 +510,7 @@ class UserbuffersBackwardLinear(FusedOperation):
bias_op
=
self
.
basic_ops
[
idx
]
bias_op
=
self
.
basic_ops
[
idx
]
# Saved tensors from forward pass
# Saved tensors from forward pass
(
x_local
,)
=
linear_op_ctx
.
saved_tensors
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
# wgrad fusion
# wgrad fusion
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
...
@@ -520,7 +530,7 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -520,7 +530,7 @@ class UserbuffersBackwardLinear(FusedOperation):
retval
=
UserbuffersBackwardLinear
.
_functional_backward
(
retval
=
UserbuffersBackwardLinear
.
_functional_backward
(
grad_output
=
grad_output
,
grad_output
=
grad_output
,
input
=
x_local
,
input
=
x_local
,
weight
=
linear_op
.
weight
,
weight
=
w
,
weight_requires_grad
=
linear_op_ctx
.
weight_requires_grad
,
weight_requires_grad
=
linear_op_ctx
.
weight_requires_grad
,
bias_requires_grad
=
(
bias_op
is
not
None
),
bias_requires_grad
=
(
bias_op
is
not
None
),
dtype
=
linear_op_ctx
.
dtype
,
dtype
=
linear_op_ctx
.
dtype
,
...
...
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
2b05e121
...
@@ -21,7 +21,7 @@ from ...module.base import (
...
@@ -21,7 +21,7 @@ from ...module.base import (
_2X_ACC_FPROP
,
_2X_ACC_FPROP
,
)
)
from
...tensor.quantized_tensor
import
QuantizedTensorBase
,
Quantizer
from
...tensor.quantized_tensor
import
QuantizedTensorBase
,
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
from
...utils
import
canonicalize_device
,
canonicalize_dtype
from
...utils
import
canonicalize_device
,
canonicalize_dtype
from
..basic
import
BasicLinear
,
Bias
,
ReduceScatter
from
..basic
import
BasicLinear
,
Bias
,
ReduceScatter
...
@@ -98,6 +98,8 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -98,6 +98,8 @@ class UserbuffersForwardLinear(FusedOperation):
input_quantizer
:
Optional
[
Quantizer
]
=
None
,
input_quantizer
:
Optional
[
Quantizer
]
=
None
,
weight_quantizer
:
Optional
[
Quantizer
]
=
None
,
weight_quantizer
:
Optional
[
Quantizer
]
=
None
,
output_quantizer
:
Optional
[
Quantizer
]
=
None
,
output_quantizer
:
Optional
[
Quantizer
]
=
None
,
input_requires_grad
:
bool
=
True
,
weight_requires_grad
:
bool
=
True
,
ub_comm_name
:
str
,
ub_comm_name
:
str
,
)
->
tuple
[
torch
.
Tensor
,
dict
]:
)
->
tuple
[
torch
.
Tensor
,
dict
]:
"""Functional API for forward pass
"""Functional API for forward pass
...
@@ -131,6 +133,12 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -131,6 +133,12 @@ class UserbuffersForwardLinear(FusedOperation):
Builder class for quantized weight tensor.
Builder class for quantized weight tensor.
output_quantizer: Quantizer, optional
output_quantizer: Quantizer, optional
Builder class for quantized output tensor.
Builder class for quantized output tensor.
input_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the input tensor is
required in the backward pass.
weight_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the weight tensor is
required in the backward pass.
ub_comm_name: str
ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators
used to access the corresponding Userbuffers communicators
...
@@ -141,8 +149,9 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -141,8 +149,9 @@ class UserbuffersForwardLinear(FusedOperation):
torch.Tensor
torch.Tensor
Output tensor
Output tensor
dict
dict
Extra output tensors. "input" is the input tensor,
Extra output tensors. "input" is the input tensor and
possibly cast and reshaped from the provided input tensor.
"weight" is the weight tensor, both ready for use in the
backward pass.
"""
"""
...
@@ -198,8 +207,10 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -198,8 +207,10 @@ class UserbuffersForwardLinear(FusedOperation):
if
with_ub_all_gather
:
if
with_ub_all_gather
:
if
input_quantizer
is
not
None
:
if
input_quantizer
is
not
None
:
if
not
isinstance
(
x_local
,
QuantizedTensorBase
):
if
not
isinstance
(
x_local
,
QuantizedTensorBase
):
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
if
isinstance
(
input_quantizer
,
Float8Quantizer
):
if
isinstance
(
input_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
):
input_quantizer
.
set_usage
(
columnwise
=
False
)
input_quantizer
.
set_usage
(
columnwise
=
False
)
x_local
=
input_quantizer
(
x_local
)
x_local
=
input_quantizer
(
x_local
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
...
@@ -212,7 +223,7 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -212,7 +223,7 @@ class UserbuffersForwardLinear(FusedOperation):
else
:
else
:
if
with_quantized_compute
:
if
with_quantized_compute
:
if
not
isinstance
(
x_local
,
QuantizedTensorBase
):
if
not
isinstance
(
x_local
,
QuantizedTensorBase
):
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
x_local
=
input_quantizer
(
x_local
)
x_local
=
input_quantizer
(
x_local
)
else
:
else
:
if
isinstance
(
x_local
,
QuantizedTensorBase
):
if
isinstance
(
x_local
,
QuantizedTensorBase
):
...
@@ -225,7 +236,7 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -225,7 +236,7 @@ class UserbuffersForwardLinear(FusedOperation):
w
=
weight
w
=
weight
w_is_quantized
=
isinstance
(
w
,
QuantizedTensorBase
)
w_is_quantized
=
isinstance
(
w
,
QuantizedTensorBase
)
if
with_quantized_compute
and
not
w_is_quantized
:
if
with_quantized_compute
and
not
w_is_quantized
:
weight_quantizer
.
set_usage
(
rowwise
=
True
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
input_requires_grad
)
w
=
weight_quantizer
(
w
)
w
=
weight_quantizer
(
w
)
elif
not
with_quantized_compute
and
w_is_quantized
:
elif
not
with_quantized_compute
and
w_is_quantized
:
w
=
w
.
dequantize
()
w
=
w
.
dequantize
()
...
@@ -258,17 +269,25 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -258,17 +269,25 @@ class UserbuffersForwardLinear(FusedOperation):
else
:
else
:
y_local
=
gemm_output
y_local
=
gemm_output
# Detach input tensor if needed
# Prepare weight tensor for backward pass
# Note: PyTorch autograd produces esoteric errors if we save
if
input_requires_grad
:
# input tensor as context for backward pass.
if
w
is
not
weight
and
with_quantized_compute
and
isinstance
(
w
,
QuantizedTensorBase
):
if
x_local
is
input
:
w
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
x_local
=
x_local
.
detach
()
else
:
w
=
None
# Configure input tensor for backward pass
if
with_quantized_compute
and
isinstance
(
x_local
,
QuantizedTensorBase
):
# Prepare input tensor for backward pass
if
not
(
isinstance
(
x_local
,
Float8TensorBase
)
and
with_ub_all_gather
):
if
weight_requires_grad
:
# FP8 does not support all-gather of transpose data
if
x_local
is
input
:
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local
=
x_local
.
detach
()
if
with_quantized_compute
and
isinstance
(
x_local
,
QuantizedTensorBase
):
if
not
(
isinstance
(
x_local
,
Float8TensorBase
)
and
with_ub_all_gather
):
# FP8 does not support all-gather of transpose data
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
else
:
x_local
=
None
# Return cast tensors
# Return cast tensors
extra_outputs
=
{
"input"
:
x_local
,
"weight"
:
w
}
extra_outputs
=
{
"input"
:
x_local
,
"weight"
:
w
}
...
@@ -298,6 +317,10 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -298,6 +317,10 @@ class UserbuffersForwardLinear(FusedOperation):
if
basic_op_kwargs
[
idx
]:
if
basic_op_kwargs
[
idx
]:
raise
ValueError
(
"Bias operation forward does not expect keyword arguments"
)
raise
ValueError
(
"Bias operation forward does not expect keyword arguments"
)
# Check which grads are required
input_requires_grad
=
linear_op_ctx
.
requires_grad
and
input_
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
# Quantization metadata
# Quantization metadata
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
input_quantizer
=
None
input_quantizer
=
None
...
@@ -306,8 +329,10 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -306,8 +329,10 @@ class UserbuffersForwardLinear(FusedOperation):
grad_input_quantizer
=
None
grad_input_quantizer
=
None
if
with_quantized_compute
:
if
with_quantized_compute
:
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
not
recipe
.
delayed
()
and
not
recipe
.
mxfp8
():
if
not
any
((
recipe
.
delayed
(),
recipe
.
float8_current_scaling
(),
recipe
.
mxfp8
())):
raise
RuntimeError
(
"Userbuffers is only supported with FP8 delayed scaling recipe"
)
raise
RuntimeError
(
f
"Unsupported recipe for Userbuffers (
{
recipe
.
__class__
.
__name__
}
)"
)
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
grad_output_quantizer
=
linear_op
.
get_quantizer
(
"backward"
,
0
)
grad_output_quantizer
=
linear_op
.
get_quantizer
(
"backward"
,
0
)
...
@@ -338,12 +363,15 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -338,12 +363,15 @@ class UserbuffersForwardLinear(FusedOperation):
input_quantizer
=
input_quantizer
,
input_quantizer
=
input_quantizer
,
weight_quantizer
=
weight_quantizer
,
weight_quantizer
=
weight_quantizer
,
output_quantizer
=
None
,
# Not supported
output_quantizer
=
None
,
# Not supported
input_requires_grad
=
input_requires_grad
,
weight_requires_grad
=
weight_requires_grad
,
ub_comm_name
=
linear_op
.
_userbuffers_options
[
"comm_name"
],
ub_comm_name
=
linear_op
.
_userbuffers_options
[
"comm_name"
],
)
)
x_local
=
extra_outputs
[
"input"
]
x_local
=
extra_outputs
[
"input"
]
w
=
extra_outputs
[
"weight"
]
# Save state for backward pass
# Save state for backward pass
linear_op_ctx
.
save_for_backward
(
x_local
)
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input_quantizer
=
input_quantizer
linear_op_ctx
.
input_quantizer
=
input_quantizer
linear_op_ctx
.
weight_quantizer
=
weight_quantizer
linear_op_ctx
.
weight_quantizer
=
weight_quantizer
...
@@ -351,8 +379,8 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -351,8 +379,8 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op_ctx
.
grad_input_quantizer
=
grad_input_quantizer
linear_op_ctx
.
grad_input_quantizer
=
grad_input_quantizer
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
input_dims
=
input_
.
size
()
linear_op_ctx
.
input_dims
=
input_
.
size
()
linear_op_ctx
.
input_requires_grad
=
input_
.
requires_grad
linear_op_ctx
.
input_requires_grad
=
input_requires_grad
linear_op_ctx
.
weight_requires_grad
=
linear_op
.
weight
.
requires_grad
linear_op_ctx
.
weight_requires_grad
=
weight
_
requires_grad
linear_op_ctx
.
has_prev_op
=
basic_op_prev_ops
[
0
]
is
not
None
linear_op_ctx
.
has_prev_op
=
basic_op_prev_ops
[
0
]
is
not
None
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
...
...
transformer_engine/pytorch/ops/fuser.py
View file @
2b05e121
...
@@ -61,13 +61,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
...
@@ -61,13 +61,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
def
forward
(
def
forward
(
func_ctx
:
Optional
[
torch
.
autograd
.
function
.
FunctionCtx
],
func_ctx
:
Optional
[
torch
.
autograd
.
function
.
FunctionCtx
],
input_
:
torch
.
Tensor
,
input_
:
torch
.
Tensor
,
forward_ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
fuser
:
OperationFuser
,
backward_ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
basic_ops
:
list
[
BasicOperation
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
is_grad_enabled
:
bool
,
is_grad_enabled
:
bool
,
num_params
:
int
,
num_extra_inputs
:
int
,
*
params_and_extra_inputs
:
torch
.
nn
.
Parameter
,
*
params_and_extra_inputs
:
torch
.
nn
.
Parameter
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
...]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
...]:
"""Forward pass
"""Forward pass
...
@@ -78,20 +74,12 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
...
@@ -78,20 +74,12 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
Context for PyTorch autograd function
Context for PyTorch autograd function
input_: torch.Tensor
input_: torch.Tensor
Input to first operation in pipeline
Input to first operation in pipeline
forward_ops: list of tuple
fuser: OperationFuser
Forward pass operations and the indices of the
Container for the pipeline of operations to run
corresponding basic operations. The order should match
basic_ops.
backward_ops: list of tuple
Backward pass operations and the indices of the
corresponding basic operations. The order should be the
reverse of basic_ops.
basic_ops: list of BasicOperation
Basic operations
basic_op_kwargs: list of dict
basic_op_kwargs: list of dict
Keyword arguments to BasicOperation
Keyword arguments to BasicOperation
num_params: int
is_grad_enabled: bool
Number of parameter tensors to include in autograd graph.
Should context be saved for backward
*params_and_extra_inputs: torch.Tensor
*params_and_extra_inputs: torch.Tensor
Other tensor inputs to include in autograd graph. Consists
Other tensor inputs to include in autograd graph. Consists
of parameter tensors, followed by extra operation inputs.
of parameter tensors, followed by extra operation inputs.
...
@@ -106,26 +94,20 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
...
@@ -106,26 +94,20 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
"""
"""
# Operation autograd contexts
# Operation autograd contexts
basic_op_ctxs
=
[
OperationContext
()
for
_
in
range
(
len
(
basic_ops
)
)
]
basic_op_ctxs
=
[
OperationContext
()
for
_
in
range
(
fuser
.
_num_
basic_ops
)]
# Unflatten list of parameters and extra tensor inputs
# Unflatten list of parameters and extra tensor inputs
if
len
(
params_and_extra_inputs
)
!=
num_params
+
num_extra_inputs
:
extra_inputs
=
params_and_extra_inputs
[
-
fuser
.
_num_extra_inputs
:]
raise
ValueError
(
f
"Expected
{
num_params
+
num_extra_inputs
}
extra tensor arguments "
f
"(
{
num_params
}
parameters,
{
num_extra_inputs
}
extra inputs), "
f
"but got
{
len
(
params_and_extra_inputs
)
}
"
)
_
,
extra_inputs
=
_split_tuple
(
params_and_extra_inputs
,
num_params
)
basic_op_extra_inputs
=
[]
basic_op_extra_inputs
=
[]
for
op
in
basic_ops
:
for
op
in
fuser
.
_
basic_ops
:
xs
,
extra_inputs
=
_split_tuple
(
extra_inputs
,
op
.
num_extra_inputs
)
xs
,
extra_inputs
=
_split_tuple
(
extra_inputs
,
op
.
num_extra_inputs
)
basic_op_extra_inputs
.
append
(
xs
)
basic_op_extra_inputs
.
append
(
xs
)
# Apply forward ops
# Apply forward ops
x
=
input_
x
=
input_
requires_grad
=
is_grad_enabled
and
x
.
requires_grad
requires_grad
=
is_grad_enabled
and
x
.
requires_grad
extra_outputs
=
[
None
for
_
in
range
(
len
(
basic_ops
))]
extra_outputs
=
[
None
]
*
fuser
.
_num_
basic_ops
for
op
,
basic_op_idxs
in
forward_ops
:
for
op
,
basic_op_idxs
in
fuser
.
_
forward_ops
:
# Check if backward op is required
# Check if backward op is required
if
is_grad_enabled
:
if
is_grad_enabled
:
...
@@ -143,9 +125,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
...
@@ -143,9 +125,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Forward op
# Forward op
extra_inputs
=
[
basic_op_extra_inputs
[
idx
]
for
idx
in
basic_op_idxs
]
extra_inputs
=
[
basic_op_extra_inputs
[
idx
]
for
idx
in
basic_op_idxs
]
prev_ops
=
[
basic_ops
[
idx
-
1
]
if
idx
>
0
else
None
for
idx
in
basic_op_idxs
]
prev_ops
=
[
fuser
.
_
basic_ops
[
idx
-
1
]
if
idx
>
0
else
None
for
idx
in
basic_op_idxs
]
next_ops
=
[
next_ops
=
[
basic_ops
[
idx
+
1
]
if
(
idx
<
len
(
basic_ops
)
-
1
)
else
None
for
idx
in
basic_op_idxs
fuser
.
_basic_ops
[
idx
+
1
]
if
(
idx
<
fuser
.
_num_basic_ops
-
1
)
else
None
for
idx
in
basic_op_idxs
]
]
x
,
fused_op_extra_outputs
=
op
.
fuser_forward
(
x
,
fused_op_extra_outputs
=
op
.
fuser_forward
(
[
basic_op_ctxs
[
idx
]
for
idx
in
basic_op_idxs
],
[
basic_op_ctxs
[
idx
]
for
idx
in
basic_op_idxs
],
...
@@ -165,7 +148,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
...
@@ -165,7 +148,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
extra_outputs_flat
=
[]
extra_outputs_flat
=
[]
for
idx
,
ys
in
enumerate
(
extra_outputs
):
for
idx
,
ys
in
enumerate
(
extra_outputs
):
ys
=
list
(
ys
)
ys
=
list
(
ys
)
num_extra_outputs
=
basic_ops
[
idx
].
num_extra_outputs
num_extra_outputs
=
fuser
.
_
basic_ops
[
idx
].
num_extra_outputs
if
len
(
ys
)
!=
num_extra_outputs
:
if
len
(
ys
)
!=
num_extra_outputs
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Expected op
{
idx
}
to generate "
f
"Expected op
{
idx
}
to generate "
...
@@ -189,11 +172,11 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
...
@@ -189,11 +172,11 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx
.
save_for_backward
(
*
to_save
)
func_ctx
.
save_for_backward
(
*
to_save
)
# Other context
# Other context
func_ctx
.
backward_ops
=
backward_ops
func_ctx
.
backward_ops
=
fuser
.
_
backward_ops
func_ctx
.
basic_ops
=
basic_ops
func_ctx
.
basic_ops
=
fuser
.
_
basic_ops
func_ctx
.
basic_op_ctxs
=
basic_op_ctxs
func_ctx
.
basic_op_ctxs
=
basic_op_ctxs
func_ctx
.
basic_op_num_params
=
[
sum
(
1
for
_
in
op
.
parameters
())
for
op
in
basic_ops
]
func_ctx
.
basic_op_num_params
=
fuser
.
_num_list_
basic_op
_param
s
func_ctx
.
num_extra_inputs
=
num_extra_inputs
func_ctx
.
num_extra_inputs
=
fuser
.
_
num_extra_inputs
func_ctx
.
num_extra_outputs
=
len
(
extra_outputs_flat
)
func_ctx
.
num_extra_outputs
=
len
(
extra_outputs_flat
)
func_ctx
.
is_first_module
=
FP8GlobalStateManager
.
is_first_fp8_module
()
func_ctx
.
is_first_module
=
FP8GlobalStateManager
.
is_first_fp8_module
()
...
@@ -216,8 +199,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
...
@@ -216,8 +199,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_op_ctxs
=
func_ctx
.
basic_op_ctxs
basic_op_ctxs
=
func_ctx
.
basic_op_ctxs
# Unflatten list of saved tensors
# Unflatten list of saved tensors
saved_tensors
=
func_ctx
.
saved_tensors
for
ctx
in
basic_op_ctxs
:
for
ctx
in
basic_op_ctxs
:
ctx
.
saved_tensors
=
func_ctx
.
saved_tensors
[
slice
(
*
ctx
.
_saved_tensors_range
)]
ctx
.
saved_tensors
=
saved_tensors
[
slice
(
*
ctx
.
_saved_tensors_range
)]
ctx
.
_saved_tensors_range
=
None
ctx
.
_saved_tensors_range
=
None
# Unflatten list of extra tensor output grads
# Unflatten list of extra tensor output grads
...
@@ -292,13 +276,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
...
@@ -292,13 +276,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
return
(
return
(
dx
,
# input_
dx
,
# input_
None
,
# forward_ops
None
,
# fuser
None
,
# backward_ops
None
,
# basic_ops
None
,
# basic_op_kwargs
None
,
# basic_op_kwargs
None
,
# is_grad_enabled
None
,
# is_grad_enabled
None
,
# num_params
None
,
# num_extra_inputs
*
grad_params_flat
,
*
grad_params_flat
,
*
grad_extra_inputs_flat
,
*
grad_extra_inputs_flat
,
)
)
...
@@ -345,6 +325,10 @@ class OperationFuser:
...
@@ -345,6 +325,10 @@ class OperationFuser:
if
fuse_ops
:
if
fuse_ops
:
self
.
fuse_ops
()
self
.
fuse_ops
()
# Flatten list of parameters
self
.
_basic_op_params
=
[
param
for
op
in
self
.
_basic_ops
for
param
in
op
.
parameters
()]
self
.
_num_list_basic_op_params
=
[
sum
(
1
for
_
in
op
.
parameters
())
for
op
in
self
.
_basic_ops
]
@
classmethod
@
classmethod
def
_fuse_forward_ops
(
def
_fuse_forward_ops
(
cls
,
cls
,
...
@@ -377,6 +361,11 @@ class OperationFuser:
...
@@ -377,6 +361,11 @@ class OperationFuser:
*
extra_inputs
:
torch
.
Tensor
,
*
extra_inputs
:
torch
.
Tensor
,
basic_op_kwargs
:
Optional
[
list
[
dict
[
str
,
Any
]]]
=
None
,
basic_op_kwargs
:
Optional
[
list
[
dict
[
str
,
Any
]]]
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
...]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
...]:
# Verify extra input count
if
len
(
extra_inputs
)
!=
self
.
_num_extra_inputs
:
raise
ValueError
(
f
"Expected
{
self
.
_num_extra_inputs
}
extra inputs but got
{
len
(
extra_inputs
)
}
"
)
# Initialization before forward pass
# Initialization before forward pass
for
op
in
self
.
_basic_ops
:
for
op
in
self
.
_basic_ops
:
...
@@ -384,10 +373,7 @@ class OperationFuser:
...
@@ -384,10 +373,7 @@ class OperationFuser:
# Canonicalize op kwargs
# Canonicalize op kwargs
if
basic_op_kwargs
is
None
:
if
basic_op_kwargs
is
None
:
basic_op_kwargs
=
[{}
for
_
in
range
(
len
(
self
.
_basic_ops
))]
basic_op_kwargs
=
[{}]
*
self
.
_num_basic_ops
# Flatten list of parameters
params
=
[
param
for
op
in
self
.
_basic_ops
for
param
in
op
.
parameters
()]
# Fuser forward pass
# Fuser forward pass
is_grad_enabled
=
torch
.
is_grad_enabled
()
is_grad_enabled
=
torch
.
is_grad_enabled
()
...
@@ -399,14 +385,10 @@ class OperationFuser:
...
@@ -399,14 +385,10 @@ class OperationFuser:
args
=
[
None
]
args
=
[
None
]
args
+=
(
args
+=
(
input
,
input
,
self
.
_forward_ops
,
self
,
self
.
_backward_ops
,
self
.
_basic_ops
,
basic_op_kwargs
,
basic_op_kwargs
,
is_grad_enabled
,
is_grad_enabled
,
len
(
params
),
*
self
.
_basic_op_params
,
self
.
_num_extra_inputs
,
*
params
,
*
extra_inputs
,
*
extra_inputs
,
)
)
return
forward_func
(
*
args
)
return
forward_func
(
*
args
)
transformer_engine/pytorch/pyproject.toml
0 → 100755
View file @
2b05e121
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires
=
[
"setuptools>=61.0"
,
"pip"
,
"torch>=2.1"
]
# Use legacy backend to import local packages in setup.py
build-backend
=
"setuptools.build_meta:__legacy__"
transformer_engine/pytorch/setup.py
View file @
2b05e121
...
@@ -31,7 +31,7 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
...
@@ -31,7 +31,7 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from
build_tools.build_ext
import
get_build_ext
from
build_tools.build_ext
import
get_build_ext
from
build_tools.utils
import
copy_common_headers
from
build_tools.utils
import
copy_common_headers
from
build_tools.te_version
import
te_version
from
build_tools.te_version
import
te_version
from
build_tools.pytorch
import
setup_pytorch_extension
from
build_tools.pytorch
import
setup_pytorch_extension
,
install_requirements
,
test_requirements
os
.
environ
[
"NVTE_PROJECT_BUILDING"
]
=
"1"
os
.
environ
[
"NVTE_PROJECT_BUILDING"
]
=
"1"
...
@@ -55,18 +55,8 @@ if __name__ == "__main__":
...
@@ -55,18 +55,8 @@ if __name__ == "__main__":
description
=
"Transformer acceleration library - Torch Lib"
,
description
=
"Transformer acceleration library - Torch Lib"
,
ext_modules
=
ext_modules
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
},
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
},
setup_requires
=
[
install_requires
=
install_requirements
(),
"torch>=2.1"
,
tests_require
=
test_requirements
(),
"nvidia-cuda-runtime-cu12"
,
"nvidia-cublas-cu12"
,
"nvidia-cudnn-cu12"
,
"nvidia-cuda-cccl-cu12"
,
"nvidia-cuda-nvcc-cu12"
,
"nvidia-nvtx-cu12"
,
"nvidia-cuda-nvrtc-cu12"
,
],
install_requires
=
[
"torch>=2.1"
],
tests_require
=
[
"numpy"
,
"torchvision"
],
)
)
if
any
(
x
in
sys
.
argv
for
x
in
(
"."
,
"sdist"
,
"bdist_wheel"
)):
if
any
(
x
in
sys
.
argv
for
x
in
(
"."
,
"sdist"
,
"bdist_wheel"
)):
shutil
.
rmtree
(
common_headers_dir
)
shutil
.
rmtree
(
common_headers_dir
)
...
...
transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
View file @
2b05e121
...
@@ -11,6 +11,7 @@ import torch
...
@@ -11,6 +11,7 @@ import torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine_torch
import
Float8BlockScaleTensorFormat
from
..quantized_tensor
import
QuantizedTensorBase
from
..quantized_tensor
import
QuantizedTensorBase
...
@@ -37,6 +38,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
...
@@ -37,6 +38,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
_rowwise_scale_inv
:
Optional
[
torch
.
Tensor
]
_rowwise_scale_inv
:
Optional
[
torch
.
Tensor
]
_columnwise_scale_inv
:
Optional
[
torch
.
Tensor
]
_columnwise_scale_inv
:
Optional
[
torch
.
Tensor
]
_is_2D_scaled
:
bool
_is_2D_scaled
:
bool
_data_format
:
Float8BlockScaleTensorFormat
def
__new__
(
def
__new__
(
cls
,
cls
,
...
@@ -48,6 +50,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
...
@@ -48,6 +50,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
fp8_dtype
:
TE_DType
,
fp8_dtype
:
TE_DType
,
quantizer
:
Quantizer
,
quantizer
:
Quantizer
,
is_2D_scaled
:
bool
,
is_2D_scaled
:
bool
,
data_format
:
Float8BlockScaleTensorFormat
=
Float8BlockScaleTensorFormat
.
GEMM_READY
,
**
kwargs
,
**
kwargs
,
):
):
instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
...
@@ -58,6 +61,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
...
@@ -58,6 +61,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
instance
.
_rowwise_scale_inv
=
rowwise_scale_inv
instance
.
_rowwise_scale_inv
=
rowwise_scale_inv
instance
.
_columnwise_scale_inv
=
columnwise_scale_inv
instance
.
_columnwise_scale_inv
=
columnwise_scale_inv
instance
.
_is_2D_scaled
=
is_2D_scaled
instance
.
_is_2D_scaled
=
is_2D_scaled
instance
.
_data_format
=
data_format
return
instance
return
instance
...
@@ -82,8 +86,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
...
@@ -82,8 +86,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
"fp8_dtype"
:
self
.
_fp8_dtype
,
"fp8_dtype"
:
self
.
_fp8_dtype
,
"quantizer"
:
self
.
_quantizer
,
"quantizer"
:
self
.
_quantizer
,
"is_2D_scaled"
:
self
.
_is_2D_scaled
,
"is_2D_scaled"
:
self
.
_is_2D_scaled
,
"data_format"
:
self
.
_data_format
,
}
}
def
_is_gemm_ready_format
(
self
)
->
bool
:
"""Whether data is in GEMM_READY format"""
return
self
.
_data_format
==
Float8BlockScaleTensorFormat
.
GEMM_READY
def
prepare_for_saving
(
def
prepare_for_saving
(
self
,
self
,
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
Float8BlockwiseQTensorBase
]:
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
Float8BlockwiseQTensorBase
]:
...
@@ -136,34 +145,69 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
...
@@ -136,34 +145,69 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
q_K
=
q
.
shape
[
-
1
]
q_K
=
q
.
shape
[
-
1
]
for
i
in
range
(
len
(
q
.
shape
)
-
1
):
for
i
in
range
(
len
(
q
.
shape
)
-
1
):
q_M
*=
q
.
shape
[
i
]
q_M
*=
q
.
shape
[
i
]
inner_q_dimension_tiled
=
True
if
self
.
_is_gemm_ready_format
():
scales_tiled_dim
,
scales_untiled_dim
=
scale_inv
.
shape
inner_scale_dimension_tiled
=
False
scales_are_compact
=
False
else
:
scales_untiled_dim
,
scales_tiled_dim
=
scale_inv
.
shape
inner_scale_dimension_tiled
=
True
scales_are_compact
=
True
else
:
else
:
assert
self
.
_columnwise_data
is
not
None
,
"No data to dequantize"
assert
self
.
_columnwise_data
is
not
None
,
"No data to dequantize"
q
=
self
.
_columnwise_data
q
=
self
.
_columnwise_data
scale_inv
=
self
.
_columnwise_scale_inv
scale_inv
=
self
.
_columnwise_scale_inv
transpose_output
=
True
scales_tiled_dim
,
scales_untiled_dim
=
scale_inv
.
shape
if
len
(
q
.
shape
)
>=
1
:
inner_scale_dimension_tiled
=
False
q_M
=
q
.
shape
[
0
]
if
self
.
_is_gemm_ready_format
():
for
i
in
range
(
1
,
len
(
q
.
shape
)):
inner_q_dimension_tiled
=
True
q_K
*=
q
.
shape
[
i
]
transpose_output
=
True
if
len
(
q
.
shape
)
>=
1
:
q_M
=
q
.
shape
[
0
]
for
i
in
range
(
1
,
len
(
q
.
shape
)):
q_K
*=
q
.
shape
[
i
]
scales_are_compact
=
False
else
:
inner_q_dimension_tiled
=
False
transpose_output
=
False
if
len
(
q
.
shape
)
>=
1
:
q_K
=
q
.
shape
[
-
1
]
for
i
in
range
(
len
(
q
.
shape
)
-
1
):
q_M
*=
q
.
shape
[
i
]
scales_are_compact
=
True
orig_shape
=
q
.
shape
orig_shape
=
q
.
shape
q
=
q
.
reshape
(
q_M
,
q_K
)
q
=
q
.
reshape
(
q_M
,
q_K
)
k_tiles
,
scale_m
=
scale_inv
.
shape
if
inner_q_dimension_tiled
:
if
q_K
%
block_len
!=
0
:
if
q_K
%
block_len
!=
0
:
k_pad_amount
=
(
block_len
-
(
q_K
%
block_len
))
%
block_len
k_pad_amount
=
(
block_len
-
(
q_K
%
block_len
))
%
block_len
q
=
torch
.
nn
.
functional
.
pad
(
q
=
torch
.
nn
.
functional
.
pad
(
q
,
(
0
,
k_pad_amount
,
0
,
0
),
mode
=
"constant"
,
value
=
0
q
,
(
0
,
k_pad_amount
,
0
,
0
),
mode
=
"constant"
,
value
=
0
).
contiguous
()
).
contiguous
()
_
,
padded_K
=
q
.
shape
padded_M
,
padded_K
=
q
.
shape
q_tiled
=
q
.
reshape
(
q_M
,
k_tiles
,
block_len
)
q_tiled
=
q
.
reshape
(
q_M
,
scales_tiled_dim
,
block_len
)
if
scale_m
>
q_M
:
else
:
# scale_m is 4 element aligned.
if
q_M
%
block_len
!=
0
:
m_pad_amount
=
(
block_len
-
(
q_M
%
block_len
))
%
block_len
q
=
torch
.
nn
.
functional
.
pad
(
q
,
(
0
,
0
,
0
,
m_pad_amount
),
mode
=
"constant"
,
value
=
0
).
contiguous
()
padded_M
,
padded_K
=
q
.
shape
q_tiled
=
q
.
reshape
(
scales_tiled_dim
,
block_len
,
q_K
)
if
not
scales_are_compact
and
scales_untiled_dim
>
q_M
:
# untiled scale dimension is 4 element aligned.
scale_inv
=
scale_inv
[:,
:
q_M
].
contiguous
()
scale_inv
=
scale_inv
[:,
:
q_M
].
contiguous
()
dq_scale
=
scale_inv
.
transpose
(
-
2
,
-
1
).
contiguous
().
reshape
(
q_M
,
k_tiles
,
1
)
if
scales_are_compact
and
inner_scale_dimension_tiled
:
dq_scale
=
scale_inv
.
contiguous
().
reshape
(
q_M
,
scales_tiled_dim
,
1
)
elif
scales_are_compact
and
not
inner_scale_dimension_tiled
:
dq_scale
=
scale_inv
.
contiguous
().
reshape
(
scales_tiled_dim
,
1
,
q_K
)
else
:
dq_scale
=
scale_inv
.
transpose
(
-
2
,
-
1
).
contiguous
().
reshape
(
q_M
,
scales_tiled_dim
,
1
)
torch_q_dtype
=
TE_DType_To_Torch
[
self
.
_fp8_dtype
]
torch_q_dtype
=
TE_DType_To_Torch
[
self
.
_fp8_dtype
]
result
=
q_tiled
.
view
(
torch_q_dtype
).
to
(
torch
.
float32
)
*
dq_scale
result
=
q_tiled
.
view
(
torch_q_dtype
).
to
(
torch
.
float32
)
*
dq_scale
if
padded_K
!=
q_K
:
if
padded_M
!=
q_M
or
padded_K
!=
q_K
:
result
=
result
.
reshape
(
q
_M
,
padded_K
)[:,
:
q_K
]
result
=
result
.
reshape
(
padded
_M
,
padded_K
)[:
q_M
,
:
q_K
]
result
=
result
.
to
(
dtype
)
result
=
result
.
to
(
dtype
)
if
len
(
orig_shape
)
==
0
:
if
len
(
orig_shape
)
==
0
:
result
=
result
.
reshape
([])
result
=
result
.
reshape
([])
...
@@ -182,6 +226,12 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
...
@@ -182,6 +226,12 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
if
not
self
.
_is_2D_scaled
:
if
not
self
.
_is_2D_scaled
:
return
self
.
_dequantize_vectorwise
(
dtype
=
dtype
)
return
self
.
_dequantize_vectorwise
(
dtype
=
dtype
)
if
not
self
.
_is_gemm_ready_format
():
raise
NotImplementedError
(
"Dequantize is only supported with GEMM_READY data format, "
f
"but found _data_format=
{
self
.
_data_format
}
"
)
def
format_scale_as_logical_shape
(
q_K
,
scales
,
block_len
):
def
format_scale_as_logical_shape
(
q_K
,
scales
,
block_len
):
# The GEMM for 2D blocks required padding in the scales.
# The GEMM for 2D blocks required padding in the scales.
derived_scale_k_shape
=
math
.
ceil
(
q_K
/
block_len
)
derived_scale_k_shape
=
math
.
ceil
(
q_K
/
block_len
)
...
@@ -247,6 +297,8 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
...
@@ -247,6 +297,8 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
if
self
.
_rowwise_data
is
not
None
:
if
self
.
_rowwise_data
is
not
None
:
return
self
.
_rowwise_data
.
size
(
*
args
,
**
kwargs
)
return
self
.
_rowwise_data
.
size
(
*
args
,
**
kwargs
)
dims
=
list
(
self
.
_columnwise_data
.
size
(
*
args
,
**
kwargs
))
dims
=
list
(
self
.
_columnwise_data
.
size
(
*
args
,
**
kwargs
))
if
not
self
.
_is_gemm_ready_format
():
# compact format
return
torch
.
Size
(
dims
)
reordered
=
[]
reordered
=
[]
for
i
in
range
(
1
,
len
(
dims
)):
for
i
in
range
(
1
,
len
(
dims
)):
reordered
.
append
(
dims
[
i
])
reordered
.
append
(
dims
[
i
])
...
@@ -285,6 +337,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
...
@@ -285,6 +337,13 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
w
=
min
(
self
.
_columnwise_scale_inv
.
shape
[
1
],
columnwise_scale_inv
.
shape
[
1
])
w
=
min
(
self
.
_columnwise_scale_inv
.
shape
[
1
],
columnwise_scale_inv
.
shape
[
1
])
self
.
_columnwise_scale_inv
[
0
:
h
,
0
:
w
].
copy_
(
columnwise_scale_inv
[
0
:
h
,
0
:
w
])
self
.
_columnwise_scale_inv
[
0
:
h
,
0
:
w
].
copy_
(
columnwise_scale_inv
[
0
:
h
,
0
:
w
])
def
_transpose_columnwise_data
(
self
):
"""Plainly transpose the columnwise data and scale inv."""
if
self
.
_columnwise_data
is
not
None
:
self
.
_columnwise_data
=
tex
.
fp8_transpose
(
self
.
_columnwise_data
,
self
.
_fp8_dtype
,
out
=
None
)
def
__repr__
(
self
):
def
__repr__
(
self
):
if
self
.
_rowwise_data
is
not
None
:
if
self
.
_rowwise_data
is
not
None
:
data
=
self
.
dequantize
()
data
=
self
.
dequantize
()
...
...
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
View file @
2b05e121
...
@@ -4,13 +4,15 @@
...
@@ -4,13 +4,15 @@
"""Tensor class with FP8 data quantized with NxN tiles"""
"""Tensor class with FP8 data quantized with NxN tiles"""
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Optional
,
Tuple
,
Iterable
from
typing
import
Optional
,
Tuple
,
Iterable
,
Union
import
math
import
math
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
import
os
import
os
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine.common.recipe
import
Float8BlockScaling
,
Recipe
from
._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
from
..utils
import
devices_match
,
round_up_to_nearest_multiple
from
..utils
import
devices_match
,
round_up_to_nearest_multiple
...
@@ -33,6 +35,8 @@ class Float8BlockQuantizer(Quantizer):
...
@@ -33,6 +35,8 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon
:
float
amax_epsilon
:
float
force_pow_2_scales
:
bool
force_pow_2_scales
:
bool
block_scaling_dim
:
int
block_scaling_dim
:
int
# Whether to produce tensors that will be used in all-gather
all_gather_usage
:
bool
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -43,6 +47,7 @@ class Float8BlockQuantizer(Quantizer):
...
@@ -43,6 +47,7 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon
:
float
=
0.0
,
amax_epsilon
:
float
=
0.0
,
force_pow_2_scales
:
bool
=
True
,
force_pow_2_scales
:
bool
=
True
,
block_scaling_dim
:
int
=
2
,
block_scaling_dim
:
int
=
2
,
all_gather_usage
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
dtype
=
tex
.
DType
.
kInt8
if
int8_simulation_fp8
else
fp8_dtype
self
.
dtype
=
tex
.
DType
.
kInt8
if
int8_simulation_fp8
else
fp8_dtype
...
@@ -50,6 +55,7 @@ class Float8BlockQuantizer(Quantizer):
...
@@ -50,6 +55,7 @@ class Float8BlockQuantizer(Quantizer):
self
.
force_pow_2_scales
=
force_pow_2_scales
self
.
force_pow_2_scales
=
force_pow_2_scales
self
.
amax_epsilon
=
amax_epsilon
self
.
amax_epsilon
=
amax_epsilon
self
.
block_scaling_dim
=
block_scaling_dim
self
.
block_scaling_dim
=
block_scaling_dim
self
.
all_gather_usage
=
all_gather_usage
def
update_quantized
(
def
update_quantized
(
self
,
self
,
...
@@ -126,22 +132,36 @@ class Float8BlockQuantizer(Quantizer):
...
@@ -126,22 +132,36 @@ class Float8BlockQuantizer(Quantizer):
M
*=
shape
[
i
]
M
*=
shape
[
i
]
if
len
(
shape
)
>
0
:
if
len
(
shape
)
>
0
:
K
=
shape
[
-
1
]
K
=
shape
[
-
1
]
# 2D 128x128 quantization block scaling
# CuBLAS requries 128x128 scaling factor to be padded
# currently rowwise and columnwise format option doesn't apply to 2D scaling
if
self
.
block_scaling_dim
==
2
:
if
self
.
block_scaling_dim
==
2
:
if
columnwise
:
if
columnwise
:
outer
=
math
.
ceil
(
K
/
self
.
block_len
)
outer
=
math
.
ceil
(
K
/
self
.
block_len
)
inner
=
round_up_to_nearest_multiple
(
math
.
ceil
(
M
/
self
.
block_len
),
4
)
inner
=
round_up_to_nearest_multiple
(
math
.
ceil
(
M
/
self
.
block_len
),
4
)
return
(
outer
,
inner
)
return
(
outer
,
inner
)
# rowwise
outer
=
math
.
ceil
(
M
/
self
.
block_len
)
outer
=
math
.
ceil
(
M
/
self
.
block_len
)
inner
=
round_up_to_nearest_multiple
(
math
.
ceil
(
K
/
self
.
block_len
),
4
)
inner
=
round_up_to_nearest_multiple
(
math
.
ceil
(
K
/
self
.
block_len
),
4
)
return
(
outer
,
inner
)
return
(
outer
,
inner
)
# 1D 1x128 quantization block scaling
# CuBLAS requries 1x128 scaling factor to be padded and transposed
assert
self
.
block_scaling_dim
==
1
,
"Only 1D or 2D blocks supported"
assert
self
.
block_scaling_dim
==
1
,
"Only 1D or 2D blocks supported"
if
columnwise
:
if
columnwise
:
columnwise_compact
=
self
.
all_gather_usage
outer
=
math
.
ceil
(
M
/
self
.
block_len
)
outer
=
math
.
ceil
(
M
/
self
.
block_len
)
inner
=
round_up_to_nearest_multiple
(
K
,
4
)
inner
=
round_up_to_nearest_multiple
(
K
,
4
)
if
not
columnwise_compact
else
K
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS
# for COMPACT case, since we apply 1x128 scaling here without transposing columnwise data, scaling factor is also [outer, inner]
# so no need to swap inner outer here
return
(
outer
,
inner
)
return
(
outer
,
inner
)
# rowwise
rowwise_compact
=
self
.
all_gather_usage
outer
=
math
.
ceil
(
K
/
self
.
block_len
)
outer
=
math
.
ceil
(
K
/
self
.
block_len
)
inner
=
round_up_to_nearest_multiple
(
M
,
4
)
inner
=
round_up_to_nearest_multiple
(
M
,
4
)
if
not
rowwise_compact
else
M
return
(
outer
,
inner
)
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS need
# for COMPACT case, since we apply 128x1 scaling, scaling block applies to inner dim, so we need to swap outer and inner here
return
(
outer
,
inner
)
if
not
rowwise_compact
else
(
inner
,
outer
)
def
get_columnwise_shape
(
self
,
shape
:
Iterable
[
int
])
->
Tuple
[
int
,
...]:
def
get_columnwise_shape
(
self
,
shape
:
Iterable
[
int
])
->
Tuple
[
int
,
...]:
"""Calculate the shape of a tensor after columnwise permutation.
"""Calculate the shape of a tensor after columnwise permutation.
...
@@ -163,15 +183,25 @@ class Float8BlockQuantizer(Quantizer):
...
@@ -163,15 +183,25 @@ class Float8BlockQuantizer(Quantizer):
"""
"""
if
len
(
shape
)
==
0
:
if
len
(
shape
)
==
0
:
return
tuple
()
return
tuple
()
# currently columnwise format option only applies to 1D quantizer
# for 2D scaling, columnwise format should always be GEMM_READY_DATA_AND_SCALES
# since currently 2D scaling only applies to module weights
if
self
.
block_scaling_dim
==
1
and
self
.
all_gather_usage
:
return
shape
colwise_shape
=
[
shape
[
-
1
]]
colwise_shape
=
[
shape
[
-
1
]]
for
i
in
range
(
len
(
shape
)
-
1
):
for
i
in
range
(
len
(
shape
)
-
1
):
colwise_shape
.
append
(
shape
[
i
])
colwise_shape
.
append
(
shape
[
i
])
return
tuple
(
colwise_shape
)
return
tuple
(
colwise_shape
)
# TODO(kwyss): With FP8 gather support, we need to implement a
def
is_quantizable
(
self
,
inp
:
torch
.
Tensor
)
->
bool
:
# shape/layout/swizzle check to know whether FP8 gather works
"""Returns whether or not given inp can be quantized"""
# cleanly by stacking data without aliasing tiles and whether
if
inp
.
ndim
<
2
:
# the scales also stack on the proper dimensions.
return
False
if
inp
.
shape
[
-
1
]
%
self
.
block_len
!=
0
:
return
False
if
math
.
prod
(
inp
.
shape
[:
-
1
])
%
self
.
block_len
!=
0
:
return
False
return
True
def
make_empty
(
def
make_empty
(
self
,
self
,
...
@@ -185,6 +215,12 @@ class Float8BlockQuantizer(Quantizer):
...
@@ -185,6 +215,12 @@ class Float8BlockQuantizer(Quantizer):
if
device
is
None
:
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
data_format
=
(
tex
.
Float8BlockScaleTensorFormat
.
COMPACT
if
self
.
all_gather_usage
else
tex
.
Float8BlockScaleTensorFormat
.
GEMM_READY
)
# Allocate FP8 data
# Allocate FP8 data
data
=
None
data
=
None
scale_inv
=
None
scale_inv
=
None
...
@@ -222,6 +258,7 @@ class Float8BlockQuantizer(Quantizer):
...
@@ -222,6 +258,7 @@ class Float8BlockQuantizer(Quantizer):
columnwise_scale_inv
=
columnwise_scale_inv
,
columnwise_scale_inv
=
columnwise_scale_inv
,
quantizer
=
self
,
quantizer
=
self
,
is_2D_scaled
=
self
.
block_scaling_dim
==
2
,
is_2D_scaled
=
self
.
block_scaling_dim
==
2
,
data_format
=
data_format
,
requires_grad
=
requires_grad
,
requires_grad
=
requires_grad
,
)
)
...
@@ -230,6 +267,9 @@ class Float8BlockQuantizer(Quantizer):
...
@@ -230,6 +267,9 @@ class Float8BlockQuantizer(Quantizer):
# where state from an estimator influences distribution parameters.
# where state from an estimator influences distribution parameters.
pass
pass
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
return
Float8BlockScaling
class
Float8BlockwiseQTensor
(
Float8BlockwiseQTensorBase
,
QuantizedTensor
):
class
Float8BlockwiseQTensor
(
Float8BlockwiseQTensorBase
,
QuantizedTensor
):
"""Tensor class with FP8 data quantized via NxN blocks or 1xN blocks.
"""Tensor class with FP8 data quantized via NxN blocks or 1xN blocks.
...
@@ -260,7 +300,8 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
...
@@ -260,7 +300,8 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
return
(
return
(
f
"Float8BlockwiseQTensor(fp8_dtype=
{
self
.
_fp8_dtype
}
,"
f
"Float8BlockwiseQTensor(fp8_dtype=
{
self
.
_fp8_dtype
}
,"
f
" is_2D_scaled=
{
self
.
_is_2D_scaled
}
,"
f
" is_2D_scaled=
{
self
.
_is_2D_scaled
}
,"
f
" data=
{
self
.
dequantize
(
dtype
=
self
.
dtype
)
}
)"
f
" data=
{
self
.
dequantize
(
dtype
=
self
.
dtype
)
}
),"
f
" data_format=
{
self
.
_data_format
}
"
)
)
def
_get_quantizer
(
self
)
->
Quantizer
:
def
_get_quantizer
(
self
)
->
Quantizer
:
...
@@ -393,6 +434,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
...
@@ -393,6 +434,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
quantizer
:
Quantizer
,
quantizer
:
Quantizer
,
is_2D_scaled
:
bool
,
is_2D_scaled
:
bool
,
data_format
:
tex
.
Float8BlockScaleTensorFormat
,
)
->
Float8BlockwiseQTensor
:
)
->
Float8BlockwiseQTensor
:
"""Build Float8BlockwiseQTensor, for use in __reduce__
"""Build Float8BlockwiseQTensor, for use in __reduce__
...
@@ -410,6 +452,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
...
@@ -410,6 +452,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dtype
=
dtype
,
dtype
=
dtype
,
quantizer
=
quantizer
,
quantizer
=
quantizer
,
is_2D_scaled
=
is_2D_scaled
,
is_2D_scaled
=
is_2D_scaled
,
data_format
=
data_format
,
)
)
def
__reduce_ex__
(
self
,
protocol
:
int
)
->
tuple
:
def
__reduce_ex__
(
self
,
protocol
:
int
)
->
tuple
:
...
@@ -426,6 +469,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
...
@@ -426,6 +469,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
self
.
dtype
,
self
.
dtype
,
self
.
_quantizer
,
self
.
_quantizer
,
self
.
_is_2D_scaled
,
self
.
_is_2D_scaled
,
self
.
_data_format
,
),
),
)
)
...
@@ -451,6 +495,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
...
@@ -451,6 +495,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dst
.
_fp8_dtype
=
src
.
_fp8_dtype
dst
.
_fp8_dtype
=
src
.
_fp8_dtype
dst
.
_rowwise_scale_inv
=
src
.
_rowwise_scale_inv
dst
.
_rowwise_scale_inv
=
src
.
_rowwise_scale_inv
dst
.
_columnwise_scale_inv
=
src
.
_columnwise_scale_inv
dst
.
_columnwise_scale_inv
=
src
.
_columnwise_scale_inv
dst
.
_data_format
=
src
.
_data_format
# Check that tensor dimensions match
# Check that tensor dimensions match
if
(
if
(
...
@@ -498,6 +543,13 @@ class _ViewFunc(torch.autograd.Function):
...
@@ -498,6 +543,13 @@ class _ViewFunc(torch.autograd.Function):
)
->
Float8BlockwiseQTensor
:
)
->
Float8BlockwiseQTensor
:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
# Check for invalid configurations
if
not
tensor
.
_is_gemm_ready_format
():
raise
NotImplementedError
(
"View is only supported with GEMM_READY data format, "
f
"but found data_format=
{
tensor
.
_data_format
}
"
)
# Return input tensor if shape is not provided
# Return input tensor if shape is not provided
ctx
.
shape
=
tensor
.
shape
ctx
.
shape
=
tensor
.
shape
if
shape
is
None
:
if
shape
is
None
:
...
@@ -566,6 +618,14 @@ class _ViewFunc(torch.autograd.Function):
...
@@ -566,6 +618,14 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
if
isinstance
(
grad
,
Float8BlockwiseQTensor
):
if
isinstance
(
grad
,
Float8BlockwiseQTensor
):
# Check for invalid configurations
if
not
grad
.
_is_gemm_ready_format
():
raise
NotImplementedError
(
"View is only supported with GEMM_READY data format, "
f
"but found data_format=
{
grad
.
_data_format
}
"
)
new_data
=
(
new_data
=
(
grad
.
_rowwise_data
.
view
(
*
ctx
.
shape
)
if
grad
.
_rowwise_data
is
not
None
else
None
grad
.
_rowwise_data
.
view
(
*
ctx
.
shape
)
if
grad
.
_rowwise_data
is
not
None
else
None
)
)
...
@@ -605,6 +665,13 @@ class _ReshapeFunc(torch.autograd.Function):
...
@@ -605,6 +665,13 @@ class _ReshapeFunc(torch.autograd.Function):
)
->
Float8BlockwiseQTensor
:
)
->
Float8BlockwiseQTensor
:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
# Check for invalid configurations
if
not
tensor
.
_is_gemm_ready_format
():
raise
NotImplementedError
(
"Reshape is only supported with GEMM_READY data format, "
f
"but found data_format=
{
tensor
.
_data_format
}
"
)
# Return input tensor if shape is not provided
# Return input tensor if shape is not provided
ctx
.
shape
=
tensor
.
shape
ctx
.
shape
=
tensor
.
shape
if
shape
is
None
:
if
shape
is
None
:
...
@@ -672,6 +739,14 @@ class _ReshapeFunc(torch.autograd.Function):
...
@@ -672,6 +739,14 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
if
isinstance
(
grad
,
Float8BlockwiseQTensor
):
if
isinstance
(
grad
,
Float8BlockwiseQTensor
):
# Check for invalid configurations
if
not
grad
.
_is_gemm_ready_format
():
raise
NotImplementedError
(
"Reshape is only supported with GEMM_READY data format, "
f
"but found data_format=
{
grad
.
_data_format
}
"
)
new_rowwise_data
=
None
new_rowwise_data
=
None
new_columnwise_data
=
None
new_columnwise_data
=
None
if
grad
.
_rowwise_data
is
not
None
:
if
grad
.
_rowwise_data
is
not
None
:
...
...
Prev
1
…
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment