Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
646
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1806 additions
and
991 deletions
+1806
-991
transformer_engine/pytorch/custom_recipes/gemm.py
transformer_engine/pytorch/custom_recipes/gemm.py
+1
-1
transformer_engine/pytorch/custom_recipes/quantization.py
transformer_engine/pytorch/custom_recipes/quantization.py
+1
-1
transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py
...ne/pytorch/custom_recipes/quantization_current_scaling.py
+525
-0
transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py
...ormer_engine/pytorch/custom_recipes/quantization_nvfp4.py
+4
-4
transformer_engine/pytorch/custom_recipes/utils.py
transformer_engine/pytorch/custom_recipes/utils.py
+1
-1
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+170
-163
transformer_engine/pytorch/export.py
transformer_engine/pytorch/export.py
+2
-2
transformer_engine/pytorch/float8_tensor.py
transformer_engine/pytorch/float8_tensor.py
+1
-1
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+1
-1
transformer_engine/pytorch/graph.py
transformer_engine/pytorch/graph.py
+156
-60
transformer_engine/pytorch/jit.py
transformer_engine/pytorch/jit.py
+2
-2
transformer_engine/pytorch/module/__init__.py
transformer_engine/pytorch/module/__init__.py
+1
-1
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+1
-1
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+74
-108
transformer_engine/pytorch/module/fp8_padding.py
transformer_engine/pytorch/module/fp8_padding.py
+19
-21
transformer_engine/pytorch/module/fp8_unpadding.py
transformer_engine/pytorch/module/fp8_unpadding.py
+20
-22
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+194
-130
transformer_engine/pytorch/module/layernorm.py
transformer_engine/pytorch/module/layernorm.py
+10
-13
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+111
-154
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+512
-305
No files found.
transformer_engine/pytorch/custom_recipes/gemm.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/custom_recipes/quantization.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py
0 → 100644
View file @
0d874a4e
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Current scaling recipe reference implementation."""
import
dataclasses
import
math
from
typing
import
Optional
,
Tuple
,
Iterable
import
torch
from
transformer_engine.pytorch.custom_recipes
import
quantization
from
transformer_engine.pytorch.custom_recipes
import
utils
from
transformer_engine.pytorch.quantized_tensor
import
QuantizedTensorStorage
,
Quantizer
def
current_scaling_ref_quantizer_factory
(
role
):
"""Factory function for current scaling reference quantizer.
Usage with CustomRecipe and autocast:
custom_recipe = recipe.CustomRecipe(qfactory=current_scaling_ref_quantizer_factory)
with autocast(recipe=custom_recipe):
output = model(input)
"""
if
role
in
(
"linear_input"
,
"linear_weight"
):
dtype
=
torch
.
float8_e4m3fn
elif
role
in
(
"linear_output"
,
"linear_grad_output"
):
dtype
=
torch
.
float8_e5m2
else
:
return
None
return
CurrentScalingQuantizerRef
(
dtype
=
dtype
,
rowwise
=
True
,
columnwise
=
True
,
pow_2_scales
=
False
,
eps
=
0.0
,
)
@
dataclasses
.
dataclass
class
CurrentScalingTensorRef
(
QuantizedTensorStorage
):
"""Reference implementation of current scaling quantized tensor"""
data
:
Optional
[
torch
.
Tensor
]
=
None
scale
:
Optional
[
torch
.
Tensor
]
=
None
data_t
:
Optional
[
torch
.
Tensor
]
=
None
scale_t
:
Optional
[
torch
.
Tensor
]
=
None
dtype
:
Optional
[
torch
.
dtype
]
=
None
device
:
Optional
[
torch
.
device
]
=
None
quant_dtype
:
Optional
[
torch
.
dtype
]
=
None
original_shape
:
Optional
[
Tuple
[
int
,
...]]
=
None
_quantizer
:
Optional
[
Quantizer
]
=
None
@
property
def
custom
(
self
)
->
bool
:
"""Flag to indicate this quantized tensor is custom."""
return
True
def
prepare_for_saving
(
self
,
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
QuantizedTensorStorage
]:
"""Prepare the quantization result for saving for backward"""
tensors
=
[
self
.
data
,
self
.
data_t
,
self
.
scale
,
self
.
scale_t
]
self
.
data
=
None
self
.
data_t
=
None
self
.
scale
=
None
self
.
scale_t
=
None
return
tensors
,
self
def
restore_from_saved
(
self
,
tensors
:
list
[
Optional
[
torch
.
Tensor
]]
)
->
list
[
Optional
[
torch
.
Tensor
]]:
"""Restore the quantization result from the saved tensors"""
self
.
data
=
tensors
[
0
]
self
.
data_t
=
tensors
[
1
]
self
.
scale
=
tensors
[
2
]
self
.
scale_t
=
tensors
[
3
]
return
tensors
[
4
:]
# Compatibility
@
property
def
_data
(
self
):
return
self
.
data
@
_data
.
setter
def
_data
(
self
,
value
):
self
.
data
=
value
@
property
def
_scale_inv
(
self
):
return
self
.
scale
@
_scale_inv
.
setter
def
_scale_inv
(
self
,
value
):
self
.
scale
=
value
def
__repr__
(
self
):
return
(
f
"
{
self
.
__class__
.
__name__
}
("
f
"dtype=
{
self
.
dtype
}
, "
f
"device=
{
self
.
device
}
, "
f
"quant_dtype=
{
self
.
quant_dtype
}
, "
f
"original_shape=
{
self
.
original_shape
}
"
")"
)
def
update_usage
(
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
columnwise_usage
:
Optional
[
bool
]
=
None
,
):
"""Generate or remove quantized data based on provided usage."""
has_data
=
self
.
data
is
not
None
has_data_transpose
=
self
.
data_t
is
not
None
needs_data
=
has_data
needs_data_transpose
=
has_data_transpose
if
rowwise_usage
is
not
None
:
needs_data
=
rowwise_usage
if
columnwise_usage
is
not
None
:
needs_data_transpose
=
columnwise_usage
# Generate data that is required
if
needs_data
and
not
has_data
:
raise
RuntimeError
(
"Cannot generate FP8 data, even from FP8 data transpose"
)
if
needs_data_transpose
and
not
has_data_transpose
:
if
not
has_data
:
raise
RuntimeError
(
"FP8 data is required to generate FP8 data transpose"
)
self
.
_create_transpose
()
# Delete data that is not required
if
not
needs_data
:
self
.
data
=
None
if
not
needs_data_transpose
:
self
.
data_t
=
None
def
_create_transpose
(
self
):
"""Create transposed quantized tensor"""
if
not
self
.
data
.
is_contiguous
():
self
.
data
=
self
.
data
.
contiguous
()
self
.
data_t
=
self
.
data
.
t
().
contiguous
()
self
.
scale_t
=
self
.
scale
def
size
(
self
,
*
args
,
**
kwargs
):
"""Get the size of the quantized tensor"""
if
self
.
data
is
not
None
:
return
self
.
data
.
size
(
*
args
,
**
kwargs
)
size
=
self
.
data_t
.
size
(
*
args
,
**
kwargs
)
return
torch
.
Size
([
size
[
-
1
],
math
.
prod
(
size
[:
-
1
])])
def
_scale_from_amax_tensor
(
x_dtype
:
torch
.
dtype
,
amax
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
*
,
eps
:
float
,
pow_2_scales
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Derives quantization and dequantization from amax and options.
Reference implementation for scale calculation.
Returns:
- scale: quantization scales
- scale_inv: dequantization scales
- amax: Amax tensor with updates made for extrema values.
"""
assert
amax
.
dtype
==
torch
.
float
,
"amax must be a float tensor."
fp8_max
=
torch
.
finfo
(
quant_dtype
).
max
# Clamping amax to avoid division by small numbers
amax
=
torch
.
max
(
amax
,
torch
.
tensor
(
eps
))
# Compute scale factor
scale
=
torch
.
div
(
fp8_max
,
amax
)
# Take care of inf before pow_2_scales
scale
=
torch
.
where
(
scale
==
torch
.
inf
,
torch
.
finfo
(
x_dtype
).
max
,
scale
)
if
pow_2_scales
:
_
,
exp
=
torch
.
frexp
(
scale
)
exp
=
exp
-
1
assert
(
exp
>
-
127
).
all
()
unity
=
torch
.
tensor
([
1.0
],
device
=
exp
.
device
)
torch
.
ldexp
(
unity
,
exp
,
out
=
scale
)
scale
=
torch
.
where
(
amax
==
float
(
"inf"
),
0.0
,
scale
)
# Handle overflow cases for amax zero causing NaN
scale
=
torch
.
where
(
amax
==
0
,
1.0
,
scale
)
# Compute scale_inv
scale_inv
=
torch
.
reciprocal
(
scale
)
return
scale
,
scale_inv
,
amax
class
CurrentScalingQuantizerRef
(
Quantizer
):
"""Reference implementation of current scaling quantizer"""
def
__init__
(
self
,
dtype
:
torch
.
dtype
,
rowwise
:
bool
=
True
,
columnwise
:
bool
=
True
,
pow_2_scales
:
bool
=
False
,
eps
:
float
=
0.0
,
):
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
internal
=
True
self
.
dtype
=
dtype
self
.
pow_2_scales
=
pow_2_scales
self
.
eps
=
eps
self
.
with_amax_reduction
=
False
self
.
amax_reduction_group
=
None
@
property
def
custom
(
self
)
->
bool
:
"""Flag to indicate this quantizer is custom."""
return
True
@
property
def
supports_allgather_fp8
(
self
)
->
bool
:
"""Flag to indicate this quantizer supports allgather fp8"""
return
True
@
classmethod
def
compute_scale
(
cls
,
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
eps
=
0.0
,
pow_2_scales
:
bool
=
False
,
):
"""Compute the scale from the amax tensor"""
# Use float32 for computation
x_fp32
=
x
.
to
(
torch
.
float32
)
if
x_fp32
.
numel
()
==
0
:
amax
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
x
.
device
)
else
:
amax
=
torch
.
amax
(
torch
.
abs
(
x_fp32
)).
view
(
1
)
return
_scale_from_amax_tensor
(
x
.
dtype
,
amax
=
amax
,
quant_dtype
=
quant_dtype
,
eps
=
eps
,
pow_2_scales
=
pow_2_scales
,
)
def
_quantize
(
self
,
tensor
:
torch
.
Tensor
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
]:
"""
Python implementation of quantization (c++ kernel can be used as an option instead).
Parameters
----------
tensor : torch.Tensor
Input tensor to quantize (should be 2D)
Returns
-------
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]
(qx, sx, qx_t, sx_t) where:
- qx: quantized data in row-major order (if rowwise_usage), None otherwise
- sx: empty scale tensor for qx (if rowwise_usage), None otherwise
- qx_t: quantized data in column-major order (if columnwise_usage), None otherwise
- sx_t: empty scale tensor for qx_t (if columnwise_usage), None otherwise
"""
# Handle amax reduction if enabled
if
self
.
with_amax_reduction
:
assert
(
self
.
amax_reduction_group
is
not
None
),
"amax_reduction_group must be set when with_amax_reduction is True"
# Compute local amax
if
tensor
.
numel
()
==
0
:
amax
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
else
:
amax
=
torch
.
amax
(
torch
.
abs
(
tensor
)).
view
(
1
).
to
(
torch
.
float32
)
# Reduce amax across all ranks
torch
.
distributed
.
all_reduce
(
amax
,
group
=
self
.
amax_reduction_group
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
# Compute scale using the global amax
scale
,
scale_inv
,
_
=
_scale_from_amax_tensor
(
tensor
.
dtype
,
amax
=
amax
,
quant_dtype
=
self
.
dtype
,
eps
=
self
.
eps
,
pow_2_scales
=
self
.
pow_2_scales
,
)
else
:
# compute scale factor using local amax
scale
,
scale_inv
,
_
=
self
.
compute_scale
(
tensor
,
self
.
dtype
,
eps
=
self
.
eps
,
pow_2_scales
=
self
.
pow_2_scales
,
)
qx
:
Optional
[
torch
.
Tensor
]
=
(
tensor
.
float
()
*
scale
).
to
(
self
.
dtype
)
sx
:
Optional
[
torch
.
Tensor
]
=
scale_inv
# transpose if needed
if
self
.
columnwise_usage
:
assert
qx
is
not
None
qx_t
=
qx
.
t
().
contiguous
()
sx_t
=
sx
else
:
qx_t
,
sx_t
=
None
,
None
if
not
self
.
rowwise_usage
:
qx
=
None
sx
=
None
return
qx
,
sx
,
qx_t
,
sx_t
def
quantize
(
self
,
tensor
:
torch
.
Tensor
,
**
kwargs
,
# pylint: disable=unused-argument
)
->
CurrentScalingTensorRef
:
# sanity checks
assert
tensor
.
dtype
in
utils
.
HIGH_PRECISION_FLOAT_DTYPES
,
"Unsupported input dtype."
# Make it work with 3D tensors
original_shape
=
tensor
.
shape
if
tensor
.
ndim
>
2
:
tensor
=
tensor
.
view
(
-
1
,
tensor
.
shape
[
-
1
])
qx
,
sx
,
qx_t
,
sx_t
=
self
.
_quantize
(
tensor
)
return
CurrentScalingTensorRef
(
data
=
qx
,
scale
=
sx
,
data_t
=
qx_t
,
scale_t
=
sx_t
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
,
quant_dtype
=
self
.
dtype
,
_quantizer
=
self
,
original_shape
=
original_shape
,
)
def
dequantize
(
self
,
tensor
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
torch
.
Tensor
:
"""Dequantize the quantized tensor"""
tensor
=
tensor
.
to
(
torch
.
float32
)
*
scale
if
dtype
is
None
:
return
tensor
return
tensor
.
to
(
dtype
)
def
qgemm
(
self
,
qx
:
torch
.
Tensor
,
qw
:
torch
.
Tensor
,
m_params
:
quantization
.
MMParams
,
out_dtype
:
torch
.
dtype
,
sx
:
torch
.
Tensor
,
sw
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
out
:
torch
.
Tensor
|
None
=
None
,
accumulate
:
bool
=
False
,
gemm_type
:
quantization
.
GEMMType
=
quantization
.
GEMMType
.
FPROP
,
# pylint: disable=unused-argument
qresult_x
:
QuantizedTensorStorage
|
None
=
None
,
# pylint: disable=unused-argument
qresult_w
:
QuantizedTensorStorage
|
None
=
None
,
# pylint: disable=unused-argument
)
->
torch
.
Tensor
:
"""Python implementation of quantized gemm."""
M
,
K
=
qx
.
shape
N
,
_
=
qw
.
shape
if
M
==
0
or
K
==
0
or
N
==
0
:
if
accumulate
:
assert
out
is
not
None
y
=
out
else
:
y
=
torch
.
zeros
((
M
,
N
),
dtype
=
out_dtype
,
device
=
qx
.
device
)
if
bias
is
not
None
:
y
+=
bias
return
y
# cublas fp8 gemm does not support fp32 bias
use_bias_in_gemm
=
(
bias
is
not
None
and
out_dtype
!=
torch
.
float32
and
bias
.
dtype
!=
torch
.
float32
)
# Run quantized gemm: y = qw * qx
scaled_mm_res
=
torch
.
_scaled_mm
(
qx
,
qw
.
transpose
(
-
1
,
-
2
),
scale_a
=
sx
,
scale_b
=
sw
,
out_dtype
=
out_dtype
,
use_fast_accum
=
not
m_params
.
use_split_accumulator
,
bias
=
bias
if
use_bias_in_gemm
else
None
,
)
y
=
scaled_mm_res
[
0
]
if
isinstance
(
scaled_mm_res
,
tuple
)
else
scaled_mm_res
if
bias
is
not
None
and
not
use_bias_in_gemm
:
# Check number of elements in bias tensor because it can be an empty tensor
if
bias
.
numel
():
y
+=
bias
if
accumulate
:
assert
out
is
not
None
,
"Output tensor must be provided for accumulation."
out
.
add_
(
y
)
y
=
out
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
return
y
def
transpose_qresult
(
self
,
qresult
:
CurrentScalingTensorRef
)
->
CurrentScalingTensorRef
:
"""Python implementation of transpose qresult."""
qx
=
qresult
.
data
scale
=
qresult
.
scale
assert
qresult
.
data_t
is
None
assert
qresult
.
scale_t
is
None
assert
qx
is
not
None
qx_t
=
qx
.
transpose
(
-
2
,
-
1
).
contiguous
()
scale_t
=
scale
qresult
.
data_t
=
qx_t
qresult
.
scale_t
=
scale_t
return
qresult
def
update_quantized
(
self
,
src
:
torch
.
Tensor
,
dst
:
QuantizedTensorStorage
,
*
,
noop_flag
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
QuantizedTensorStorage
:
"""Update the quantized tensor with the given tensor in-place
Parameters
----------
src: torch.Tensor
Source tensor to copy from
dst: ExperimentalQuantizedTensor
Destination ExperimentalQuantizedTensor to update
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
# Handle noop flag
if
noop_flag
is
not
None
and
noop_flag
.
item
()
!=
0
:
return
dst
# Make sure input is in expected format
if
not
src
.
is_contiguous
():
src
=
src
.
contiguous
()
# Store the original shape and reshape for processing
original_shape
=
src
.
shape
if
src
.
ndim
>
2
:
src
=
src
.
view
(
-
1
,
src
.
shape
[
-
1
])
qx
,
sx
,
qx_t
,
sx_t
=
self
.
_quantize
(
src
)
# Update the destination with new data
dst
.
data
=
qx
dst
.
scale
=
sx
dst
.
data_t
=
qx_t
dst
.
scale_t
=
sx_t
dst
.
dtype
=
src
.
dtype
dst
.
quant_dtype
=
self
.
dtype
dst
.
original_shape
=
original_shape
return
dst
def
make_empty
(
self
,
shape
:
Iterable
[
int
],
*
,
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
False
,
# pylint: disable=unused-argument
)
->
CurrentScalingTensorRef
:
assert
len
(
shape
)
==
2
,
"shape is not 2d"
# Canonicalize tensor attributes
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
# Allocate quantized data
qx
=
torch
.
empty
(
shape
,
dtype
=
self
.
dtype
,
device
=
device
)
sx
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
# Allocate quantized data transpose if needed
qx_t
=
None
sx_t
=
None
if
self
.
columnwise_usage
:
inner_dim
=
qx
.
size
(
-
1
)
qx_t
=
torch
.
empty
(
inner_dim
,
qx
.
numel
()
//
inner_dim
,
dtype
=
self
.
dtype
,
device
=
device
,
)
sx_t
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
# Construct quantized tensor
return
CurrentScalingTensorRef
(
data
=
qx
,
scale
=
sx
,
data_t
=
qx_t
,
scale_t
=
sx_t
,
dtype
=
dtype
,
device
=
device
,
quant_dtype
=
self
.
dtype
,
_quantizer
=
self
,
original_shape
=
shape
,
)
transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -18,9 +18,9 @@ def nvfp4_ref_rht_2d_quantizer_factory(role):
...
@@ -18,9 +18,9 @@ def nvfp4_ref_rht_2d_quantizer_factory(role):
"""
"""
Quantizer factory for NVFP4 recipe reference implementation (RHT and 2D quantization for weights).
Quantizer factory for NVFP4 recipe reference implementation (RHT and 2D quantization for weights).
Usage with CustomRecipe and
fp8_
autocast:
Usage with CustomRecipe and autocast:
custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory)
custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory)
with
fp8_
autocast(fp8_recipe=custom_recipe):
with autocast(fp8_recipe=custom_recipe):
output = model(input)
output = model(input)
"""
"""
if
role
==
"linear_input"
:
if
role
==
"linear_input"
:
...
@@ -338,7 +338,7 @@ def get_wgrad_sign_vector() -> torch.Tensor:
...
@@ -338,7 +338,7 @@ def get_wgrad_sign_vector() -> torch.Tensor:
class
NVFP4QuantizerRef
(
Quantizer
):
class
NVFP4QuantizerRef
(
Quantizer
):
"""
NVFP4 quantizer for middleware between Transformer Engine and Kitchen
"""
"""
Reference implementation of NVFP4 quantizer
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
...
transformer_engine/pytorch/custom_recipes/utils.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/distributed.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -30,7 +30,7 @@ except ImportError:
...
@@ -30,7 +30,7 @@ except ImportError:
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.triton.pad
import
pad_columnwise_scale_inv
from
transformer_engine.pytorch.triton.pad
import
pad_columnwise_scale_inv
from
.
import
torch_version
from
.
torch_version
import
torch_version
from
.utils
import
(
from
.utils
import
(
is_non_tn_fp8_gemm_supported
,
is_non_tn_fp8_gemm_supported
,
safely_set_viewless_tensor_data
,
safely_set_viewless_tensor_data
,
...
@@ -48,7 +48,7 @@ from .tensor.storage.float8_tensor_storage import Float8TensorStorage
...
@@ -48,7 +48,7 @@ from .tensor.storage.float8_tensor_storage import Float8TensorStorage
from
.tensor.storage.mxfp8_tensor_storage
import
MXFP8TensorStorage
from
.tensor.storage.mxfp8_tensor_storage
import
MXFP8TensorStorage
from
.tensor.storage.nvfp4_tensor_storage
import
NVFP4TensorStorage
from
.tensor.storage.nvfp4_tensor_storage
import
NVFP4TensorStorage
from
.tensor.storage.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
from
.tensor.storage.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
,
DebugQuantizer
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
__all__
=
[
"checkpoint"
,
"CudaRNGStatesTracker"
]
__all__
=
[
"checkpoint"
,
"CudaRNGStatesTracker"
]
...
@@ -90,6 +90,11 @@ def graph_safe_rng_available() -> bool:
...
@@ -90,6 +90,11 @@ def graph_safe_rng_available() -> bool:
)
)
def
is_graph_safe_rng_state
(
state
:
Union
[
torch
.
Tensor
,
torch
.
Generator
])
->
bool
:
"""Returns whether the rng state is a graph safe version."""
return
graph_safe_rng_available
()
and
isinstance
(
state
,
torch
.
Generator
)
def
_get_cuda_rng_state
(
def
_get_cuda_rng_state
(
device
:
Union
[
int
,
str
,
torch
.
device
]
=
"cuda"
,
device
:
Union
[
int
,
str
,
torch
.
device
]
=
"cuda"
,
clone
:
bool
=
False
,
clone
:
bool
=
False
,
...
@@ -340,9 +345,16 @@ class _CheckpointFunction(torch.autograd.Function):
...
@@ -340,9 +345,16 @@ class _CheckpointFunction(torch.autograd.Function):
# Copy the rng states.
# Copy the rng states.
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
ctx
.
fwd_cuda_rng_state
=
_get_cuda_rng_state
(
graph_safe
=
False
)
if
get_rng_state_tracker
is
not
None
:
if
get_rng_state_tracker
is
not
None
:
ctx
.
fwd_cuda_rng_state_tracker
=
get_rng_state_tracker
().
get_states
()
ctx
.
fwd_cuda_rng_state_tracker
=
get_rng_state_tracker
().
get_states
()
ctx
.
graph_safe_rng_state
=
(
is_graph_safe_rng_state
(
next
(
iter
(
ctx
.
fwd_cuda_rng_state_tracker
.
values
())))
if
ctx
.
fwd_cuda_rng_state_tracker
else
False
)
else
:
ctx
.
graph_safe_rng_state
=
False
ctx
.
fwd_cuda_rng_state
=
_get_cuda_rng_state
(
graph_safe
=
ctx
.
graph_safe_rng_state
)
if
context_fn
is
not
None
:
if
context_fn
is
not
None
:
forward_ctx
,
recompute_ctx
=
context_fn
()
forward_ctx
,
recompute_ctx
=
context_fn
()
...
@@ -406,13 +418,13 @@ class _CheckpointFunction(torch.autograd.Function):
...
@@ -406,13 +418,13 @@ class _CheckpointFunction(torch.autograd.Function):
# Store the current states.
# Store the current states.
bwd_cpu_rng_state
=
torch
.
get_rng_state
()
bwd_cpu_rng_state
=
torch
.
get_rng_state
()
bwd_cuda_rng_state
=
_get_cuda_rng_state
(
graph_safe
=
Fals
e
)
bwd_cuda_rng_state
=
_get_cuda_rng_state
(
graph_safe
=
ctx
.
graph_safe_rng_stat
e
)
if
get_rng_state_tracker
is
not
None
:
if
get_rng_state_tracker
is
not
None
:
bwd_cuda_rng_state_tracker
=
get_rng_state_tracker
().
get_states
()
bwd_cuda_rng_state_tracker
=
get_rng_state_tracker
().
get_states
()
# Set the states to what it used to be before the forward pass.
# Set the states to what it used to be before the forward pass.
torch
.
set_rng_state
(
ctx
.
fwd_cpu_rng_state
)
torch
.
set_rng_state
(
ctx
.
fwd_cpu_rng_state
)
_set_cuda_rng_state
(
ctx
.
fwd_cuda_rng_state
,
graph_safe
=
Fals
e
)
_set_cuda_rng_state
(
ctx
.
fwd_cuda_rng_state
,
graph_safe
=
ctx
.
graph_safe_rng_stat
e
)
if
get_rng_state_tracker
is
not
None
:
if
get_rng_state_tracker
is
not
None
:
get_rng_state_tracker
().
set_states
(
ctx
.
fwd_cuda_rng_state_tracker
)
get_rng_state_tracker
().
set_states
(
ctx
.
fwd_cuda_rng_state_tracker
)
...
@@ -427,7 +439,7 @@ class _CheckpointFunction(torch.autograd.Function):
...
@@ -427,7 +439,7 @@ class _CheckpointFunction(torch.autograd.Function):
# Set the states back to what it was at the start of this function.
# Set the states back to what it was at the start of this function.
torch
.
set_rng_state
(
bwd_cpu_rng_state
)
torch
.
set_rng_state
(
bwd_cpu_rng_state
)
_set_cuda_rng_state
(
bwd_cuda_rng_state
,
graph_safe
=
Fals
e
)
_set_cuda_rng_state
(
bwd_cuda_rng_state
,
graph_safe
=
ctx
.
graph_safe_rng_stat
e
)
if
get_rng_state_tracker
is
not
None
:
if
get_rng_state_tracker
is
not
None
:
get_rng_state_tracker
().
set_states
(
bwd_cuda_rng_state_tracker
)
get_rng_state_tracker
().
set_states
(
bwd_cuda_rng_state_tracker
)
...
@@ -470,12 +482,21 @@ class _CheckpointFrame:
...
@@ -470,12 +482,21 @@ class _CheckpointFrame:
def
cache_rng_states
(
self
,
forward
=
True
):
def
cache_rng_states
(
self
,
forward
=
True
):
"""Cache fwd/bwd RNG states in the frame to restore later."""
"""Cache fwd/bwd RNG states in the frame to restore later."""
rng_states
=
(
rng_states
=
(
torch
.
get_rng_state
(),)
torch
.
get_rng_state
(),
_get_cuda_rng_state
(
graph_safe
=
False
),
)
if
self
.
get_rng_state_tracker
is
not
None
:
if
self
.
get_rng_state_tracker
is
not
None
:
rng_states
+=
(
self
.
get_rng_state_tracker
().
get_states
(),)
tracker_states
=
self
.
get_rng_state_tracker
().
get_states
()
self
.
graph_safe_rng_state
=
(
is_graph_safe_rng_state
(
next
(
iter
(
tracker_states
.
values
())))
if
tracker_states
else
False
)
rng_states
+=
(
_get_cuda_rng_state
(
graph_safe
=
self
.
graph_safe_rng_state
),
tracker_states
,
)
else
:
self
.
graph_safe_rng_state
=
False
rng_states
+=
(
_get_cuda_rng_state
(
graph_safe
=
self
.
graph_safe_rng_state
),)
if
forward
:
if
forward
:
self
.
fwd_rng_states
=
rng_states
self
.
fwd_rng_states
=
rng_states
...
@@ -490,7 +511,7 @@ class _CheckpointFrame:
...
@@ -490,7 +511,7 @@ class _CheckpointFrame:
rng_states
=
self
.
bwd_rng_states
rng_states
=
self
.
bwd_rng_states
torch
.
set_rng_state
(
rng_states
[
0
])
torch
.
set_rng_state
(
rng_states
[
0
])
_set_cuda_rng_state
(
rng_states
[
1
],
graph_safe
=
Fals
e
)
_set_cuda_rng_state
(
rng_states
[
1
],
graph_safe
=
self
.
graph_safe_rng_stat
e
)
if
self
.
get_rng_state_tracker
is
not
None
:
if
self
.
get_rng_state_tracker
is
not
None
:
self
.
get_rng_state_tracker
().
set_states
(
rng_states
[
2
])
self
.
get_rng_state_tracker
().
set_states
(
rng_states
[
2
])
...
@@ -642,18 +663,18 @@ def checkpoint(
...
@@ -642,18 +663,18 @@ def checkpoint(
Parameters
Parameters
----------
----------
function: Callable
function
: Callable
pytorch module used to run the forward and backward passes using
pytorch module used to run the forward and backward passes using
the specified :attr:`args` and :attr:`kwargs`.
the specified :attr:`args` and :attr:`kwargs`.
distribute_saved_activations: bool, default = False
distribute_saved_activations
: bool, default = False
if set to `True` and `use_reentrant=True`, first tensor argument is distributed
if set to
`
`True`
`
and
`
`use_reentrant=True`
`
, first tensor argument is distributed
across the specified tensor parallel group (`tp_group`) before saving it for the
across the specified tensor parallel group (`
`
tp_group`
`
) before saving it for the
backward pass. This has no effect when `use_reentrant=False`.
backward pass. This has no effect when
`
`use_reentrant=False`
`
.
get_rng_state_tracker:
`
Callable
`
, default = None
get_rng_state_tracker
: Callable, default = None
python callable which returns an instance of :
func
:`CudaRNGStatesTracker`.
python callable which returns an instance of :
class
:`CudaRNGStatesTracker`.
tp_group : ProcessGroup, default = None
tp_group : ProcessGroup, default = None
tensor parallel process group. Used only when `distribute_saved_activations=True`
tensor parallel process group. Used only when
`
`distribute_saved_activations=True`
`
and `use_reentrant=True`. If `None`, it falls back to the default group.
and
`
`use_reentrant=True`
`
. If
`
`None`
`
, it falls back to the default group.
use_reentrant : bool, default = True
use_reentrant : bool, default = True
perform checkpointing in reentrant mode.
perform checkpointing in reentrant mode.
args : tuple
args : tuple
...
@@ -778,8 +799,8 @@ class CudaRNGStatesTracker:
...
@@ -778,8 +799,8 @@ class CudaRNGStatesTracker:
For model parallelism, multiple RNG states need to simultaneously exist in order
For model parallelism, multiple RNG states need to simultaneously exist in order
to execute operations in or out of the model parallel region. This class keeps
to execute operations in or out of the model parallel region. This class keeps
track of the various RNG states and provides utility methods to maintain them and
track of the various RNG states and provides utility methods to maintain them and
execute parts of the model under a given RNG setting. Using the `add` method, a
execute parts of the model under a given RNG setting. Using the
:meth:
`add` method, a
cuda rng state is initialized based on the input `seed` and is assigned to `name`.
cuda rng state is initialized based on the input
`
`seed`
`
and is assigned to
`
`name`
`
.
Later, by forking the rng state, we can perform operations and return to our starting
Later, by forking the rng state, we can perform operations and return to our starting
cuda state.
cuda state.
"""
"""
...
@@ -812,18 +833,24 @@ class CudaRNGStatesTracker:
...
@@ -812,18 +833,24 @@ class CudaRNGStatesTracker:
Set the rng states. For efficiency purposes, we do not
Set the rng states. For efficiency purposes, we do not
check the size of seed for compatibility.
check the size of seed for compatibility.
states: Dict[str, torch.Tensor]
Parameters
----------
states : Dict[str, torch.Tensor]
A mapping from string names to RNG states.
A mapping from string names to RNG states.
"""
"""
self
.
states_
=
states
self
.
states_
=
states
# Update global states.
set_all_rng_states
(
self
.
states_
)
def
add
(
self
,
name
:
str
,
seed
:
int
)
->
None
:
def
add
(
self
,
name
:
str
,
seed
:
int
)
->
None
:
"""
"""
Adds a new RNG state.
Adds a new RNG state.
name: str
Parameters
----------
name : str
string identifier for the RNG state.
string identifier for the RNG state.
seed: int
seed
: int
PyTorch seed for the RNG state.
PyTorch seed for the RNG state.
"""
"""
# Check seed is not already used.
# Check seed is not already used.
...
@@ -857,7 +884,9 @@ class CudaRNGStatesTracker:
...
@@ -857,7 +884,9 @@ class CudaRNGStatesTracker:
Fork the cuda rng state, perform operations, and exit with
Fork the cuda rng state, perform operations, and exit with
the original state.
the original state.
name: str
Parameters
----------
name : str
string identifier for the RNG state.
string identifier for the RNG state.
"""
"""
# Check if we have added the state
# Check if we have added the state
...
@@ -901,6 +930,34 @@ def reduce_scatter_along_first_dim(
...
@@ -901,6 +930,34 @@ def reduce_scatter_along_first_dim(
return
output
,
handle
return
output
,
handle
@
dataclass
class
_AsyncHandle
:
"""Handle for asynchronous collectives."""
async_handle
:
torch
.
distributed
.
Work
post_process_function
:
Optional
[
Callable
]
=
None
post_process_function_args
:
Optional
[
Tuple
[
Any
,
...]]
=
None
post_process_function_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
_synchronized
:
bool
=
False
def
wait
(
self
)
->
None
:
"""Synchronize the asynchronous communicaton.
Perform post-processing if needed.
"""
if
self
.
_synchronized
:
return
self
.
async_handle
.
wait
()
if
self
.
post_process_function
is
not
None
:
args
=
self
.
post_process_function_args
args
=
()
if
args
is
None
else
args
kwargs
=
self
.
post_process_function_kwargs
kwargs
=
{}
if
kwargs
is
None
else
kwargs
self
.
post_process_function
(
*
args
,
**
kwargs
)
self
.
_synchronized
=
True
def
_all_gather_fp8
(
def
_all_gather_fp8
(
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
process_group
:
dist_group_type
,
process_group
:
dist_group_type
,
...
@@ -948,7 +1005,13 @@ def _all_gather_fp8(
...
@@ -948,7 +1005,13 @@ def _all_gather_fp8(
if
isinstance
(
inp
,
Float8Tensor
):
if
isinstance
(
inp
,
Float8Tensor
):
dtype
=
inp
.
dtype
dtype
=
inp
.
dtype
device
=
inp
.
device
device
=
inp
.
device
# Temporarily ensure rowwise usage for output tensor creation
# since we're gathering rowwise data, not the transpose
init_rowwise_usage
=
quantizer
.
rowwise_usage
init_columnwise_usage
=
quantizer
.
columnwise_usage
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
init_columnwise_usage
)
out
=
quantizer
.
make_empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
out
=
quantizer
.
make_empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
quantizer
.
set_usage
(
rowwise
=
init_rowwise_usage
,
columnwise
=
init_columnwise_usage
)
elif
isinstance
(
inp
,
Float8Tensor
):
elif
isinstance
(
inp
,
Float8Tensor
):
out
=
inp
.
make_like
(
inp
,
shape
=
out_shape
)
out
=
inp
.
make_like
(
inp
,
shape
=
out_shape
)
out
.
_data
=
torch
.
empty
(
out
.
_data
=
torch
.
empty
(
...
@@ -985,77 +1048,7 @@ def _all_gather_fp8(
...
@@ -985,77 +1048,7 @@ def _all_gather_fp8(
return
out
,
handle
return
out
,
handle
def
_get_quantizer_format
(
quantizer
:
Quantizer
)
->
Optional
[
bool
]:
def
_start_all_gather_fp8_blockwise
(
"""Get quantizer format."""
if
isinstance
(
quantizer
,
DebugQuantizer
):
quantizer
=
quantizer
.
parent_quantizer
if
isinstance
(
quantizer
,
Float8BlockQuantizer
):
return
quantizer
.
all_gather_usage
return
None
def
_set_quantizer_format
(
quantizer
:
Quantizer
,
compact
:
bool
=
False
)
->
None
:
"""Make quantizer compact"""
_quantizer
=
quantizer
if
isinstance
(
quantizer
,
DebugQuantizer
):
_quantizer
=
quantizer
.
parent_quantizer
if
isinstance
(
_quantizer
,
Float8BlockQuantizer
):
_quantizer
.
all_gather_usage
=
compact
def
_post_process_fp8_blockwise_gather
(
out
:
Float8BlockwiseQTensorStorage
,
quantizer
:
Float8BlockQuantizer
,
handle
:
Optional
[
torch
.
distributed
.
Work
]
=
None
,
)
->
Float8BlockwiseQTensorStorage
:
"""Post-process FP8 blockwise gather."""
if
handle
is
not
None
:
handle
.
wait
()
handle
=
None
if
out
.
_is_gemm_ready_format
():
return
out
needs_columnwise_data_transpose
=
(
quantizer
is
not
None
and
quantizer
.
columnwise_usage
and
not
is_non_tn_fp8_gemm_supported
(
is_blockwise
=
True
)
)
need_rowwise_scale_transpose
=
(
quantizer
is
not
None
and
quantizer
.
rowwise_usage
and
not
is_non_tn_fp8_gemm_supported
(
is_blockwise
=
True
)
)
# CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
# columnwise compact format means doing 128x1 quantization of it
# so quantized tensor is 256x1024, scale inv is 2x1024
# If we were doing GEMM_READY format, then it's equivalent to do 1x128 quantization
# on a transposed 1024x256 tensor, so scale inv is 1024x2, cublas requries 2x1024
# Thereforce, it turns out we don't need to transpose the scale inv, only columnwise data
if
needs_columnwise_data_transpose
:
out
.
_transpose_columnwise_data
()
if
need_rowwise_scale_transpose
:
out
.
_rowwise_scale_inv
=
out
.
_rowwise_scale_inv
.
transpose
(
-
2
,
-
1
).
contiguous
()
out
.
_data_format
=
tex
.
Float8BlockScaleTensorFormat
.
GEMM_READY
return
out
@
dataclass
class
_FP8BlockwiseAllGatherAsyncHandle
:
"""Handle for asynchronous FP8 blockwise all-gather."""
tensor
:
Float8BlockwiseQTensorStorage
quantizer
:
Float8BlockQuantizer
async_handle
:
torch
.
distributed
.
Work
_synchronized
:
bool
=
False
def
wait
(
self
)
->
None
:
"""Wait for the async operation to complete and post-process the tensor."""
if
self
.
_synchronized
:
return
self
.
async_handle
.
wait
()
_post_process_fp8_blockwise_gather
(
self
.
tensor
,
self
.
quantizer
)
self
.
_synchronized
=
True
def
_all_gather_fp8_blockwise
(
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
process_group
:
dist_group_type
,
process_group
:
dist_group_type
,
*
,
*
,
...
@@ -1094,44 +1087,25 @@ def _all_gather_fp8_blockwise(
...
@@ -1094,44 +1087,25 @@ def _all_gather_fp8_blockwise(
)
)
world_size
=
get_distributed_world_size
(
process_group
)
world_size
=
get_distributed_world_size
(
process_group
)
# Check that quantizer is valid
if
quantizer
is
not
None
and
not
isinstance
(
quantizer
,
Float8BlockQuantizer
):
raise
ValueError
(
f
"Got non-FP8 blockwise quantizer (
{
quantizer
.
__class__
.
__name__
}
)"
)
if
not
(
quantizer
.
block_scaling_dim
==
1
and
(
quantizer
.
block_len
==
128
or
quantizer
.
block_len
==
64
)):
raise
NotImplementedError
(
"Only 1D blockwise quantization is supported for allgather"
)
# Output tensor dims
# Output tensor dims
if
out_shape
is
None
:
if
out_shape
is
None
:
out_shape
=
list
(
inp
.
size
())
out_shape
=
list
(
inp
.
size
())
out_shape
[
0
]
*=
world_size
out_shape
[
0
]
*=
world_size
# Doing BF16 gather for now as baseline because it's simpler
# Check that quantizer is valid
if
(
if
quantizer
is
None
:
not
isinstance
(
inp
,
Float8BlockwiseQTensorStorage
)
raise
ValueError
(
"Quantizer is missing"
)
and
quantizer
is
not
None
if
not
isinstance
(
quantizer
,
Float8BlockQuantizer
):
and
not
quantizer
.
is_quantizable
(
inp
)
raise
ValueError
(
f
"Got non-FP8 blockwise quantizer (
{
quantizer
.
__class__
.
__name__
}
)"
)
):
out
=
torch
.
empty
(
# Fall back to high-precision all-gather if FP8 is not supported
out_shape
,
if
not
quantizer
.
is_quantizable
(
inp
)
or
quantizer
.
block_scaling_dim
!=
1
:
dtype
=
dtype
,
out
=
torch
.
empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
device
=
device
,
memory_format
=
torch
.
contiguous_format
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out
,
inp
,
group
=
process_group
,
async_op
=
False
)
torch
.
distributed
.
all_gather_into_tensor
(
out
,
inp
,
group
=
process_group
,
async_op
=
False
)
orig_all_gather_usage
=
quantizer
.
all_gather_usage
quantizer
.
all_gather_usage
=
False
out
=
quantizer
(
out
)
out
=
quantizer
(
out
)
quantizer
.
all_gather_usage
=
orig_all_gather_usage
return
out
,
None
return
out
,
None
# Implementation of fp8 gather needs to account for:
# Quantize input tensor if needed
# * Getting columnwise data as a transpose of how it is stored for GEMMS.
# * Gathering non GEMM swizzled scales.
# Cast input tensor to Float8BlockwiseQTensor with required data
# Set to compact usage in case the quantizer is not correctly configured
orig_all_gather_usage
=
quantizer
.
all_gather_usage
quantizer
.
all_gather_usage
=
True
if
not
isinstance
(
inp
,
Float8BlockwiseQTensorStorage
):
if
not
isinstance
(
inp
,
Float8BlockwiseQTensorStorage
):
inp
=
quantizer
(
inp
)
inp
=
quantizer
(
inp
)
elif
(
quantizer
.
rowwise_usage
and
inp
.
_rowwise_data
is
None
)
or
(
elif
(
quantizer
.
rowwise_usage
and
inp
.
_rowwise_data
is
None
)
or
(
...
@@ -1146,14 +1120,9 @@ def _all_gather_fp8_blockwise(
...
@@ -1146,14 +1120,9 @@ def _all_gather_fp8_blockwise(
# Construct Float8BlockwiseQTensor output tensor
# Construct Float8BlockwiseQTensor output tensor
out
=
quantizer
.
make_empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
out
=
quantizer
.
make_empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
quantizer
.
all_gather_usage
=
orig_all_gather_usage
# Temporary buffers for all-gathering transposed buffers
interleaved_rowwise_scale_inv
=
None
# Begin to do network communication, need to make sure compact format
interleaved_columnwise_data
=
None
if
inp
.
_data_format
!=
tex
.
Float8BlockScaleTensorFormat
.
COMPACT
:
raise
RuntimeError
(
"All-gather with FP8 block-wise quantized tensor requires compact data format, "
f
"but found data_format=
{
inp
.
_data_format
}
"
)
# Coalesce NCCL collectives
# Coalesce NCCL collectives
with
torch
.
distributed
.
_coalescing_manager
(
with
torch
.
distributed
.
_coalescing_manager
(
...
@@ -1162,11 +1131,17 @@ def _all_gather_fp8_blockwise(
...
@@ -1162,11 +1131,17 @@ def _all_gather_fp8_blockwise(
async_ops
=
async_op
,
async_ops
=
async_op
,
)
as
coalescing_manager
:
)
as
coalescing_manager
:
# Gather
Float8BlockwiseQTensor data for row-wise usage
# Gather
row-wise data
if
quantizer
.
rowwise_usage
:
if
quantizer
.
rowwise_usage
:
# Launch all-gathers
scale_inv_shape
=
list
(
inp
.
_rowwise_scale_inv
.
size
())
scale_inv_shape
[
0
]
*=
world_size
interleaved_rowwise_scale_inv
=
torch
.
empty
(
scale_inv_shape
,
dtype
=
inp
.
_rowwise_scale_inv
.
dtype
,
device
=
device
,
)
torch
.
distributed
.
all_gather_into_tensor
(
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_rowwise_scale_inv
,
interleaved
_rowwise_scale_inv
,
inp
.
_rowwise_scale_inv
,
inp
.
_rowwise_scale_inv
,
group
=
process_group
,
group
=
process_group
,
)
)
...
@@ -1176,36 +1151,73 @@ def _all_gather_fp8_blockwise(
...
@@ -1176,36 +1151,73 @@ def _all_gather_fp8_blockwise(
group
=
process_group
,
group
=
process_group
,
)
)
#
Gather Float8BlockwiseQTensor data for c
olumn-wise
usage
#
C
olumn-wise
data
if
quantizer
.
columnwise_usage
:
if
quantizer
.
columnwise_usage
:
# Launch all-gathers
data_shape
=
list
(
inp
.
_columnwise_data
.
size
())
data_shape
[
0
]
*=
world_size
interleaved_columnwise_data
=
torch
.
empty
(
data_shape
,
dtype
=
inp
.
_columnwise_data
.
dtype
,
device
=
device
,
)
torch
.
distributed
.
all_gather_into_tensor
(
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_columnwise_scale_inv
,
out
.
_columnwise_scale_inv
,
inp
.
_columnwise_scale_inv
,
inp
.
_columnwise_scale_inv
,
group
=
process_group
,
group
=
process_group
,
)
)
torch
.
distributed
.
all_gather_into_tensor
(
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_columnwise_data
,
interleaved
_columnwise_data
,
inp
.
_columnwise_data
,
inp
.
_columnwise_data
,
group
=
process_group
,
group
=
process_group
,
)
)
handle
=
coalescing_manager
if
async_op
else
None
# Finalize communication if needed
async_handle
=
None
# Unlike MXFP8, this fp8 blockwise tensor primarily works with Hopper
# This means that we need to transpose the gathered columnwise data
# Example usage is grad_output tensor, ie. dY in linear backward
# We want to gather two FP8 tensors (rowwise and columnwise) along dim0
# and then transpose the columnwise data to match the rowwise data
# Make sure FP8 transpose is populated if needed
if
async_op
:
if
async_op
:
handle
=
_FP8BlockwiseAllGatherAsyncHandle
(
out
,
quantizer
,
handle
)
async_handle
=
_AsyncHandle
(
coalescing_manager
,
post_process_function
=
_finish_all_gather_fp8_blockwise
,
post_process_function_args
=
(
out
,
world_size
,
interleaved_rowwise_scale_inv
,
interleaved_columnwise_data
,
),
)
else
:
else
:
# if it's a sync op, we need to do the transpose here as post processing step
_finish_all_gather_fp8_blockwise
(
_post_process_fp8_blockwise_gather
(
out
,
quantizer
,
handle
)
out
,
world_size
,
interleaved_rowwise_scale_inv
,
interleaved_columnwise_data
,
)
return
out
,
handle
return
out
,
async_handle
def
_finish_all_gather_fp8_blockwise
(
out
:
Float8BlockwiseQTensorStorage
,
world_size
:
int
,
interleaved_rowwise_scale_inv
:
Optional
[
torch
.
Tensor
],
interleaved_columnwise_data
:
Optional
[
torch
.
Tensor
],
)
->
Float8BlockwiseQTensorStorage
:
"""Post-process FP8 blockwise gather."""
# Fix interleaving in row-wise scales
if
interleaved_rowwise_scale_inv
is
not
None
:
dim0
=
out
.
_rowwise_scale_inv
.
size
(
0
)
view_in
=
interleaved_rowwise_scale_inv
.
view
(
world_size
,
dim0
,
-
1
)
view_out
=
out
.
_rowwise_scale_inv
.
view
(
dim0
,
world_size
,
-
1
)
tex
.
swap_first_dims
(
view_in
,
out
=
view_out
)
# Fix interleaving in column-wise data
if
interleaved_columnwise_data
is
not
None
:
dim0
=
out
.
_columnwise_data
.
size
(
0
)
view_in
=
interleaved_columnwise_data
.
view
(
world_size
,
dim0
,
-
1
)
view_out
=
out
.
_columnwise_data
.
view
(
dim0
,
world_size
,
-
1
)
tex
.
swap_first_dims
(
view_in
,
out
=
view_out
)
return
out
def
_swap_first_dims
(
tensor
:
torch
.
Tensor
,
world_size
:
int
):
def
_swap_first_dims
(
tensor
:
torch
.
Tensor
,
world_size
:
int
):
...
@@ -1219,7 +1231,7 @@ def _swap_first_dims(tensor: torch.Tensor, world_size: int):
...
@@ -1219,7 +1231,7 @@ def _swap_first_dims(tensor: torch.Tensor, world_size: int):
"""
"""
shape
=
tensor
.
shape
shape
=
tensor
.
shape
assert
t
en
sor
.
ndim
>=
2
,
"Wrong number of dimensions for fixing interleave."
assert
l
en
(
shape
)
>=
2
,
"Wrong number of dimensions for fixing interleave."
first_dim
=
shape
[
0
]
first_dim
=
shape
[
0
]
flattened_trailing
=
math
.
prod
(
shape
[
1
:])
flattened_trailing
=
math
.
prod
(
shape
[
1
:])
assert
first_dim
%
world_size
==
0
,
"Wrong dimensions for fixing interleave."
assert
first_dim
%
world_size
==
0
,
"Wrong dimensions for fixing interleave."
...
@@ -1650,7 +1662,7 @@ def gather_along_first_dim(
...
@@ -1650,7 +1662,7 @@ def gather_along_first_dim(
if
isinstance
(
inp
,
Float8BlockwiseQTensorStorage
)
or
isinstance
(
if
isinstance
(
inp
,
Float8BlockwiseQTensorStorage
)
or
isinstance
(
quantizer
,
Float8BlockQuantizer
quantizer
,
Float8BlockQuantizer
):
):
return
_all_gather_fp8_blockwise
(
return
_start
_all_gather_fp8_blockwise
(
inp
,
inp
,
process_group
,
process_group
,
async_op
=
async_op
,
async_op
=
async_op
,
...
@@ -1688,10 +1700,6 @@ def gather_along_first_dim(
...
@@ -1688,10 +1700,6 @@ def gather_along_first_dim(
)
)
if
isinstance
(
inp
,
QuantizedTensorStorage
):
if
isinstance
(
inp
,
QuantizedTensorStorage
):
inp
=
inp
.
dequantize
()
inp
=
inp
.
dequantize
()
# Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format
compact
=
_get_quantizer_format
(
quantizer
)
_set_quantizer_format
(
quantizer
,
compact
=
False
)
out
=
torch
.
empty
(
out
=
torch
.
empty
(
out_shape
,
out_shape
,
dtype
=
inp
.
dtype
,
dtype
=
inp
.
dtype
,
...
@@ -1700,7 +1708,6 @@ def gather_along_first_dim(
...
@@ -1700,7 +1708,6 @@ def gather_along_first_dim(
)
)
torch
.
distributed
.
all_gather_into_tensor
(
out
,
inp
,
group
=
process_group
)
torch
.
distributed
.
all_gather_into_tensor
(
out
,
inp
,
group
=
process_group
)
out
=
quantizer
(
out
)
out
=
quantizer
(
out
)
_set_quantizer_format
(
quantizer
,
compact
=
compact
)
return
out
,
None
return
out
,
None
# Dequantize quantized tensor if not supported
# Dequantize quantized tensor if not supported
...
@@ -2001,7 +2008,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
...
@@ -2001,7 +2008,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
Parameters
Parameters
----------
----------
fsdp_root: torch.nn.Module
fsdp_root
: torch.nn.Module
FSDP-wrapped root module that may contain FSDP-wrapped TE modules.
FSDP-wrapped root module that may contain FSDP-wrapped TE modules.
"""
"""
assert
isinstance
(
fsdp_root
,
FSDP
),
"Root module must be FSDP-wrapped."
assert
isinstance
(
fsdp_root
,
FSDP
),
"Root module must be FSDP-wrapped."
...
...
transformer_engine/pytorch/export.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -28,7 +28,7 @@ def onnx_export(enabled: bool = False) -> Generator[None, None, None]:
...
@@ -28,7 +28,7 @@ def onnx_export(enabled: bool = False) -> Generator[None, None, None]:
Parameters
Parameters
----------
----------
enabled: bool, default =
`
False
`
enabled
: bool, default = False
whether or not to enable export
whether or not to enable export
"""
"""
...
...
transformer_engine/pytorch/float8_tensor.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/fp8.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/graph.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -7,6 +7,7 @@ from collections.abc import Iterable
...
@@ -7,6 +7,7 @@ from collections.abc import Iterable
import
contextlib
import
contextlib
import
gc
import
gc
import
warnings
import
warnings
from
math
import
ceil
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
TypeVar
,
Union
import
torch
import
torch
...
@@ -61,6 +62,21 @@ def graph_pool_handle():
...
@@ -61,6 +62,21 @@ def graph_pool_handle():
return
_graph_pool_handle
()
return
_graph_pool_handle
()
@
contextlib
.
contextmanager
def
_none_grad_context_wrapper
(
inputs
):
"""
Wrapper to set the gradients of the inputs to None,
in case the backward pass makes grad accumulations.
"""
original_input_grads
=
[]
for
input_tensor
in
inputs
:
original_input_grads
.
append
(
input_tensor
.
grad
)
input_tensor
.
grad
=
None
yield
for
input_tensor
,
original_grad
in
zip
(
inputs
,
original_input_grads
):
input_tensor
.
grad
=
original_grad
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
_graph_context_wrapper
(
*
args
,
**
kwargs
):
def
_graph_context_wrapper
(
*
args
,
**
kwargs
):
"""Wrapper around `torch.cuda.graph`.
"""Wrapper around `torch.cuda.graph`.
...
@@ -127,6 +143,8 @@ def _make_graphed_callables(
...
@@ -127,6 +143,8 @@ def _make_graphed_callables(
)
)
# Check sizes of args
# Check sizes of args
_order_without_wgrad
=
None
delay_wgrad_compute
=
False
if
_order
is
None
:
if
_order
is
None
:
assert
len
(
sample_args
)
==
len
(
callables
)
assert
len
(
sample_args
)
==
len
(
callables
)
assert
len
(
sample_kwargs
)
==
len
(
callables
)
assert
len
(
sample_kwargs
)
==
len
(
callables
)
...
@@ -145,17 +163,34 @@ def _make_graphed_callables(
...
@@ -145,17 +163,34 @@ def _make_graphed_callables(
# values indicate backward passes. Each
# values indicate backward passes. Each
# entry in sample_args corresponds to one of the forward
# entry in sample_args corresponds to one of the forward
# passes.
# passes.
num_model_chunks
=
max
(
_order
)
_order_without_wgrad
=
[]
num_microbatches
=
len
(
_order
)
//
num_model_chunks
//
2
for
c_id
in
_order
:
assert
num_model_chunks
*
num_microbatches
*
2
==
len
(
_order
)
if
ceil
(
c_id
)
!=
c_id
:
delay_wgrad_compute
=
True
continue
_order_without_wgrad
.
append
(
c_id
)
num_model_chunks
=
max
(
_order_without_wgrad
)
num_microbatches
=
len
(
_order_without_wgrad
)
//
num_model_chunks
//
2
assert
num_model_chunks
*
num_microbatches
*
2
==
len
(
_order_without_wgrad
)
# When delay_wgrad_compute is enabled, each layer is treated as a model chunk, which
# allows for fine-grained graph capture order.
if
delay_wgrad_compute
:
assert
(
_num_layers_per_chunk
is
not
None
),
"'_num_layers_per_chunk' must be provided when delay_wgrad_compute is True."
for
num_layers
in
_num_layers_per_chunk
:
assert
(
num_layers
==
1
),
"Each model chunk must have only one layer when delay_wgrad_compute is True."
# Determine number of layers in each model chunk.
# Determine number of layers in each model chunk.
if
_num_layers_per_chunk
is
None
:
if
_num_layers_per_chunk
is
None
:
assert
len
(
sample_args
)
*
2
>=
len
(
_order
)
and
(
assert
len
(
sample_args
)
*
2
>=
len
(
_order
_without_wgrad
)
and
(
len
(
sample_args
)
*
2
%
len
(
_order
)
==
0
len
(
sample_args
)
*
2
%
len
(
_order
_without_wgrad
)
==
0
),
(
),
(
f
"
{
len
(
sample_args
)
}
* 2 >=
{
len
(
_order
)
}
and
{
len
(
sample_args
)
}
* 2
%
"
f
"
{
len
(
sample_args
)
}
* 2 >=
{
len
(
_order
_without_wgrad
)
}
and
{
len
(
sample_args
)
}
* 2"
f
"
{
len
(
_order
)
}
== 0"
f
"
%
{
len
(
_order
_without_wgrad
)
}
== 0"
)
)
num_layers
=
len
(
sample_args
)
//
num_model_chunks
//
num_microbatches
num_layers
=
len
(
sample_args
)
//
num_model_chunks
//
num_microbatches
_num_layers_per_chunk
=
[
num_layers
]
*
num_model_chunks
_num_layers_per_chunk
=
[
num_layers
]
*
num_model_chunks
...
@@ -175,7 +210,7 @@ def _make_graphed_callables(
...
@@ -175,7 +210,7 @@ def _make_graphed_callables(
+
f
"entries when order input is provided but got
{
len
(
callables
)
}
."
+
f
"entries when order input is provided but got
{
len
(
callables
)
}
."
)
)
assert
len
(
sample_args
)
==
total_num_layers
*
num_microbatches
,
(
assert
len
(
sample_args
)
==
total_num_layers
*
num_microbatches
,
(
f
"Expected
{
total_num_layers
*
num_microbatches
}
"
f
"Expected
{
total_num_layers
*
num_microbatches
}
"
+
f
"args tuple, but got
{
len
(
sample_args
)
}
."
+
f
"args tuple, but got
{
len
(
sample_args
)
}
."
)
)
...
@@ -198,9 +233,10 @@ def _make_graphed_callables(
...
@@ -198,9 +233,10 @@ def _make_graphed_callables(
assert
(
assert
(
is_training
is_training
),
"`_reuse_graph_input_output_buffers` is only available in training mode."
),
"`_reuse_graph_input_output_buffers` is only available in training mode."
assert
isinstance
(
if
isinstance
(
sample_args
,
tuple
):
sample_args
,
list
sample_args
=
list
(
sample_args
)
),
"sample_args must be a list for _reuse_graph_input_output_buffers."
if
isinstance
(
sample_kwargs
,
tuple
):
sample_kwargs
=
list
(
sample_kwargs
)
# Reorganize args and kwargs for input tensor reuse.
# Reorganize args and kwargs for input tensor reuse.
# fwd_sample_qs is keyed by model chunk index. The value is a queue of tuples.
# fwd_sample_qs is keyed by model chunk index. The value is a queue of tuples.
...
@@ -214,7 +250,7 @@ def _make_graphed_callables(
...
@@ -214,7 +250,7 @@ def _make_graphed_callables(
consumed_sample_q
=
{}
consumed_sample_q
=
{}
fwd_idx
=
[
0
]
*
num_model_chunks
fwd_idx
=
[
0
]
*
num_model_chunks
for
c_id
in
_order
:
for
c_id
in
_order
:
m_chunk
=
abs
(
c_id
)
-
1
m_chunk
=
abs
(
ceil
(
c_id
)
)
-
1
if
c_id
>
0
:
if
c_id
>
0
:
sample_start_idx
=
(
_prefix_num_layers
[
m_chunk
]
*
num_microbatches
)
+
(
sample_start_idx
=
(
_prefix_num_layers
[
m_chunk
]
*
num_microbatches
)
+
(
...
@@ -241,6 +277,8 @@ def _make_graphed_callables(
...
@@ -241,6 +277,8 @@ def _make_graphed_callables(
sample_args
[
per_callable_fwd_idx
]
=
sample_args
[
reuse_fwd_idx
]
sample_args
[
per_callable_fwd_idx
]
=
sample_args
[
reuse_fwd_idx
]
sample_kwargs
[
per_callable_fwd_idx
]
=
sample_kwargs
[
reuse_fwd_idx
]
sample_kwargs
[
per_callable_fwd_idx
]
=
sample_kwargs
[
reuse_fwd_idx
]
fwd_idx
[
m_chunk
]
+=
1
fwd_idx
[
m_chunk
]
+=
1
elif
ceil
(
c_id
)
!=
c_id
:
continue
else
:
else
:
num_consumed_samples
=
min
(
num_consumed_samples
=
min
(
len
(
fwd_sample_qs
[
m_chunk
]),
_num_layers_per_chunk
[
m_chunk
]
len
(
fwd_sample_qs
[
m_chunk
]),
_num_layers_per_chunk
[
m_chunk
]
...
@@ -411,13 +449,15 @@ def _make_graphed_callables(
...
@@ -411,13 +449,15 @@ def _make_graphed_callables(
for
hook
in
hooks
:
for
hook
in
hooks
:
hook
.
remove
()
hook
.
remove
()
if
is_training
:
if
is_training
:
grad_inputs
=
torch
.
autograd
.
grad
(
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
)
outputs
=
tuple
(
o
for
o
in
outputs
if
o
.
requires_grad
),
with
_none_grad_context_wrapper
(
inputs
):
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
torch
.
autograd
.
backward
(
grad_outputs
=
tuple
(
torch
.
empty_like
(
o
)
for
o
in
outputs
if
o
.
requires_grad
),
tuple
(
o
for
o
in
outputs
if
o
.
requires_grad
),
only_inputs
=
True
,
grad_tensors
=
tuple
(
allow_unused
=
allow_unused_input
,
torch
.
empty_like
(
o
)
for
o
in
outputs
if
o
.
requires_grad
),
)
)
grad_inputs
=
tuple
(
input
.
grad
for
input
in
inputs
)
# Filter module params that get None grad from grad_inputs and remove them
# Filter module params that get None grad from grad_inputs and remove them
# from static_input_surface. This is to ensure that the backward hooks
# from static_input_surface. This is to ensure that the backward hooks
...
@@ -432,6 +472,14 @@ def _make_graphed_callables(
...
@@ -432,6 +472,14 @@ def _make_graphed_callables(
module_params_with_grad
=
[]
module_params_with_grad
=
[]
for
grad_inputs_idx
,
inputs_idx
in
enumerate
(
required_grad_input_idx
):
for
grad_inputs_idx
,
inputs_idx
in
enumerate
(
required_grad_input_idx
):
if
(
if
(
grad_inputs
[
grad_inputs_idx
]
is
None
and
grad_inputs_idx
<
num_required_grad_sample_args
):
assert
allow_unused_input
,
(
"The input tensor requires grad, but the grad is None after"
" backward pass."
)
elif
(
grad_inputs
[
grad_inputs_idx
]
is
not
None
grad_inputs
[
grad_inputs_idx
]
is
not
None
and
grad_inputs_idx
>=
num_required_grad_sample_args
and
grad_inputs_idx
>=
num_required_grad_sample_args
):
):
...
@@ -477,9 +525,11 @@ def _make_graphed_callables(
...
@@ -477,9 +525,11 @@ def _make_graphed_callables(
fwd_idx
=
[
0
]
*
num_model_chunks
fwd_idx
=
[
0
]
*
num_model_chunks
bwd_idx
=
[
0
]
*
num_model_chunks
bwd_idx
=
[
0
]
*
num_model_chunks
static_grad_outputs_dict
=
{}
static_grad_outputs_dict
=
{}
wgrad_validation_list
=
[
None
]
*
len
(
_order
)
previous_chunk_last_callable_bwd_idx
=
None
previous_chunk_last_callable_bwd_idx
=
None
for
c_id
in
_order
:
for
i
,
c_id
in
enumerate
(
_order
)
:
if
c_id
>
0
:
if
c_id
>
0
:
assert
isinstance
(
c_id
,
int
),
"Forward order value must be an integer."
# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
m_chunk
=
c_id
-
1
m_chunk
=
c_id
-
1
for
l_no
in
range
(
_num_layers_per_chunk
[
m_chunk
]):
for
l_no
in
range
(
_num_layers_per_chunk
[
m_chunk
]):
...
@@ -499,12 +549,65 @@ def _make_graphed_callables(
...
@@ -499,12 +549,65 @@ def _make_graphed_callables(
fwd_idx
[
m_chunk
]
+=
1
fwd_idx
[
m_chunk
]
+=
1
else
:
else
:
# Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
# Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
m_chunk
=
-
c_id
-
1
m_chunk
=
-
ceil
(
c_id
)
-
1
previous_per_callable_bwd_idx
=
None
previous_per_callable_bwd_idx
=
None
for
l_no
in
list
(
reversed
(
range
(
_num_layers_per_chunk
[
m_chunk
]))):
for
l_no
in
list
(
reversed
(
range
(
_num_layers_per_chunk
[
m_chunk
]))):
per_callable_bwd_idx
=
(
_prefix_num_layers
[
m_chunk
]
*
num_microbatches
)
+
(
per_callable_bwd_idx
=
(
_prefix_num_layers
[
m_chunk
]
*
num_microbatches
)
+
(
bwd_idx
[
m_chunk
]
*
_num_layers_per_chunk
[
m_chunk
]
+
l_no
bwd_idx
[
m_chunk
]
*
_num_layers_per_chunk
[
m_chunk
]
+
l_no
)
)
if
ceil
(
c_id
)
==
c_id
and
need_bwd_dw_graph
[
per_callable_bwd_idx
]:
# Check if bwd graph has corresponding wgrad graph:
# Number of dgrad backward graphs should be equal to number of
# wgrad backward graphs.
# Note: For MCore, the validation rule is more strict (the next backward
# of dgrad graph must be corresponding wgrad graph).
if
wgrad_validation_list
[
i
]
is
None
:
same_bwd_c_id_list
=
[
i
]
num_wgrad_c_id
=
0
for
idx
in
range
(
i
+
1
,
len
(
_order
)):
if
_order
[
idx
]
>
0
:
continue
if
_order
[
idx
]
==
c_id
:
same_bwd_c_id_list
.
append
(
idx
)
if
_order
[
idx
]
+
0.5
==
c_id
:
num_wgrad_c_id
+=
1
if
len
(
same_bwd_c_id_list
)
==
num_wgrad_c_id
:
for
same_c_id_idx
in
same_bwd_c_id_list
:
wgrad_validation_list
[
same_c_id_idx
]
=
True
break
if
len
(
same_bwd_c_id_list
)
<
num_wgrad_c_id
:
# It's impossible to have more wgrad than dgrad.
wgrad_validation_list
[
i
]
=
False
break
if
wgrad_validation_list
[
i
]
is
None
:
wgrad_validation_list
[
i
]
=
False
assert
wgrad_validation_list
[
i
],
(
f
"Number of wgrad graph(
{
num_wgrad_c_id
}
) doesn't match number "
f
"of dgrad graphs (
{
len
(
same_bwd_c_id_list
)
}
) for chunk
{
c_id
}
."
)
elif
ceil
(
c_id
)
!=
c_id
:
per_callable_bwd_idx
-=
_num_layers_per_chunk
[
m_chunk
]
assert
is_training
,
"Only training mode supports backward_dw."
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
# So skip capturing it. For backward_dw, the order value is c_id - 0.5 to indicate
# the specific order of backward_dw.
assert
ceil
(
c_id
)
-
c_id
==
0.5
,
(
"The order diff of wgrad and dgrad must be 0.5, "
f
"get
{
ceil
(
c_id
)
-
c_id
}
."
)
assert
need_bwd_dw_graph
[
per_callable_bwd_idx
],
"No module needs wgrad computation but get float in order"
bwd_dw_graph
=
bwd_dw_graphs
[
per_callable_bwd_idx
]
with
_graph_context_wrapper
(
bwd_dw_graph
,
pool
=
mempool
):
for
module
in
visited_te_modules
[
per_callable_bwd_idx
]:
if
(
hasattr
(
module
,
"need_backward_dw"
)
and
module
.
need_backward_dw
()
):
module
.
backward_dw
()
continue
static_input_surface
=
per_callable_static_input_surfaces
[
per_callable_bwd_idx
]
static_input_surface
=
per_callable_static_input_surfaces
[
per_callable_bwd_idx
]
static_outputs
=
per_callable_static_outputs
[
per_callable_bwd_idx
]
static_outputs
=
per_callable_static_outputs
[
per_callable_bwd_idx
]
bwd_graph
=
bwd_graphs
[
per_callable_bwd_idx
]
bwd_graph
=
bwd_graphs
[
per_callable_bwd_idx
]
...
@@ -528,26 +631,17 @@ def _make_graphed_callables(
...
@@ -528,26 +631,17 @@ def _make_graphed_callables(
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
)
)
if
is_training
:
if
is_training
:
with
_graph_context_wrapper
(
bwd_graph
,
pool
=
mempool
):
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
)
grad_inputs
=
torch
.
autograd
.
grad
(
with
_none_grad_context_wrapper
(
inputs
),
_graph_context_wrapper
(
outputs
=
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
bwd_graph
,
pool
=
mempool
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
):
grad_outputs
=
tuple
(
o
for
o
in
static_grad_outputs
if
o
is
not
None
),
torch
.
autograd
.
backward
(
only_inputs
=
True
,
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
)
,
allow_unused
=
allow_unused_input
,
grad_tensors
=
tuple
(
o
for
o
in
static_grad_outputs
if
o
is
not
None
)
,
retain_graph
=
retain_graph_in_backward
,
retain_graph
=
retain_graph_in_backward
,
)
)
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
grad_inputs
=
tuple
(
input
.
grad
for
input
in
inputs
)
# So skip capturing it.
if
need_bwd_dw_graph
[
per_callable_bwd_idx
]:
bwd_dw_graph
=
bwd_dw_graphs
[
per_callable_bwd_idx
]
with
_graph_context_wrapper
(
bwd_dw_graph
,
pool
=
mempool
):
for
module
in
visited_te_modules
[
per_callable_bwd_idx
]:
if
(
hasattr
(
module
,
"need_backward_dw"
)
and
module
.
need_backward_dw
()
):
module
.
backward_dw
()
# Constructs a tuple suitable for returning from Graphed.backward:
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
# that don't require grad. I couldn't think of a one-liner for this pattern.
...
@@ -596,7 +690,7 @@ def _make_graphed_callables(
...
@@ -596,7 +690,7 @@ def _make_graphed_callables(
per_callable_static_grad_inputs
[
idx
]
per_callable_static_grad_inputs
[
idx
]
)
)
previous_chunk_last_callable_bwd_idx
=
per_callable_bwd_idx
previous_chunk_last_callable_bwd_idx
=
per_callable_bwd_idx
if
ceil
(
c_id
)
==
c_id
:
bwd_idx
[
m_chunk
]
+=
1
bwd_idx
[
m_chunk
]
+=
1
else
:
else
:
# Capture forward graphs
# Capture forward graphs
...
@@ -628,15 +722,17 @@ def _make_graphed_callables(
...
@@ -628,15 +722,17 @@ def _make_graphed_callables(
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
torch
.
empty_like
(
o
)
if
o
.
requires_grad
else
None
for
o
in
static_outputs
)
)
if
is_training
:
if
is_training
:
with
_graph_context_wrapper
(
bwd_graph
,
pool
=
mempool
):
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
)
grad_inputs
=
torch
.
autograd
.
grad
(
with
_none_grad_context_wrapper
(
inputs
),
_graph_context_wrapper
(
outputs
=
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
),
bwd_graph
,
pool
=
mempool
inputs
=
tuple
(
i
for
i
in
static_input_surface
if
i
.
requires_grad
),
):
grad_outputs
=
tuple
(
o
for
o
in
static_grad_outputs
if
o
is
not
None
),
torch
.
autograd
.
backward
(
only_inputs
=
True
,
tuple
(
o
for
o
in
static_outputs
if
o
.
requires_grad
)
,
allow_unused
=
allow_unused_input
,
grad_tensors
=
tuple
(
o
for
o
in
static_grad_outputs
if
o
is
not
None
)
,
retain_graph
=
retain_graph_in_backward
,
retain_graph
=
retain_graph_in_backward
,
)
)
grad_inputs
=
tuple
(
input
.
grad
for
input
in
inputs
)
if
need_bwd_dw_graph
[
bwd_idx
]:
if
need_bwd_dw_graph
[
bwd_idx
]:
with
_graph_context_wrapper
(
bwd_dw_graph
,
pool
=
mempool
):
with
_graph_context_wrapper
(
bwd_dw_graph
,
pool
=
mempool
):
for
module
in
visited_te_modules
[
bwd_idx
]:
for
module
in
visited_te_modules
[
bwd_idx
]:
...
@@ -950,38 +1046,38 @@ def make_graphed_callables(
...
@@ -950,38 +1046,38 @@ def make_graphed_callables(
Positional arguments to callable(s).
Positional arguments to callable(s).
num_warmup_iters: int, default = 3
num_warmup_iters: int, default = 3
Number of warmup iterations.
Number of warmup iterations.
allow_unused_input: bool, default =
`
False
`
allow_unused_input: bool, default = False
Whether to handle case where callable inputs
Whether to handle case where callable inputs
and outputs are disconnected in compute graph.
and outputs are disconnected in compute graph.
sample_kwargs: (tuple of) dict, optional
sample_kwargs: (tuple of) dict, optional
Keyword arguments to callable(s)
Keyword arguments to callable(s)
pool: (tuple of) int, default =
`
None
`
, optional
pool: (tuple of) int, default = None, optional
An instance returned from function `torch.cuda.graph_pool_handle` that hints
An instance returned from function `torch.cuda.graph_pool_handle` that hints
this graph may share memory with the indicated pool.
this graph may share memory with the indicated pool.
retain_graph_in_backward: bool, default =
`
False
`
retain_graph_in_backward: bool, default = False
Whether to set retain_graph=True in backward graph capture.
Whether to set retain_graph=True in backward graph capture.
_reuse_graph_input_output_buffers: bool, default =
`
False
`
_reuse_graph_input_output_buffers: bool, default = False
Reduce memory usage by reusing input/output data buffers between
Reduce memory usage by reusing input/output data buffers between
graphs. Only supported with Mcore interleaved pipeline parallelism, i.e.
graphs. Only supported with Mcore interleaved pipeline parallelism, i.e.
when `_order` is provided. All callables in `modules` are assumed to have
when `_order` is provided. All callables in `modules` are assumed to have
inputs and outputs with the same dtype and shape.
inputs and outputs with the same dtype and shape.
Quantization
related
parameters
Quantization parameters
----------------------
----------------------
-
enabled: (tuple of) bool, default =
`
False
`
enabled: (tuple of) bool, default = False
whether or not to enable low precision quantization (FP8/FP4).
whether or not to enable low precision quantization (FP8/FP4).
If tuple, the length must match the number of modules.
If tuple, the length must match the number of modules.
calibrating: bool, default =
`
False
`
calibrating: bool, default = False
calibration mode allows collecting statistics such as amax and scale
calibration mode allows collecting statistics such as amax and scale
data of quantized tensors even when executing without quantization enabled.
data of quantized tensors even when executing without quantization enabled.
This is useful for saving an inference ready checkpoint while training
This is useful for saving an inference ready checkpoint while training
using a higher precision.
using a higher precision.
recipe: recipe.Recipe, default =
`
None
`
recipe: recipe.Recipe, default = None
recipe used for low precision quantization.
recipe used for low precision quantization.
amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default =
`
None
`
amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = None
distributed group over which amaxes for the quantized tensors
distributed group over which amaxes for the quantized tensors
are reduced at the end of each training step.
are reduced at the end of each training step.
cache_quantized_params: bool, default =
`
False
`
cache_quantized_params: bool, default = False
Whether or not to cache quantized weights across microbatches. if set to `True`,
Whether or not to cache quantized weights across microbatches. if set to `True`,
the `is_first_microbatch` boolean argument must be passed into the forward
the `is_first_microbatch` boolean argument must be passed into the forward
method for TransformerEngine modules. When storing primary weights in low precision
method for TransformerEngine modules. When storing primary weights in low precision
...
...
transformer_engine/pytorch/jit.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -8,7 +8,7 @@ from functools import wraps
...
@@ -8,7 +8,7 @@ from functools import wraps
from
typing
import
Callable
,
Optional
,
Tuple
from
typing
import
Callable
,
Optional
,
Tuple
import
torch
import
torch
from
.
import
torch_version
from
.
torch_version
import
torch_version
from
.export
import
is_in_onnx_export_mode
from
.export
import
is_in_onnx_export_mode
from
.utils
import
gpu_autocast_ctx
from
.utils
import
gpu_autocast_ctx
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
...
transformer_engine/pytorch/module/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/module/_common.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/pytorch/module/base.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -20,7 +20,6 @@ import torch.nn.functional as F
...
@@ -20,7 +20,6 @@ import torch.nn.functional as F
from
torch.distributed.tensor
import
DTensor
from
torch.distributed.tensor
import
DTensor
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
._common
import
_ParameterInitMeta
,
noop_cat
from
._common
import
_ParameterInitMeta
,
noop_cat
from
..quantization
import
(
from
..quantization
import
(
...
@@ -39,13 +38,19 @@ from ..distributed import (
...
@@ -39,13 +38,19 @@ from ..distributed import (
_fsdp_gather_tensors
,
_fsdp_gather_tensors
,
)
)
from
..constants
import
dist_group_type
from
..constants
import
dist_group_type
from
..cpp_extensions.gemm
import
_NUM_MAX_UB_STREAMS
from
..quantized_tensor
import
QuantizedTensor
,
QuantizedTensorStorage
,
Quantizer
from
..quantized_tensor
import
QuantizedTensor
,
QuantizedTensorStorage
,
Quantizer
from
..tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
..tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.storage.float8_tensor_storage
import
Float8TensorStorage
from
..tensor.storage.float8_tensor_storage
import
Float8TensorStorage
from
..tensor.storage.mxfp8_tensor_storage
import
MXFP8TensorStorage
from
..tensor.storage.mxfp8_tensor_storage
import
MXFP8TensorStorage
from
..utils
import
is_non_tn_fp8_gemm_supported
,
torch_get_autocast_gpu_dtype
from
..tensor.storage.nvfp4_tensor_storage
import
NVFP4TensorStorage
from
..utils
import
(
is_non_tn_fp8_gemm_supported
,
torch_get_autocast_gpu_dtype
,
get_nvtx_range_context
,
)
from
..tensor.storage.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
from
..tensor.storage.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
from
...common.recipe
import
DelayedScaling
,
Recipe
from
...common.recipe
import
DelayedScaling
,
Recipe
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.debug_state
import
TEDebugState
...
@@ -58,13 +63,9 @@ __all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"]
...
@@ -58,13 +63,9 @@ __all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"]
_2X_ACC_FPROP
=
False
_2X_ACC_FPROP
=
False
_2X_ACC_DGRAD
=
True
_2X_ACC_DGRAD
=
True
_2X_ACC_WGRAD
=
True
_2X_ACC_WGRAD
=
True
_multi_stream_cublas_workspace
=
[]
_dummy_wgrads
=
{}
_dummy_wgrads
=
{}
_multi_stream_cublas_batchgemm_workspace
=
[]
_multi_stream_cublas_batchgemm_workspace
=
[]
_cublas_workspace
=
None
_ub_communicators
=
None
_ub_communicators
=
None
ub_stream_nums
=
int
(
os
.
getenv
(
"NVTE_UB_STREAM_NUMS"
,
"2"
))
_NUM_MAX_UB_STREAMS
=
ub_stream_nums
if
IS_HIP_EXTENSION
else
3
_MIN_STREAM_PRIORITY
,
_MAX_STREAM_PRIORITY
=
None
,
None
_MIN_STREAM_PRIORITY
,
_MAX_STREAM_PRIORITY
=
None
,
None
layers_atomic_ring_exchange
=
[]
layers_atomic_ring_exchange
=
[]
...
@@ -77,39 +78,6 @@ class UserBufferQuantizationMode(Enum):
...
@@ -77,39 +78,6 @@ class UserBufferQuantizationMode(Enum):
NONE
=
"none"
NONE
=
"none"
FP8
=
"fp8"
FP8
=
"fp8"
def
get_cublas_workspace_size_bytes
()
->
None
:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
# Add env for control the padding for blaslt
if
IS_HIP_EXTENSION
:
return
134_217_728
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
major
>=
9
:
# 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales
return
32
*
1024
*
1024
+
1024
return
4_194_304
def
get_workspace
()
->
torch
.
Tensor
:
"""Returns workspace for cublas."""
global
_cublas_workspace
if
_cublas_workspace
is
None
:
_cublas_workspace
=
torch
.
empty
(
get_cublas_workspace_size_bytes
(),
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
return
_cublas_workspace
def
get_multi_stream_cublas_workspace
()
->
List
[
torch
.
Tensor
]:
"""Returns workspace for multi-stream cublas."""
global
_multi_stream_cublas_workspace
if
not
_multi_stream_cublas_workspace
:
for
_
in
range
(
tex
.
get_num_cublas_streams
()):
_multi_stream_cublas_workspace
.
append
(
torch
.
empty
(
get_cublas_workspace_size_bytes
(),
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
)
return
_multi_stream_cublas_workspace
def
get_multi_stream_cublas_batchgemm_workspace
()
->
List
[
torch
.
Tensor
]:
def
get_multi_stream_cublas_batchgemm_workspace
()
->
List
[
torch
.
Tensor
]:
"""Returns workspace for multi-stream cublas."""
"""Returns workspace for multi-stream cublas."""
global
_multi_stream_cublas_batchgemm_workspace
global
_multi_stream_cublas_batchgemm_workspace
...
@@ -126,7 +94,6 @@ if bool(int(os.getenv("NVTE_DISABLE_FC2_DGRAD_OVERLAP", "0"))):
...
@@ -126,7 +94,6 @@ if bool(int(os.getenv("NVTE_DISABLE_FC2_DGRAD_OVERLAP", "0"))):
else
:
else
:
remove_ag_gemm_dgrad
=
[]
remove_ag_gemm_dgrad
=
[]
def
get_dummy_wgrad
(
shape
:
list
,
dtype
:
torch
.
dtype
,
zero
=
False
)
->
torch
.
Tensor
:
def
get_dummy_wgrad
(
shape
:
list
,
dtype
:
torch
.
dtype
,
zero
=
False
)
->
torch
.
Tensor
:
"""Returns a dummy tensor of given shape."""
"""Returns a dummy tensor of given shape."""
assert
len
(
shape
)
==
2
assert
len
(
shape
)
==
2
...
@@ -154,27 +121,27 @@ def initialize_ub(
...
@@ -154,27 +121,27 @@ def initialize_ub(
)
->
None
:
)
->
None
:
r
"""
r
"""
Initialize the Userbuffers communicator for overlapping tensor-parallel communications with
Initialize the Userbuffers communicator for overlapping tensor-parallel communications with
GEMM compute in te.Linear
,
te.LayerNormLinear and te.LayerNormMLP modules.
GEMM compute in
``
te.Linear
``, ``
te.LayerNormLinear
``
and
``
te.LayerNormMLP
``
modules.
Parameters
Parameters
----------
----------
shape : list
shape : list
shape of the communication buffer, typically set to be the same as the global shape of
shape of the communication buffer, typically set to be the same as the global shape of
the input tensor to a te.TransformerLayer forward pass, with the sequence and batch
the input tensor to a
``
te.TransformerLayer
``
forward pass, with the sequence and batch
dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)`
dimensions collapsed together -- i.e.:
`
`(sequence_length * batch_size, hidden_size)`
`
tp_size : int
tp_size : int
number of GPUs in the tensor-parallel process group
number of GPUs in the tensor-parallel process group
use_fp8 : bool = False
use_fp8 : bool = False
allocate the communication buffer for FP8 GEMM inputs/outputs.
allocate the communication buffer for FP8 GEMM inputs/outputs.
DEPRECATED: Please use `quantization_modes` instead.
DEPRECATED: Please use
`
`quantization_modes`
`
instead.
quantization_modes : List[UserBufferQuantizationMode] = None
quantization_modes : List[UserBufferQuantizationMode] = None
if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list.
if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list.
falls back to the legacy `use_fp8` parameter if `None` is provided.
falls back to the legacy
`
`use_fp8`
`
parameter if
`
`None`
`
is provided.
dtype : torch.dtype = torch.bfloat16
dtype : torch.dtype = torch.bfloat16
non-FP8 data type of the communication buffer when `use_fp8 = False`
non-FP8 data type of the communication buffer when
`
`use_fp8 = False`
`
ub_cfgs: dict = None
ub_cfgs
: dict = None
Configuration dictionary with the structure
Configuration dictionary with the structure
::
```
{
{
<gemm_name> : {
<gemm_name> : {
"method": <"ring_exchange" or "pipeline">,
"method": <"ring_exchange" or "pipeline">,
...
@@ -189,20 +156,20 @@ def initialize_ub(
...
@@ -189,20 +156,20 @@ def initialize_ub(
"fp8_buf": bool,
"fp8_buf": bool,
}
}
}
}
```
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
for
`
`te.TransformerLayer`
`
GEMM layers in
`
`["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_wgrad"]`.
"fc2_fprop", "fc2_wgrad"]`
`
.
a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes`
a list may be provided to specify different overlap configurations for different the quantization settings in
`
`quantization_modes`
`
bootstrap_backend : str = None
bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and
`
`torch.distributed`
`
communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are
barrier collectives during Userbuffers initialization. Not all backends are
valid for every cluster configuration and distributed launch method even if
valid for every cluster configuration and distributed launch method even if
they are available in PyTorch. When left unset, the initialization prefers
they are available in PyTorch. When left unset, the initialization prefers
to use the MPI backend, falling back first on Gloo and then NCCL if MPI is
to use the MPI backend, falling back first on Gloo and then NCCL if MPI is
not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this
not available. Setting
`
`NVTE_UB_WITH_MPI=1`
`
when building TE overrides this
option and always initializes Userbuffers with direct MPI calls in C++,
option and always initializes Userbuffers with direct MPI calls in C++,
which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time.
which also requires
`
`MPI_HOME=/path/to/mpi/root`
`
to be set at compile time.
"""
"""
if
not
tex
.
device_supports_multicast
():
if
not
tex
.
device_supports_multicast
():
assert
bool
(
int
(
os
.
getenv
(
"UB_SKIPMC"
,
"1"
))),
(
assert
bool
(
int
(
os
.
getenv
(
"UB_SKIPMC"
,
"1"
))),
(
...
@@ -299,16 +266,6 @@ def initialize_ub(
...
@@ -299,16 +266,6 @@ def initialize_ub(
flush
=
True
,
flush
=
True
,
)
)
# Allocate cuBLAS workspace with expanded size for chunking in overlapping GEMM calls
global
_cublas_workspace
if
_cublas_workspace
is
None
:
_cublas_workspace
=
get_workspace
().
repeat
(
_NUM_MAX_UB_STREAMS
)
elif
_cublas_workspace
.
numel
()
!=
get_cublas_workspace_size_bytes
()
*
_NUM_MAX_UB_STREAMS
:
# This ensures we don't do `.repeat()` on an already expanded workspace
_cublas_workspace
=
torch
.
empty
(
get_cublas_workspace_size_bytes
(),
dtype
=
torch
.
uint8
,
device
=
"cuda"
).
repeat
(
_NUM_MAX_UB_STREAMS
)
# Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
# Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
layers_all_gather_overlap
=
[
layers_all_gather_overlap
=
[
"qkv_fprop"
,
"qkv_fprop"
,
...
@@ -642,6 +599,8 @@ def fill_userbuffers_buffer_for_all_gather(
...
@@ -642,6 +599,8 @@ def fill_userbuffers_buffer_for_all_gather(
"Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
"Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
f
"but got MXFP8 tensor with shape=
{
tuple
(
local_shape
)
}
"
f
"but got MXFP8 tensor with shape=
{
tuple
(
local_shape
)
}
"
)
)
if
local_tensor
.
_with_gemm_swizzled_scales
:
raise
ValueError
(
"Userbuffers assumes MXFP8 tensors have unswizzled scales"
)
local_scale_inv
=
(
local_scale_inv
=
(
local_tensor
.
_rowwise_scale_inv
local_tensor
.
_rowwise_scale_inv
if
with_rowwise_data
if
with_rowwise_data
...
@@ -674,6 +633,7 @@ def fill_userbuffers_buffer_for_all_gather(
...
@@ -674,6 +633,7 @@ def fill_userbuffers_buffer_for_all_gather(
columnwise_scale_inv
=
columnwise_scale_inv
,
columnwise_scale_inv
=
columnwise_scale_inv
,
fp8_dtype
=
local_tensor
.
_fp8_dtype
,
fp8_dtype
=
local_tensor
.
_fp8_dtype
,
quantizer
=
quantizer
,
quantizer
=
quantizer
,
with_gemm_swizzled_scales
=
False
,
)
)
return
global_tensor
,
local_tensor
return
global_tensor
,
local_tensor
...
@@ -1033,7 +993,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1033,7 +993,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
Parameters
Parameters
----------
----------
tp_group : ProcessGroup, default =
`
None
`
tp_group : ProcessGroup, default = None
tensor parallel process group.
tensor parallel process group.
"""
"""
self
.
tp_group
=
tp_group
self
.
tp_group
=
tp_group
...
@@ -1123,8 +1083,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1123,8 +1083,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""
"""
self
.
allow_different_data_and_param_types
=
allow_different_data_and_param_types
self
.
allow_different_data_and_param_types
=
allow_different_data_and_param_types
self
.
forwarded_at_least_once
=
True
self
.
forwarded_at_least_once
=
True
# Activation recomputation is used and this is the second forward phase.
# Activation recomputation is used and this is the second forward phase.
if
self
.
fp8
and
in_fp8_activation_recompute_phase
():
if
self
.
fp8
and
in_fp8_activation_recompute_phase
():
delayed_scaling_recipe
=
self
.
fp8_meta
[
"recipe"
].
delayed
()
FP8GlobalStateManager
.
get_old_fp8_meta_tensors_for_recompute
(
self
.
fp8_meta
)
FP8GlobalStateManager
.
get_old_fp8_meta_tensors_for_recompute
(
self
.
fp8_meta
)
else
:
else
:
assert
inp
.
is_cuda
,
"TransformerEngine needs CUDA."
assert
inp
.
is_cuda
,
"TransformerEngine needs CUDA."
...
@@ -1136,25 +1098,27 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1136,25 +1098,27 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self
.
init_fp8_metadata
(
num_gemms
=
num_gemms
)
self
.
init_fp8_metadata
(
num_gemms
=
num_gemms
)
self
.
_check_weight_tensor_recipe_correspondence
()
self
.
_check_weight_tensor_recipe_correspondence
()
if
self
.
fp8
and
self
.
sequence_parallel
and
self
.
fp8_meta
[
"recipe"
].
delayed
():
delayed_scaling_recipe
=
self
.
fp8
and
self
.
fp8_meta
[
"recipe"
].
delayed
()
if
delayed_scaling_recipe
:
if
self
.
sequence_parallel
:
assert
self
.
fp8_meta
[
"recipe"
].
reduce_amax
,
(
assert
self
.
fp8_meta
[
"recipe"
].
reduce_amax
,
(
"Amax reduction across tensor parallel group is "
"Amax reduction across tensor parallel group is "
"necessary when using sequence parallelism with FP8."
"necessary when using sequence parallelism with FP8."
)
)
if
self
.
fp8
and
not
FP8GlobalStateManager
.
fp8_graph_capturing
():
if
not
FP8GlobalStateManager
.
fp8_graph_capturing
():
FP8GlobalStateManager
.
add_fp8_tensors_to_global_buffer
(
self
.
fp8_meta
)
FP8GlobalStateManager
.
add_fp8_tensors_to_global_buffer
(
self
.
fp8_meta
)
# Activation recomputation is used and this is the first forward phase.
# Activation recomputation is used and this is the first forward phase.
if
self
.
fp8
and
self
.
training
and
is_fp8_activation_recompute_enabled
():
if
self
.
training
and
is_fp8_activation_recompute_enabled
():
FP8GlobalStateManager
.
copy_forward_fp8_meta_tensors_for_recompute
(
self
.
fp8_meta
)
FP8GlobalStateManager
.
copy_forward_fp8_meta_tensors_for_recompute
(
self
.
fp8_meta
)
with
torch
.
cuda
.
nvtx
.
range
(
self
.
__class__
.
__name__
+
" forward"
):
with
get_
nvtx
_
range
_context
(
self
.
__class__
.
__name__
+
" forward"
):
if
not
allow_non_contiguous
and
not
inp
.
is_contiguous
():
if
not
allow_non_contiguous
and
not
inp
.
is_contiguous
():
inp
=
inp
.
contiguous
()
inp
=
inp
.
contiguous
()
yield
inp
yield
inp
if
self
.
fp8
and
in_fp8_activation_recompute_phase
():
if
delayed_scaling_recipe
and
self
.
fp8
and
in_fp8_activation_recompute_phase
():
FP8GlobalStateManager
.
restore_fp8_meta_tensors
(
self
.
fp8_meta
)
FP8GlobalStateManager
.
restore_fp8_meta_tensors
(
self
.
fp8_meta
)
def
set_nccl_overlap_warning_if_tp
(
self
)
->
None
:
def
set_nccl_overlap_warning_if_tp
(
self
)
->
None
:
...
@@ -1243,18 +1207,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1243,18 +1207,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None
# bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None
if
ctx
.
debug
:
if
ctx
.
debug
:
grad_output_
=
quantizer
(
grad_output
)
grad_output_
=
quantizer
(
grad_output
)
if
(
if
ctx
.
use_bias
:
isinstance
(
grad_output_
.
get_tensor
(
True
),
(
QuantizedTensor
,
Float8TensorStorage
,
MXFP8TensorStorage
,
Float8BlockwiseQTensorStorage
,
),
)
and
ctx
.
use_bias
):
grad_bias
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
sum
(
dim
=
0
)
grad_bias
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
sum
(
dim
=
0
)
else
:
else
:
grad_bias
=
None
grad_bias
=
None
...
@@ -1434,7 +1387,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1434,7 +1387,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
workspace is being constructed or updated.
workspace is being constructed or updated.
cache_name: str, optional
cache_name: str, optional
Key for caching.
Key for caching.
update_workspace: bool, default =
`
True
`
update_workspace: bool, default = True
Update workspace with values from `tensor`.
Update workspace with values from `tensor`.
skip_update_flag: torch.Tensor, optional
skip_update_flag: torch.Tensor, optional
GPU flag to skip updating the workspace. Take precedence
GPU flag to skip updating the workspace. Take precedence
...
@@ -1478,6 +1431,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1478,6 +1431,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
reset_cache
=
True
reset_cache
=
True
elif
quantizer
.
columnwise_usage
and
out
.
_columnwise_data
is
None
:
elif
quantizer
.
columnwise_usage
and
out
.
_columnwise_data
is
None
:
reset_cache
=
True
reset_cache
=
True
elif
isinstance
(
out
,
NVFP4TensorStorage
):
if
quantizer
.
rowwise_usage
and
out
.
_rowwise_data
is
None
:
reset_cache
=
True
elif
quantizer
.
columnwise_usage
and
out
.
_columnwise_data
is
None
:
reset_cache
=
True
if
isinstance
(
out
,
DebugQuantizedTensor
)
!=
isinstance
(
quantizer
,
DebugQuantizer
):
if
isinstance
(
out
,
DebugQuantizedTensor
)
!=
isinstance
(
quantizer
,
DebugQuantizer
):
reset_cache
=
True
reset_cache
=
True
if
reset_cache
:
if
reset_cache
:
...
@@ -1576,7 +1534,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1576,7 +1534,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""
"""
if
not
self
.
need_backward_dw
():
if
not
self
.
need_backward_dw
():
return
return
with
torch
.
cuda
.
nvtx
.
range
(
f
"_
{
self
.
__class__
.
__name__
}
_wgrad"
):
with
get_
nvtx
_
range
_context
(
f
"_
{
self
.
__class__
.
__name__
}
_wgrad"
):
(
wgrad
,
bgrad
),
_
=
self
.
wgrad_store
.
pop
()
(
wgrad
,
bgrad
),
_
=
self
.
wgrad_store
.
pop
()
if
not
self
.
fuse_wgrad_accumulation
:
if
not
self
.
fuse_wgrad_accumulation
:
weight_tensor
=
noop_cat
(
self
.
_get_weight_tensors
())
weight_tensor
=
noop_cat
(
self
.
_get_weight_tensors
())
...
@@ -1618,6 +1576,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1618,6 +1576,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# we use the debug value from the first invocation in the iteration.
# we use the debug value from the first invocation in the iteration.
debug
=
self
.
debug_enabled_in_this_iteration
debug
=
self
.
debug_enabled_in_this_iteration
self
.
debug_last_iteration
=
TEDebugState
.
get_iteration
()
if
self
.
wgrad_store
is
not
None
:
if
debug
and
self
.
wgrad_store
.
delay_wgrad_compute
():
raise
RuntimeError
(
"Delayed wgrad compute is not supported in debug mode."
)
return
debug
return
debug
def
no_debug_features_active
(
self
,
quantizers
):
def
no_debug_features_active
(
self
,
quantizers
):
...
@@ -1673,6 +1637,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1673,6 +1637,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""
"""
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
return
return
if
not
self
.
primary_weights_in_fp8
:
return
if
not
hasattr
(
self
,
"weight_names"
)
or
not
self
.
weight_names
:
if
not
hasattr
(
self
,
"weight_names"
)
or
not
self
.
weight_names
:
return
return
...
...
transformer_engine/pytorch/module/fp8_padding.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
..quantization
import
FP8GlobalStateManager
from
..quantization
import
FP8GlobalStateManager
,
get_align_size_for_quantization
from
..jit
import
no_torch_dynamo
from
..jit
import
no_torch_dynamo
...
@@ -24,11 +24,14 @@ class _Fp8Padding(torch.autograd.Function):
...
@@ -24,11 +24,14 @@ class _Fp8Padding(torch.autograd.Function):
def
forward
(
def
forward
(
ctx
,
ctx
,
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
m_splits
:
List
[
int
],
non_tensor_args
:
Tuple
,
padded_m_splits
:
List
[
int
],
is_grad_enabled
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
m_splits
,
padded_m_splits
,
is_grad_enabled
)
=
non_tensor_args
# Make sure input dimensions are compatible
# Make sure input dimensions are compatible
in_features
=
inp
.
shape
[
-
1
]
in_features
=
inp
.
shape
[
-
1
]
...
@@ -65,7 +68,7 @@ class _Fp8Padding(torch.autograd.Function):
...
@@ -65,7 +68,7 @@ class _Fp8Padding(torch.autograd.Function):
grad_output
.
view
(
-
1
,
in_features
),
grad_input
,
ctx
.
padded_m_splits
,
ctx
.
m_splits
grad_output
.
view
(
-
1
,
in_features
),
grad_input
,
ctx
.
padded_m_splits
,
ctx
.
m_splits
)
)
return
(
grad_input
,
None
,
None
,
None
)
return
grad_input
,
None
class
Fp8Padding
(
torch
.
nn
.
Module
):
class
Fp8Padding
(
torch
.
nn
.
Module
):
...
@@ -111,14 +114,8 @@ class Fp8Padding(torch.nn.Module):
...
@@ -111,14 +114,8 @@ class Fp8Padding(torch.nn.Module):
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
if
self
.
align_size
is
None
:
if
self
.
align_size
is
None
:
self
.
align_size
=
(
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
32
self
.
align_size
=
get_align_size_for_quantization
(
recipe
)
if
(
FP8GlobalStateManager
.
get_fp8_recipe
().
mxfp8
()
or
FP8GlobalStateManager
.
get_fp8_recipe
().
nvfp4
()
)
else
16
)
# FP8 padding calculate
# FP8 padding calculate
padded_m_splits
=
[
padded_m_splits
=
[
...
@@ -128,19 +125,20 @@ class Fp8Padding(torch.nn.Module):
...
@@ -128,19 +125,20 @@ class Fp8Padding(torch.nn.Module):
if
m_splits
==
padded_m_splits
:
if
m_splits
==
padded_m_splits
:
return
inp
,
m_splits
return
inp
,
m_splits
if
torch
.
is_grad_enabled
():
is_grad_enabled
=
torch
.
is_grad_enabled
()
if
is_grad_enabled
:
fn
=
_Fp8Padding
.
apply
fn
=
_Fp8Padding
.
apply
a
rgs
=
[]
a
utograd_ctx
=
[]
else
:
else
:
fn
=
_Fp8Padding
.
forward
fn
=
_Fp8Padding
.
forward
a
rgs
=
[
None
]
a
utograd_ctx
=
[
None
]
args
+=
(
non_tensor_args
=
(
inp
,
m_splits
,
m_splits
,
padded_m_splits
,
padded_m_splits
,
torch
.
is_grad_enabled
()
,
is_grad_enabled
,
)
)
out
=
fn
(
*
args
)
out
=
fn
(
*
autograd_ctx
,
inp
,
non_tensor_
args
)
return
out
,
padded_m_splits
return
out
,
padded_m_splits
transformer_engine/pytorch/module/fp8_unpadding.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""FP8 Padding API"""
"""FP8 Padding API"""
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
..quantization
import
FP8GlobalStateManager
from
..quantization
import
FP8GlobalStateManager
,
get_align_size_for_quantization
from
..jit
import
no_torch_dynamo
from
..jit
import
no_torch_dynamo
...
@@ -24,11 +24,14 @@ class _Fp8Unpadding(torch.autograd.Function):
...
@@ -24,11 +24,14 @@ class _Fp8Unpadding(torch.autograd.Function):
def
forward
(
def
forward
(
ctx
,
ctx
,
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
m_splits
:
List
[
int
],
non_tensor_args
:
Tuple
,
padded_m_splits
:
List
[
int
],
is_grad_enabled
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
m_splits
,
padded_m_splits
,
is_grad_enabled
)
=
non_tensor_args
in_features
=
inp
.
shape
[
-
1
]
in_features
=
inp
.
shape
[
-
1
]
# Allocate cast and transpose output tensor
# Allocate cast and transpose output tensor
...
@@ -63,7 +66,7 @@ class _Fp8Unpadding(torch.autograd.Function):
...
@@ -63,7 +66,7 @@ class _Fp8Unpadding(torch.autograd.Function):
grad_output
.
view
(
-
1
,
in_features
),
grad_input
,
ctx
.
m_splits
,
ctx
.
padded_m_splits
grad_output
.
view
(
-
1
,
in_features
),
grad_input
,
ctx
.
m_splits
,
ctx
.
padded_m_splits
)
)
return
(
grad_input
,
None
,
None
,
None
)
return
grad_input
,
None
class
Fp8Unpadding
(
torch
.
nn
.
Module
):
class
Fp8Unpadding
(
torch
.
nn
.
Module
):
...
@@ -109,14 +112,8 @@ class Fp8Unpadding(torch.nn.Module):
...
@@ -109,14 +112,8 @@ class Fp8Unpadding(torch.nn.Module):
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
if
self
.
align_size
is
None
:
if
self
.
align_size
is
None
:
self
.
align_size
=
(
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
32
self
.
align_size
=
get_align_size_for_quantization
(
recipe
)
if
(
FP8GlobalStateManager
.
get_fp8_recipe
().
mxfp8
()
or
FP8GlobalStateManager
.
get_fp8_recipe
().
nvfp4
()
)
else
16
)
# FP8 padding calculate
# FP8 padding calculate
padded_m_splits
=
[
padded_m_splits
=
[
...
@@ -126,19 +123,20 @@ class Fp8Unpadding(torch.nn.Module):
...
@@ -126,19 +123,20 @@ class Fp8Unpadding(torch.nn.Module):
if
m_splits
==
padded_m_splits
:
if
m_splits
==
padded_m_splits
:
return
inp
return
inp
if
torch
.
is_grad_enabled
():
is_grad_enabled
=
torch
.
is_grad_enabled
()
if
is_grad_enabled
:
fn
=
_Fp8Unpadding
.
apply
fn
=
_Fp8Unpadding
.
apply
a
rgs
=
[]
a
utograd_ctx
=
[]
else
:
else
:
fn
=
_Fp8Unpadding
.
forward
fn
=
_Fp8Unpadding
.
forward
a
rgs
=
[
None
]
a
utograd_ctx
=
[
None
]
args
+=
(
non_tensor_args
=
(
inp
,
m_splits
,
m_splits
,
padded_m_splits
,
padded_m_splits
,
torch
.
is_grad_enabled
()
,
is_grad_enabled
,
)
)
out
=
fn
(
*
args
)
out
=
fn
(
*
autograd_ctx
,
inp
,
non_tensor_
args
)
return
out
return
out
transformer_engine/pytorch/module/grouped_linear.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""GroupedLinear API"""
"""GroupedLinear API"""
from
typing
import
Union
,
Optional
,
Callable
,
Tuple
,
List
from
typing
import
Union
,
Optional
,
Callable
,
Tuple
,
List
from
itertools
import
chain
import
warnings
import
warnings
import
os
import
os
import
functools
import
functools
import
torch
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.common.recipe
import
Recipe
from
.base
import
(
from
.base
import
(
get_dummy_wgrad
,
get_multi_stream_cublas_workspace
,
get_dummy_wgrad
,
get_dummy_wgrad
,
TransformerEngineBaseModule
,
TransformerEngineBaseModule
,
_2X_ACC_FPROP
,
_2X_ACC_FPROP
,
...
@@ -30,6 +30,7 @@ from ..utils import (
...
@@ -30,6 +30,7 @@ from ..utils import (
clear_tensor_data
,
clear_tensor_data
,
init_method_constant
,
init_method_constant
,
requires_grad
,
requires_grad
,
get_nvtx_range_context
,
)
)
from
..distributed
import
(
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
set_tensor_model_parallel_attributes
,
...
@@ -42,7 +43,6 @@ from ..cpp_extensions import (
...
@@ -42,7 +43,6 @@ from ..cpp_extensions import (
)
)
from
..constants
import
GemmParallelModes
,
dist_group_type
from
..constants
import
GemmParallelModes
,
dist_group_type
from
..jit
import
no_torch_dynamo
from
..jit
import
no_torch_dynamo
from
..graph
import
is_graph_capturing
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_not_offload
,
start_offload
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_not_offload
,
start_offload
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
...
@@ -52,7 +52,8 @@ from ..quantized_tensor import (
...
@@ -52,7 +52,8 @@ from ..quantized_tensor import (
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
)
)
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
from
...debug.pytorch.debug_state
import
TEDebugState
__all__
=
[
"GroupedLinear"
]
__all__
=
[
"GroupedLinear"
]
...
@@ -62,32 +63,42 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -62,32 +63,42 @@ class _GroupedLinear(torch.autograd.Function):
Calls custom cuda extensions.
Calls custom cuda extensions.
"""
"""
# pylint: disable=keyword-arg-before-vararg
@
staticmethod
@
staticmethod
def
forward
(
def
forward
(
ctx
,
ctx
,
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
m_splits
:
List
[
int
],
non_tensor_args
:
Tuple
,
use_bias
:
bool
,
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
input_quantizers
:
List
[
Quantizer
],
weight_quantizers
:
List
[
Quantizer
],
output_quantizers
:
List
[
Quantizer
],
grad_output_quantizers
:
List
[
Quantizer
],
fuse_wgrad_accumulation
:
bool
,
cpu_offloading
:
bool
,
sequence_parallel
:
bool
,
activation_dtype
:
torch
.
dtype
,
is_grad_enabled
:
bool
,
module
,
skip_fp8_weight_update
,
save_original_input
,
*
weights_and_biases
,
*
weights_and_biases
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
m_splits
,
use_bias
,
is_first_microbatch
,
fp8
,
fp8_calibration
,
wgrad_store
,
input_quantizers
,
weight_quantizers
,
output_quantizers
,
grad_input_quantizers
,
grad_weight_quantizers
,
grad_output_quantizers
,
fuse_wgrad_accumulation
,
cpu_offloading
,
sequence_parallel
,
activation_dtype
,
is_grad_enabled
,
module
,
skip_fp8_weight_update
,
save_original_input
,
debug
,
)
=
non_tensor_args
num_gemms
=
len
(
m_splits
)
num_gemms
=
len
(
m_splits
)
weights
=
weights_and_biases
[:
num_gemms
]
weights
=
weights_and_biases
[:
num_gemms
]
biases
=
weights_and_biases
[
num_gemms
:]
biases
=
weights_and_biases
[
num_gemms
:]
...
@@ -133,8 +144,17 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -133,8 +144,17 @@ class _GroupedLinear(torch.autograd.Function):
)
)
inp_view
=
inp
.
reshape
(
-
1
,
in_features
)
inp_view
=
inp
.
reshape
(
-
1
,
in_features
)
inputmats
:
list
inputmats
:
list
if
fp8
:
if
fp8
and
not
debug
:
inputmats
=
tex
.
split_quantize
(
inp_view
,
m_splits
,
input_quantizers
)
# Disable bulk allocation when CPU offloading is active: offloading skips small
# tensors (like scales), but bulk allocation shares storage across all tensors,
# so if scales can't be offloaded, nothing in the group can be offloaded.
inputmats
=
tex
.
split_quantize
(
inp_view
,
m_splits
,
input_quantizers
,
disable_bulk_allocation
=
cpu_offloading
)
elif
debug
:
inputmats
=
DebugQuantizer
.
multi_tensor_quantize
(
inp_view
,
input_quantizers
,
m_splits
,
activation_dtype
)
else
:
else
:
inputmats
=
torch
.
split
(
cast_if_needed
(
inp_view
,
activation_dtype
),
m_splits
)
inputmats
=
torch
.
split
(
cast_if_needed
(
inp_view
,
activation_dtype
),
m_splits
)
...
@@ -143,7 +163,7 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -143,7 +163,7 @@ class _GroupedLinear(torch.autograd.Function):
# Initialize weights
# Initialize weights
weights_fp8
:
list
weights_fp8
:
list
if
fp8
:
if
fp8
or
debug
:
# FP8 cast to workspace buffer
# FP8 cast to workspace buffer
weights_fp8
=
[]
weights_fp8
=
[]
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
...
@@ -154,6 +174,7 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -154,6 +174,7 @@ class _GroupedLinear(torch.autograd.Function):
cache_name
=
(
None
if
is_first_microbatch
is
None
else
f
"weight
{
i
}
"
),
cache_name
=
(
None
if
is_first_microbatch
is
None
else
f
"weight
{
i
}
"
),
update_workspace
=
update_workspace
,
update_workspace
=
update_workspace
,
skip_update_flag
=
skip_fp8_weight_update
,
skip_update_flag
=
skip_fp8_weight_update
,
workspace_dtype
=
activation_dtype
,
)
)
weights_fp8
.
append
(
weight_fp8
)
weights_fp8
.
append
(
weight_fp8
)
...
@@ -165,7 +186,6 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -165,7 +186,6 @@ class _GroupedLinear(torch.autograd.Function):
if
fp8
and
activation_dtype
==
torch
.
float32
:
if
fp8
and
activation_dtype
==
torch
.
float32
:
bias_dtype
=
torch
.
bfloat16
# FP8 GEMM only supports BF16/FP16 bias
bias_dtype
=
torch
.
bfloat16
# FP8 GEMM only supports BF16/FP16 bias
biases
=
[
cast_if_needed
(
bias
,
bias_dtype
)
for
bias
in
biases
]
if
use_bias
else
biases
biases
=
[
cast_if_needed
(
bias
,
bias_dtype
)
for
bias
in
biases
]
if
use_bias
else
biases
# Initialize output tensor
# Initialize output tensor
out
=
torch
.
empty
(
out
=
torch
.
empty
(
[
sum
(
m_splits
),
weights_fp8
[
0
].
size
(
0
)],
[
sum
(
m_splits
),
weights_fp8
[
0
].
size
(
0
)],
...
@@ -181,12 +201,12 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -181,12 +201,12 @@ class _GroupedLinear(torch.autograd.Function):
use_split_accumulator
=
recipe
.
fp8_gemm_fprop
.
use_split_accumulator
use_split_accumulator
=
recipe
.
fp8_gemm_fprop
.
use_split_accumulator
# Perform GEMM
# Perform GEMM
_
=
general_grouped_gemm
(
general_grouped_gemm
(
weights_fp8
,
weights_fp8
,
inputmats
,
inputmats
,
[
out
],
[
out
],
output_quantizers
,
activation_dtype
,
activation_dtype
,
get_multi_stream_cublas_workspace
(),
single_output
=
True
,
single_output
=
True
,
m_splits
=
m_splits
,
m_splits
=
m_splits
,
bias
=
biases
,
bias
=
biases
,
...
@@ -243,6 +263,10 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -243,6 +263,10 @@ class _GroupedLinear(torch.autograd.Function):
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
tensor_objects
=
tensor_objects
ctx
.
tensor_objects
=
tensor_objects
ctx
.
grad_input_quantizers
=
grad_input_quantizers
ctx
.
grad_output_quantizers
=
grad_output_quantizers
ctx
.
grad_weight_quantizers
=
grad_weight_quantizers
ctx
.
weights_requires_grad
=
weights
[
0
].
requires_grad
ctx
.
weights_requires_grad
=
weights
[
0
].
requires_grad
if
fuse_wgrad_accumulation
and
ctx
.
weights_requires_grad
:
if
fuse_wgrad_accumulation
and
ctx
.
weights_requires_grad
:
# This check is needed to ensure that main_grad is not created
# This check is needed to ensure that main_grad is not created
...
@@ -258,7 +282,7 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -258,7 +282,7 @@ class _GroupedLinear(torch.autograd.Function):
else
:
else
:
ctx
.
main_grad_funcs
=
[
lambda
:
None
for
i
in
range
(
num_gemms
)]
ctx
.
main_grad_funcs
=
[
lambda
:
None
for
i
in
range
(
num_gemms
)]
ctx
.
device
=
device
ctx
.
device
=
device
ctx
.
grad_
output_quantizers
=
grad_
output_quantizers
ctx
.
output_quantizers
=
output_quantizers
ctx
.
m_splits
=
m_splits
ctx
.
m_splits
=
m_splits
ctx
.
num_gemms
=
num_gemms
ctx
.
num_gemms
=
num_gemms
ctx
.
activation_dtype
=
activation_dtype
ctx
.
activation_dtype
=
activation_dtype
...
@@ -278,6 +302,7 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -278,6 +302,7 @@ class _GroupedLinear(torch.autograd.Function):
or
FP8GlobalStateManager
.
is_first_fp8_module
()
or
FP8GlobalStateManager
.
is_first_fp8_module
()
)
)
ctx
.
wgrad_store
=
wgrad_store
ctx
.
wgrad_store
=
wgrad_store
ctx
.
debug
=
debug
ctx
.
save_original_input
=
save_original_input
ctx
.
save_original_input
=
save_original_input
ctx
.
input_quantizers
=
input_quantizers
ctx
.
input_quantizers
=
input_quantizers
...
@@ -287,7 +312,7 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -287,7 +312,7 @@ class _GroupedLinear(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
with
torch
.
cuda
.
nvtx
.
range
(
"_GroupedLinear_backward"
):
with
get_
nvtx
_
range
_context
(
"_GroupedLinear_backward"
):
saved_tensors
=
restore_from_saved
(
ctx
.
tensor_objects
,
ctx
.
saved_tensors
)
saved_tensors
=
restore_from_saved
(
ctx
.
tensor_objects
,
ctx
.
saved_tensors
)
N
=
ctx
.
num_gemms
N
=
ctx
.
num_gemms
inputmats
=
saved_tensors
[:
N
]
inputmats
=
saved_tensors
[:
N
]
...
@@ -310,7 +335,7 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -310,7 +335,7 @@ class _GroupedLinear(torch.autograd.Function):
grad_output_view
=
grad_output
.
contiguous
().
view
(
-
1
,
grad_output
.
shape
[
-
1
])
grad_output_view
=
grad_output
.
contiguous
().
view
(
-
1
,
grad_output
.
shape
[
-
1
])
grad_output
=
[
None
]
*
ctx
.
num_gemms
grad_output
=
[
None
]
*
ctx
.
num_gemms
grad_biases
=
[
None
]
*
ctx
.
num_gemms
grad_biases
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
fp8
:
if
ctx
.
fp8
and
not
ctx
.
debug
:
if
ctx
.
use_bias
:
if
ctx
.
use_bias
:
grad_output_mats
=
torch
.
split
(
grad_output_view
,
ctx
.
m_splits
)
grad_output_mats
=
torch
.
split
(
grad_output_view
,
ctx
.
m_splits
)
recipe
=
ctx
.
fp8_recipe
recipe
=
ctx
.
fp8_recipe
...
@@ -337,6 +362,13 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -337,6 +362,13 @@ class _GroupedLinear(torch.autograd.Function):
ctx
.
m_splits
,
ctx
.
m_splits
,
ctx
.
grad_output_quantizers
,
ctx
.
grad_output_quantizers
,
)
)
elif
ctx
.
debug
:
grad_output_mats
=
torch
.
split
(
grad_output_view
,
ctx
.
m_splits
)
for
i
in
range
(
ctx
.
num_gemms
):
grad_biases
[
i
]
=
grad_output_mats
[
i
].
sum
(
dim
=
0
)
grad_output
=
DebugQuantizer
.
multi_tensor_quantize
(
grad_output_view
,
ctx
.
grad_output_quantizers
,
ctx
.
m_splits
,
ctx
.
activation_dtype
)
else
:
else
:
# Only split grad output. Grad bias is fused with
# Only split grad output. Grad bias is fused with
# wgrad GEMM.
# wgrad GEMM.
...
@@ -354,7 +386,7 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -354,7 +386,7 @@ class _GroupedLinear(torch.autograd.Function):
if
ctx
.
requires_dgrad
:
if
ctx
.
requires_dgrad
:
dgrad_gemm_use_split_accumulator
=
_2X_ACC_DGRAD
dgrad_gemm_use_split_accumulator
=
_2X_ACC_DGRAD
if
ctx
.
fp8
:
if
ctx
.
fp8
or
ctx
.
debug
:
recipe
=
ctx
.
fp8_recipe
recipe
=
ctx
.
fp8_recipe
if
hasattr
(
recipe
,
"fp8_gemm_dgrad"
):
if
hasattr
(
recipe
,
"fp8_gemm_dgrad"
):
dgrad_gemm_use_split_accumulator
=
(
dgrad_gemm_use_split_accumulator
=
(
...
@@ -374,8 +406,8 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -374,8 +406,8 @@ class _GroupedLinear(torch.autograd.Function):
weights
,
weights
,
grad_output
,
grad_output
,
[
dgrad
],
[
dgrad
],
ctx
.
grad_input_quantizers
,
ctx
.
activation_dtype
,
ctx
.
activation_dtype
,
get_multi_stream_cublas_workspace
(),
single_output
=
True
,
single_output
=
True
,
layout
=
"NN"
,
layout
=
"NN"
,
m_splits
=
ctx
.
m_splits
,
m_splits
=
ctx
.
m_splits
,
...
@@ -412,17 +444,20 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -412,17 +444,20 @@ class _GroupedLinear(torch.autograd.Function):
else
:
else
:
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
inputmats
:
list
inputmats
:
list
if
ctx
.
fp8
:
if
ctx
.
fp8
and
not
ctx
.
debug
:
inputmats
=
tex
.
split_quantize
(
inp_view
,
ctx
.
m_splits
,
ctx
.
input_quantizers
)
inputmats
=
tex
.
split_quantize
(
inp_view
,
ctx
.
m_splits
,
ctx
.
input_quantizers
)
elif
ctx
.
debug
:
inputmats
=
DebugQuantizer
.
multi_tensor_quantize
(
inp_view
,
ctx
.
input_quantizers
,
ctx
.
m_splits
,
ctx
.
activation_dtype
)
else
:
else
:
inputmats
=
torch
.
split
(
inputmats
=
torch
.
split
(
cast_if_needed
(
inp_view
,
ctx
.
activation_dtype
),
ctx
.
m_splits
cast_if_needed
(
inp_view
,
ctx
.
activation_dtype
),
ctx
.
m_splits
)
)
grouped_gemm_wgrad
=
functools
.
partial
(
grouped_gemm_wgrad
=
functools
.
partial
(
general_grouped_gemm
,
general_grouped_gemm
,
quantization_params
=
ctx
.
grad_weight_quantizers
,
out_dtype
=
ctx
.
activation_dtype
,
out_dtype
=
ctx
.
activation_dtype
,
workspaces
=
get_multi_stream_cublas_workspace
(),
layout
=
"NT"
,
layout
=
"NT"
,
grad
=
True
,
grad
=
True
,
m_splits
=
ctx
.
m_splits
,
m_splits
=
ctx
.
m_splits
,
...
@@ -494,28 +529,11 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -494,28 +529,11 @@ class _GroupedLinear(torch.autograd.Function):
):
):
grad_biases
=
[
None
]
*
ctx
.
num_gemms
grad_biases
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
reduce_and_update_bwd_fp8_tensors
and
not
is_graph_capturing
()
:
if
ctx
.
reduce_and_update_bwd_fp8_tensors
:
FP8GlobalStateManager
.
reduce_and_update_fp8_tensors
(
forward
=
False
)
FP8GlobalStateManager
.
reduce_and_update_fp8_tensors
(
forward
=
False
)
return
(
return
(
dgrad
.
view
(
ctx
.
inp_shape
)
if
ctx
.
requires_dgrad
else
None
,
dgrad
.
view
(
ctx
.
inp_shape
)
if
ctx
.
requires_dgrad
else
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
wgrad_list
,
*
wgrad_list
,
*
grad_biases
,
*
grad_biases
,
)
)
...
@@ -533,14 +551,14 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -533,14 +551,14 @@ class GroupedLinear(TransformerEngineBaseModule):
size of each input sample.
size of each input sample.
out_features : int
out_features : int
size of each output sample.
size of each output sample.
bias : bool, default =
`
True
`
bias : bool, default = True
if set to `False`, the layer will not learn an additive bias.
if set to
`
`False`
`
, the layer will not learn an additive bias.
init_method : Callable, default =
`
None
`
init_method : Callable, default = None
used for initializing weights in the following way: `init_method(weight)`.
used for initializing weights in the following way:
`
`init_method(weight)`
`
.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
When set to
`
`None`
`
, defaults to
`
`torch.nn.init.normal_(mean=0.0, std=0.023)`
`
.
get_rng_state_tracker : Callable, default =
`
None
`
get_rng_state_tracker : Callable, default = None
used to get the random number generator state tracker for initializing weights.
used to get the random number generator state tracker for initializing weights.
rng_tracker_name : str, default =
`
None
`
rng_tracker_name : str, default = None
the param passed to get_rng_state_tracker to get the specific rng tracker.
the param passed to get_rng_state_tracker to get the specific rng tracker.
device : Union[torch.device, str], default = "cuda"
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
...
@@ -549,33 +567,35 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -549,33 +567,35 @@ class GroupedLinear(TransformerEngineBaseModule):
Optimization parameters
Optimization parameters
-----------------------
-----------------------
fuse_wgrad_accumulation : bool, default =
'
False
'
fuse_wgrad_accumulation : bool, default = False
if set to `True`, enables fusing of creation and accumulation of
if set to
`
`True`
`
, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
have an additional
`
`main_grad`
`
attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
regular
`
`grad`
`
) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
will overwrite
`
`main_grad`
`
instead of accumulating.
return_bias : bool, default =
`
False
`
return_bias : bool, default = False
when set to `True`, this module will not apply the additive bias itself, but
when set to
`
`True`
`
, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default =
`
torch.get_default_dtype()
`
params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
would not fit in GPU memory.
delay_wgrad_compute : bool, default =
`
False
`
delay_wgrad_compute : bool, default = False
Whether to delay weight gradient computation
Whether to delay weight gradient computation
save_original_input : bool, default =
`
False
`
save_original_input : bool, default = False
If set to `True`, always saves the original input tensor rather than the
If set to
`
`True`
`
, always saves the original input tensor rather than the
cast tensor. In some scenarios, the input tensor is used by multiple modules,
cast tensor. In some scenarios, the input tensor is used by multiple modules,
and saving the original input tensor may reduce the memory usage.
and saving the original input tensor may reduce the memory usage.
Cannot work with FP8 DelayedScaling recipe.
Cannot work with FP8 DelayedScaling recipe.
Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and
Notes
`parallel_mode` are used to determine the shapes of weights and biases.
-----
GroupedLinear doesn't really handle the TP communications inside. The ``tp_size`` and
``parallel_mode`` are used to determine the shapes of weights and biases.
The TP communication should be handled in the dispatch and combine stages of MoE models.
The TP communication should be handled in the dispatch and combine stages of MoE models.
"""
"""
...
@@ -601,6 +621,7 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -601,6 +621,7 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_name
:
Optional
[
str
]
=
None
,
ub_name
:
Optional
[
str
]
=
None
,
delay_wgrad_compute
:
bool
=
False
,
delay_wgrad_compute
:
bool
=
False
,
save_original_input
:
bool
=
False
,
save_original_input
:
bool
=
False
,
name
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -621,6 +642,7 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -621,6 +642,7 @@ class GroupedLinear(TransformerEngineBaseModule):
),
"GroupedLinear doesn't support Userbuffer overlap."
),
"GroupedLinear doesn't support Userbuffer overlap."
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
rng_tracker_name
=
rng_tracker_name
self
.
name
=
name
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
...
@@ -694,7 +716,8 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -694,7 +716,8 @@ class GroupedLinear(TransformerEngineBaseModule):
if
self
.
primary_weights_in_fp8
:
if
self
.
primary_weights_in_fp8
:
self
.
init_fp8_metadata
(
num_gemms
=
self
.
num_gemms
)
self
.
init_fp8_metadata
(
num_gemms
=
self
.
num_gemms
)
self
.
reset_parameters
(
defer_init
=
device
==
"meta"
)
is_meta
=
torch
.
device
(
device
).
type
==
"meta"
self
.
reset_parameters
(
defer_init
=
is_meta
)
if
self
.
wgrad_store
.
delay_wgrad_compute
():
if
self
.
wgrad_store
.
delay_wgrad_compute
():
for
name
,
param
in
self
.
named_parameters
():
for
name
,
param
in
self
.
named_parameters
():
...
@@ -706,13 +729,9 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -706,13 +729,9 @@ class GroupedLinear(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd."""
"""Init scales and amaxes for fwd | bwd."""
super
().
set_meta_tensor
(
fwd
,
recipe
)
super
().
set_meta_tensor
(
fwd
,
recipe
)
#
customize quantizers based on each recipe & layer configs
#
Recipe-specific quantizer configuration
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_current_scaling
():
if
recipe
.
float8_current_scaling
():
assert
not
self
.
tp_size
>
1
,
(
"GroupedLinear doesn't support TP > 1 with Float8 current scaling. "
"Because the TP communication is handled outside of this module."
)
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
def
reset_parameters
(
self
,
defer_init
=
False
):
def
reset_parameters
(
self
,
defer_init
=
False
):
...
@@ -770,58 +789,46 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -770,58 +789,46 @@ class GroupedLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
first microbatch (since it is the first gradient being
produced)
produced)
"""
"""
debug
=
self
.
is_debug_iter
()
assert
not
isinstance
(
assert
not
isinstance
(
inp
,
QuantizedTensorStorage
inp
,
QuantizedTensorStorage
),
"GroupedLinear doesn't support input tensor in FP8."
),
"GroupedLinear doesn't support input tensor in FP8."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
if
FP8GlobalStateManager
.
fp8_graph_capturing
():
is_grad_enabled
=
torch
.
is_grad_enabled
()
skip_fp8_weight_update
=
FP8GlobalStateManager
.
get_skip_fp8_weight_update_tensor
()
else
:
skip_fp8_weight_update
=
None
if
skip_fp8_weight_update
is
not
None
:
is_first_microbatch
=
False
with
torch
.
cuda
.
device
(
with
self
.
prepare_forward
(
inp
,
num_gemms
=
self
.
num_gemms
)
as
inp
:
getattr
(
self
,
list
(
self
.
named_parameters
())[
0
][
0
]).
device
),
self
.
prepare_forward
(
inp
,
num_gemms
=
self
.
num_gemms
)
as
inp
:
weight_tensors
=
self
.
_get_weight_tensors
()
weight_tensors
=
self
.
_get_weight_tensors
()
bias_tensors
=
[
getattr
(
self
,
f
"bias
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
bias_tensors
=
[
getattr
(
self
,
f
"bias
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
weight_quantizers
=
self
.
_get_weight_quantizers
()
quantizers
=
self
.
_get_quantizers
()
if
not
debug
else
self
.
_get_debug_quantizers
()
input_quantizers
,
output_quantizers
=
(
[
None
]
*
self
.
num_gemms
,
[
None
]
*
self
.
num_gemms
,
)
grad_output_quantizers
,
_
=
[
None
]
*
self
.
num_gemms
,
[
None
]
*
self
.
num_gemms
if
self
.
fp8
:
input_quantizers
=
[
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"fwd"
]
]
for
i
in
range
(
self
.
num_gemms
)
]
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for
i
in
range
(
self
.
num_gemms
):
input_quantizers
[
i
].
internal
=
False
if
torch
.
is_grad_enabled
():
grad_output_quantizers
=
[
self
.
quantizers
[
"scaling_bwd"
][
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"bwd"
]
]
for
i
in
range
(
self
.
num_gemms
)
]
for
i
in
range
(
self
.
num_gemms
):
grad_output_quantizers
[
i
].
internal
=
True
if
torch
.
is_grad_enabled
():
if
debug
:
if
self
.
no_debug_features_active
(
list
(
chain
(
*
quantizers
))):
debug
=
False
quantizers
=
self
.
_get_quantizers
()
if
isinstance
(
weight_tensors
,
QuantizedTensorStorage
):
raise
RuntimeError
(
"FP8 weights are not supported in debug mode."
)
(
input_quantizers
,
weight_quantizers
,
output_quantizers
,
grad_input_quantizers
,
grad_weight_quantizers
,
grad_output_quantizers
,
)
=
quantizers
if
is_grad_enabled
:
linear_fn
=
_GroupedLinear
.
apply
linear_fn
=
_GroupedLinear
.
apply
a
rgs
=
[]
a
utograd_ctx
=
[]
else
:
else
:
linear_fn
=
_GroupedLinear
.
forward
linear_fn
=
_GroupedLinear
.
forward
a
rgs
=
[
None
]
a
utograd_ctx
=
[
None
]
args
+=
(
inp
,
non_tensor_args
=
(
m_splits
,
m_splits
,
self
.
apply_bias
,
self
.
apply_bias
,
is_first_microbatch
,
is_first_microbatch
,
...
@@ -831,19 +838,20 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -831,19 +838,20 @@ class GroupedLinear(TransformerEngineBaseModule):
input_quantizers
,
input_quantizers
,
weight_quantizers
,
weight_quantizers
,
output_quantizers
,
output_quantizers
,
grad_input_quantizers
,
grad_weight_quantizers
,
grad_output_quantizers
,
grad_output_quantizers
,
self
.
fuse_wgrad_accumulation
,
self
.
fuse_wgrad_accumulation
,
is_cpu_offload_enabled
(),
is_cpu_offload_enabled
(),
self
.
sequence_parallel
,
self
.
sequence_parallel
,
self
.
activation_dtype
,
self
.
activation_dtype
,
torch
.
is_grad_enabled
()
,
is_grad_enabled
,
self
,
self
,
skip_fp8_weight_update
,
None
,
#
skip_fp8_weight_update
self
.
save_original_input
,
self
.
save_original_input
,
*
weight_tensors
,
debug
,
*
bias_tensors
,
)
)
out
=
linear_fn
(
*
ar
g
s
)
out
=
linear_fn
(
*
a
utograd_ctx
,
inp
,
non_tensor_args
,
*
weight_tensors
,
*
bias_tenso
rs
)
if
self
.
return_bias
:
if
self
.
return_bias
:
return
out
,
[
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
b
in
bias_tensors
]
return
out
,
[
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
b
in
bias_tensors
]
...
@@ -856,7 +864,7 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -856,7 +864,7 @@ class GroupedLinear(TransformerEngineBaseModule):
"""
"""
if
not
self
.
need_backward_dw
():
if
not
self
.
need_backward_dw
():
return
return
with
torch
.
cuda
.
nvtx
.
range
(
"_GroupedLinear_wgrad"
):
with
get_
nvtx
_
range
_context
(
"_GroupedLinear_wgrad"
):
(
_
,
grad_biases_
,
_
),
tensor_list
=
self
.
wgrad_store
.
pop
()
(
_
,
grad_biases_
,
_
),
tensor_list
=
self
.
wgrad_store
.
pop
()
wgrad_list
=
tensor_list
[
2
]
wgrad_list
=
tensor_list
[
2
]
weight_params
=
[
getattr
(
self
,
f
"weight
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
weight_params
=
[
getattr
(
self
,
f
"weight
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
...
@@ -876,9 +884,12 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -876,9 +884,12 @@ class GroupedLinear(TransformerEngineBaseModule):
def
_customize_quantizers_float8_current_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
def
_customize_quantizers_float8_current_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on current scaling recipe + linear."""
"""Customize quantizers based on current scaling recipe + linear."""
assert
(
recipe
.
float8_current_scaling
()
assert
not
self
.
tp_size
>
1
,
(
),
"current scaling recipe quantizer customization here"
"GroupedLinear doesn't support TP > 1 with Float8 current scaling. "
"Because the TP communication is handled outside of this module."
)
if
fwd
:
if
fwd
:
for
i
in
range
(
self
.
num_gemms
):
for
i
in
range
(
self
.
num_gemms
):
# set configs about amax epsilon and power_2_scale
# set configs about amax epsilon and power_2_scale
...
@@ -932,3 +943,56 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -932,3 +943,56 @@ class GroupedLinear(TransformerEngineBaseModule):
for
i
in
range
(
self
.
num_gemms
):
for
i
in
range
(
self
.
num_gemms
):
weight_quantizers
[
i
].
internal
=
True
weight_quantizers
[
i
].
internal
=
True
return
weight_quantizers
return
weight_quantizers
def
_get_quantizers
(
self
):
weight_quantizers
=
self
.
_get_weight_quantizers
()
input_quantizers
,
output_quantizers
=
(
[
None
]
*
self
.
num_gemms
,
[
None
]
*
self
.
num_gemms
,
)
grad_input_quantizers
,
grad_weight_quantizers
,
grad_output_quantizers
=
(
[
None
]
*
self
.
num_gemms
,
[
None
]
*
self
.
num_gemms
,
[
None
]
*
self
.
num_gemms
,
)
if
self
.
fp8
:
input_quantizers
=
[
self
.
quantizers
[
"scaling_fwd"
][
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"fwd"
]
]
for
i
in
range
(
self
.
num_gemms
)
]
for
i
in
range
(
self
.
num_gemms
):
input_quantizers
[
i
].
internal
=
True
input_quantizers
[
i
].
optimize_for_gemm
=
True
if
torch
.
is_grad_enabled
():
grad_output_quantizers
=
[
self
.
quantizers
[
"scaling_bwd"
][
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"bwd"
]
]
for
i
in
range
(
self
.
num_gemms
)
]
for
i
in
range
(
self
.
num_gemms
):
grad_output_quantizers
[
i
].
internal
=
True
grad_output_quantizers
[
i
].
optimize_for_gemm
=
True
return
(
input_quantizers
,
weight_quantizers
,
output_quantizers
,
grad_input_quantizers
,
grad_weight_quantizers
,
grad_output_quantizers
,
)
def
_get_debug_quantizers
(
self
):
original_quantizers
=
self
.
_get_quantizers
()
assert
TEDebugState
.
debug_enabled
names
=
[
"activation"
,
"weight"
,
"output"
,
"dgrad"
,
"wgrad"
,
"gradient"
]
return
tuple
(
[
DebugQuantizer
(
self
.
name
+
f
".gemm_
{
q_id
}
"
,
name
,
q
,
self
.
tp_group
)
for
q_id
,
q
in
enumerate
(
qs
)
]
for
name
,
qs
in
zip
(
names
,
original_quantizers
)
)
transformer_engine/pytorch/module/layernorm.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -28,33 +28,30 @@ class LayerNorm(_LayerNormOp):
...
@@ -28,33 +28,30 @@ class LayerNorm(_LayerNormOp):
Parameters
Parameters
----------
----------
normalized_shape: int or iterable of int
normalized_shape
: int or iterable of int
Inner dimensions of input tensor
Inner dimensions of input tensor
eps : float, default = 1e-5
eps : float, default = 1e-5
A value added to the denominator of layer normalization for
A value added to the denominator of layer normalization for
numerical stability
numerical stability
device: torch.device, default = default CUDA device
device
: torch.device, default = default CUDA device
Tensor device
Tensor device
dtype: torch.dtype, default = default dtype
dtype
: torch.dtype, default = default dtype
Tensor datatype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero
If
`
`True`
`
, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to
and the calculation changes to
.. math::
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
sm_margin: int or dict, default = 0
sm_margin
: int or dict, default = 0
Number of SMs to exclude when launching CUDA kernels. This
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward",
margin at each compute stage (``"forward"``, ``"backward"``,
"inference").
``"inference"``).
sequence_parallel : bool
Legacy
**Legacy parameter.** Set a bool attr named ``sequence_parallel`` in the parameters.
------
sequence_parallel: bool
Set a bool attr named `sequence_parallel` in the parameters.
This is custom logic for Megatron-LM integration.
This is custom logic for Megatron-LM integration.
"""
"""
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -15,11 +15,10 @@ from torch.nn import init
...
@@ -15,11 +15,10 @@ from torch.nn import init
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
.torch_version
import
torch_version
from
transformer_engine.pytorch.tensor.utils
import
is_custom
from
transformer_engine.pytorch.tensor.utils
import
is_custom
from
.base
import
(
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
get_ub
,
get_ub
,
TransformerEngineBaseModule
,
TransformerEngineBaseModule
,
get_dummy_wgrad
,
get_dummy_wgrad
,
...
@@ -40,6 +39,7 @@ from ..utils import (
...
@@ -40,6 +39,7 @@ from ..utils import (
nvtx_range_push
,
nvtx_range_push
,
requires_grad
,
requires_grad
,
needs_quantized_gemm
,
needs_quantized_gemm
,
get_nvtx_range_context
,
)
)
from
..distributed
import
(
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
set_tensor_model_parallel_attributes
,
...
@@ -64,7 +64,6 @@ from ..quantized_tensor import (
...
@@ -64,7 +64,6 @@ from ..quantized_tensor import (
restore_from_saved
,
restore_from_saved
,
)
)
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.debug_state
import
TEDebugState
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..cpu_offload
import
(
from
..cpu_offload
import
(
is_cpu_offload_enabled
,
is_cpu_offload_enabled
,
...
@@ -107,47 +106,53 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -107,47 +106,53 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias
:
Union
[
torch
.
Tensor
,
None
],
ln_bias
:
Union
[
torch
.
Tensor
,
None
],
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
eps
:
float
,
non_tensor_args
:
Tuple
,
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
fuse_wgrad_accumulation
:
bool
,
input_quantizer
:
Optional
[
Quantizer
],
weight_quantizer
:
Optional
[
Quantizer
],
output_quantizer
:
Optional
[
Quantizer
],
grad_input_quantizer
:
Optional
[
Quantizer
],
grad_weight_quantizer
:
Optional
[
Quantizer
],
grad_output_quantizer
:
Optional
[
Quantizer
],
cpu_offloading
:
bool
,
tp_group
:
Union
[
dist_group_type
,
None
],
tp_size
:
int
,
sequence_parallel
:
bool
,
tensor_parallel
:
bool
,
activation_dtype
:
torch
.
dtype
,
parallel_mode
:
Union
[
str
,
None
],
return_layernorm_output
:
bool
,
return_layernorm_output_gathered
:
bool
,
is_grad_enabled
:
bool
,
fwd_ln_sm_margin
:
int
,
bwd_ln_sm_margin
:
int
,
zero_centered_gamma
:
bool
,
normalization
:
str
,
ub_overlap_ag_fprop
:
bool
,
ub_overlap_rs_fprop
:
bool
,
ub_overlap_ag_dgrad
:
bool
,
ub_overlap_rs_dgrad
:
bool
,
ub_bulk_wgrad
:
bool
,
ub_bulk_dgrad
:
bool
,
ub_name
:
str
,
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
skip_fp8_weight_update
:
bool
,
symmetric_ar_type
:
str
,
debug
:
Optional
[
bool
]
=
False
,
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
eps
,
is_first_microbatch
,
fp8
,
fp8_calibration
,
wgrad_store
,
fuse_wgrad_accumulation
,
input_quantizer
,
weight_quantizer
,
output_quantizer
,
grad_input_quantizer
,
grad_weight_quantizer
,
grad_output_quantizer
,
cpu_offloading
,
tp_group
,
tp_size
,
sequence_parallel
,
tensor_parallel
,
activation_dtype
,
parallel_mode
,
return_layernorm_output
,
return_layernorm_output_gathered
,
is_grad_enabled
,
fwd_ln_sm_margin
,
bwd_ln_sm_margin
,
zero_centered_gamma
,
normalization
,
ub_overlap_ag_fprop
,
ub_overlap_rs_fprop
,
ub_overlap_ag_dgrad
,
ub_overlap_rs_dgrad
,
ub_bulk_wgrad
,
ub_bulk_dgrad
,
ub_name
,
fsdp_group
,
module
,
skip_fp8_weight_update
,
symmetric_ar_type
,
debug
,
)
=
non_tensor_args
# NVTX label for profiling
# NVTX label for profiling
nvtx_label
=
"transformer_engine._LayerNormLinear.forward"
nvtx_label
=
"transformer_engine._LayerNormLinear.forward"
if
ub_name
is
not
None
:
if
ub_name
is
not
None
:
...
@@ -258,8 +263,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -258,8 +263,6 @@ class _LayerNormLinear(torch.autograd.Function):
if
fp8
or
debug
:
if
fp8
or
debug
:
ln_out
=
input_quantizer
(
ln_out
)
ln_out
=
input_quantizer
(
ln_out
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
isinstance
(
input_quantizer
,
Float8BlockQuantizer
):
input_quantizer
.
all_gather_usage
=
False
ln_out_total
=
input_quantizer
(
ln_out_total
)
ln_out_total
=
input_quantizer
(
ln_out_total
)
else
:
else
:
quantizer
=
None
quantizer
=
None
...
@@ -366,7 +369,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -366,7 +369,6 @@ class _LayerNormLinear(torch.autograd.Function):
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weightmat
,
weightmat
,
ln_out_total
,
ln_out_total
,
get_workspace
(),
quantization_params
=
output_quantizer
,
quantization_params
=
output_quantizer
,
out_dtype
=
activation_dtype
,
out_dtype
=
activation_dtype
,
bias
=
bias
,
bias
=
bias
,
...
@@ -555,7 +557,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -555,7 +557,7 @@ class _LayerNormLinear(torch.autograd.Function):
if
ctx
.
ub_name
is
not
None
:
if
ctx
.
ub_name
is
not
None
:
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ctx
.
ub_name
}
"
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ctx
.
ub_name
}
"
with
torch
.
cuda
.
nvtx
.
range
(
"_LayerNormLinear_backward"
):
with
get_
nvtx
_
range
_context
(
"_LayerNormLinear_backward"
):
saved_tensors
=
ctx
.
saved_tensors
saved_tensors
=
ctx
.
saved_tensors
(
# pylint: disable=unbalanced-tuple-unpacking
(
# pylint: disable=unbalanced-tuple-unpacking
inputmat
,
inputmat
,
...
@@ -743,7 +745,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -743,7 +745,6 @@ class _LayerNormLinear(torch.autograd.Function):
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weight
,
weight
,
grad_output
,
grad_output
,
get_workspace
(),
layout
=
"NN"
,
layout
=
"NN"
,
grad
=
True
,
grad
=
True
,
quantization_params
=
ctx
.
grad_input_quantizer
,
quantization_params
=
ctx
.
grad_input_quantizer
,
...
@@ -870,7 +871,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -870,7 +871,6 @@ class _LayerNormLinear(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure
# Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs
=
{
wgrad_gemm_kwargs
=
{
"workspace"
:
get_workspace
(),
"out_dtype"
:
(
"out_dtype"
:
(
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
),
...
@@ -1045,44 +1045,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -1045,44 +1045,7 @@ class _LayerNormLinear(torch.autograd.Function):
dbeta
,
dbeta
,
wgrad
,
wgrad
,
grad_bias
,
grad_bias
,
None
,
# eps
None
,
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8_calibration
None
,
# wgrad_store
None
,
# fuse_wgrad_accumulation
None
,
# input_quantizer
None
,
# weight_quantizer
None
,
# output_quantizer
None
,
# grad_input_quantizer
None
,
# grad_weight_quantizer
None
,
# grad_output_quantizer
None
,
# cpu_offloading
None
,
# tp_group
None
,
# tp_size
None
,
# sequence_parallel
None
,
# tensor_parallel
None
,
# activation_dtype
None
,
# parallel_mode
None
,
# return_layernorm_output
None
,
# return_layernorm_output_gathered
None
,
# is_grad_enabled
None
,
# fwd_ln_sm_margin
None
,
# bwd_ln_sm_margin
None
,
# zero_centered_gamma
None
,
# normalization
None
,
# ub_overlap_ag_fprop
None
,
# ub_overlap_rs_fprop
None
,
# ub_overlap_ag_dgrad
None
,
# ub_overlap_rs_dgrad
None
,
# ub_bulk_dgrad
None
,
# ub_bulk_wgrad
None
,
# ub_name
None
,
# fsdp_group
None
,
# debug
None
,
# module
None
,
# skip_fp8_weight_update
None
,
# symmetric_ar_type
)
)
...
@@ -1098,20 +1061,20 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1098,20 +1061,20 @@ class LayerNormLinear(TransformerEngineBaseModule):
size of each output sample.
size of each output sample.
eps : float, default = 1e-5
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
a value added to the denominator of layer normalization for numerical stability.
bias : bool, default =
`
True
`
bias : bool, default = True
if set to `False`, the layer will not learn an additive bias.
if set to
`
`False`
`
, the layer will not learn an additive bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
type of normalization applied.
init_method : Callable, default =
`
None
`
init_method : Callable, default = None
used for initializing weights in the following way: `init_method(weight)`.
used for initializing weights in the following way:
`
`init_method(weight)`
`
.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
When set to
`
`None`
`
, defaults to
`
`torch.nn.init.normal_(mean=0.0, std=0.023)`
`
.
return_layernorm_output : bool, default =
`
False
`
return_layernorm_output : bool, default = False
if set to `True`, output of layernorm is returned from the forward
if set to
`
`True`
`
, output of layernorm is returned from the forward
together with the output of the linear transformation.
together with the output of the linear transformation.
Example use case: residual connection for transformer module is
Example use case: residual connection for transformer module is
taken post layernorm.
taken post layernorm.
return_layernorm_output_gathered : bool, default =
`
False
`
return_layernorm_output_gathered : bool, default = False
if set to `True`, output of layernorm is returned after the all
if set to
`
`True`
`
, output of layernorm is returned after the all
gather operation. Ignored if return_layernorm_output is False.
gather operation. Ignored if return_layernorm_output is False.
Example use case: with sequence parallel, input to residual connection
Example use case: with sequence parallel, input to residual connection
for transformer module (e.g. LoRA) will need to be gathered.
for transformer module (e.g. LoRA) will need to be gathered.
...
@@ -1122,10 +1085,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1122,10 +1085,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
they are used to make the names of equally-sized parameters. If a dict
they are used to make the names of equally-sized parameters. If a dict
(preferably an OrderedDict) is provided, the keys are used as names and
(preferably an OrderedDict) is provided, the keys are used as names and
values as split sizes along dim 0. The resulting parameters will have
values as split sizes along dim 0. The resulting parameters will have
names that end in `_weight` or `_bias`, so trailing underscores are
names that end in
`
`_weight`
`
or
`
`_bias`
`
, so trailing underscores are
stripped from any provided names.
stripped from any provided names.
zero_centered_gamma : bool, default = 'False'
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
if set to
``
'True'
``
, gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
the LayerNorm formula changes to
.. math::
.. math::
...
@@ -1135,53 +1098,53 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1135,53 +1098,53 @@ class LayerNormLinear(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
forward pass.
name: str, default =
`
None
`
name
: str, default = None
name of the module, currently used for debugging purposes.
name of the module, currently used for debugging purposes.
Parallelism parameters
Parallelism parameters
----------------------
----------------------
sequence_parallel : bool, default =
`
False
`
sequence_parallel : bool, default = False
if set to `True`, uses sequence parallelism.
if set to
`
`True`
`
, uses sequence parallelism.
tp_group : ProcessGroup, default =
`
None
`
tp_group : ProcessGroup, default = None
tensor parallel process group.
tensor parallel process group.
tp_size : int, default = 1
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
`
`set_tensor_parallel_group(tp_group)`
`
method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel collectives.
parallel_mode : {None, 'column', 'row'}, default =
`
None
`
parallel_mode : {None, 'column', 'row'}, default = None
used to decide whether this Linear layer is Column Parallel Linear or Row
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
When set to
`
`None`
`
, no communication is performed.
Optimization parameters
Optimization parameters
-----------------------
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
if set to
`
`True`
`
, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
have an additional
`
`main_grad`
`
attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
regular
`
`grad`
`
) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
will overwrite
`
`main_grad`
`
instead of accumulating.
return_bias : bool, default =
`
False
`
return_bias : bool, default = False
when set to `True`, this module will not apply the additive bias itself, but
when set to
`
`True`
`
, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default =
`
torch.get_default_dtype()
`
params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
would not fit in GPU memory.
delay_wgrad_compute : bool, default =
`
False
`
delay_wgrad_compute : bool, default = False
Whether or not to delay weight gradient computation. If set to `True`,
Whether or not to delay weight gradient computation. If set to
`
`True`
`
,
it's the user's responsibility to call `module.backward_dw` to compute
it's the user's responsibility to call
`
`module.backward_dw`
`
to compute
weight gradients.
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
Requires PyTorch version 2.7.0 or higher. When set to
``
None
``
, standard all-reduce
is used.
is used.
"""
"""
...
@@ -1462,15 +1425,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1462,15 +1425,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd."""
"""Init scales and amaxes for fwd | bwd."""
super
().
set_meta_tensor
(
fwd
,
recipe
)
super
().
set_meta_tensor
(
fwd
,
recipe
)
#
customize quantizers based on each recipe & layer configs
#
Recipe-specific quantizer configuration
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_current_scaling
():
if
recipe
.
float8_current_scaling
():
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
elif
recipe
.
nvfp4
():
elif
recipe
.
nvfp4
():
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
# elif other recipes (mxfp8, etc)
def
reset_layer_norm_parameters
(
self
)
->
None
:
def
reset_layer_norm_parameters
(
self
)
->
None
:
"""Init LN params"""
"""Init LN params"""
...
@@ -1542,8 +1502,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1542,8 +1502,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
first microbatch (since it is the first gradient being
produced)
produced)
"""
"""
is_grad_enabled
=
torch
.
is_grad_enabled
()
if
is_in_onnx_export_mode
():
if
is_in_onnx_export_mode
():
return
self
.
onnx_forward
(
inp
,
fp8_output
)
return
self
.
onnx_forward
(
inp
,
fp8_output
,
is_grad_enabled
)
debug
=
self
.
is_debug_iter
()
debug
=
self
.
is_debug_iter
()
...
@@ -1565,9 +1527,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1565,9 +1527,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
).
is_fp8_ubuf
():
).
is_fp8_ubuf
():
fp8_grad
=
True
fp8_grad
=
True
with
torch
.
cuda
.
device
(
with
self
.
prepare_forward
(
getattr
(
self
,
list
(
self
.
named_parameters
())[
0
][
0
]).
device
),
self
.
prepare_forward
(
inp
,
allow_non_contiguous
=
False
# removed .contiguous from inside the layer
inp
,
allow_non_contiguous
=
False
# removed .contiguous from inside the layer
)
as
inp
:
)
as
inp
:
...
@@ -1575,14 +1535,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1575,14 +1535,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
quantizers
=
(
quantizers
=
(
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
if
not
debug
if
not
debug
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
)
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
)
)
if
debug
:
if
debug
:
if
self
.
no_debug_features_active
(
quantizers
):
if
self
.
no_debug_features_active
(
quantizers
):
debug
=
False
debug
=
False
quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
(
(
input_quantizer
,
input_quantizer
,
...
@@ -1593,18 +1553,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1593,18 +1553,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer
,
grad_output_quantizer
,
)
=
quantizers
)
=
quantizers
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
fwd_fn
=
_LayerNormLinear
.
apply
fwd_fn
=
_LayerNormLinear
.
apply
a
rgs
=
[]
a
utograd_ctx
=
[]
else
:
else
:
fwd_fn
=
_LayerNormLinear
.
forward
fwd_fn
=
_LayerNormLinear
.
forward
args
=
[
None
]
autograd_ctx
=
[
None
]
args
+=
(
non_tensor_args
=
(
inp
,
self
.
layer_norm_weight
,
self
.
layer_norm_bias
,
weight_tensor
,
bias_tensor
if
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
else
None
,
self
.
eps
,
self
.
eps
,
is_first_microbatch
,
is_first_microbatch
,
self
.
fp8
,
self
.
fp8
,
...
@@ -1626,8 +1581,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1626,8 +1581,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
parallel_mode
,
self
.
parallel_mode
,
self
.
return_layernorm_output
,
self
.
return_layernorm_output
,
self
.
return_layernorm_output_gathered
,
self
.
return_layernorm_output_gathered
,
torch
.
is_grad_enabled
()
,
is_grad_enabled
,
self
.
fwd_ln_sm_margin
if
torch
.
is_grad_enabled
()
else
self
.
inf_ln_sm_margin
,
self
.
fwd_ln_sm_margin
if
is_grad_enabled
else
self
.
inf_ln_sm_margin
,
self
.
bwd_ln_sm_margin
,
self
.
bwd_ln_sm_margin
,
self
.
zero_centered_gamma
,
self
.
zero_centered_gamma
,
self
.
normalization
,
self
.
normalization
,
...
@@ -1644,7 +1599,15 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1644,7 +1599,15 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
symmetric_ar_type
,
self
.
symmetric_ar_type
,
debug
,
debug
,
)
)
out
=
fwd_fn
(
*
args
)
out
=
fwd_fn
(
*
autograd_ctx
,
inp
,
self
.
layer_norm_weight
,
self
.
layer_norm_bias
,
weight_tensor
,
bias_tensor
if
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
else
None
,
non_tensor_args
,
)
if
self
.
return_layernorm_output
:
if
self
.
return_layernorm_output
:
out
,
ln_out
=
out
out
,
ln_out
=
out
...
@@ -1660,7 +1623,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1660,7 +1623,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
return
out
,
ln_out
return
out
,
ln_out
return
out
return
out
def
_get_quantizers
(
self
,
fp8_output
,
fp8_grad
):
def
_get_quantizers
(
self
,
fp8_output
,
fp8_grad
,
is_grad_enabled
):
if
not
self
.
fp8
:
if
not
self
.
fp8
:
return
[
None
]
*
6
return
[
None
]
*
6
grad_input_quantizer
=
None
grad_input_quantizer
=
None
...
@@ -1669,12 +1632,16 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1669,12 +1632,16 @@ class LayerNormLinear(TransformerEngineBaseModule):
output_quantizer
=
None
output_quantizer
=
None
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
input_quantizer
.
internal
=
True
input_quantizer
.
internal
=
True
if
not
(
self
.
parallel_mode
==
"column"
and
self
.
sequence_parallel
):
input_quantizer
.
optimize_for_gemm
=
True
(
weight_quantizer
,)
=
self
.
_get_weight_quantizers
()
(
weight_quantizer
,)
=
self
.
_get_weight_quantizers
()
if
fp8_output
:
if
fp8_output
:
output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_OUTPUT
]
output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_OUTPUT
]
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
]
grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
]
grad_output_quantizer
.
internal
=
True
grad_output_quantizer
.
internal
=
True
if
not
(
self
.
parallel_mode
==
"row"
and
self
.
sequence_parallel
):
grad_output_quantizer
.
optimize_for_gemm
=
True
if
fp8_grad
:
if
fp8_grad
:
grad_input_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_INPUT1
]
grad_input_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_INPUT1
]
...
@@ -1687,8 +1654,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1687,8 +1654,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer
,
grad_output_quantizer
,
)
)
def
_get_debug_quantizers
(
self
,
fp8_output
,
fp8_grad
):
def
_get_debug_quantizers
(
self
,
fp8_output
,
fp8_grad
,
is_grad_enabled
):
original_quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
original_quantizers
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
,
is_grad_enabled
)
assert
TEDebugState
.
debug_enabled
assert
TEDebugState
.
debug_enabled
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
...
@@ -1713,6 +1680,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1713,6 +1680,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
,
self
,
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
fp8_output
:
bool
,
fp8_output
:
bool
,
is_grad_enabled
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
ONNX-compatible version of the forward function that provides numerical equivalence
ONNX-compatible version of the forward function that provides numerical equivalence
...
@@ -1728,7 +1696,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1728,7 +1696,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_quantizer
,
weight_quantizer
,
output_quantizer
,
output_quantizer
,
*
_
,
*
_
,
)
=
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
=
False
)
)
=
self
.
_get_quantizers
(
fp8_output
,
False
,
is_grad_enabled
)
inp_dtype
=
inp
.
dtype
inp_dtype
=
inp
.
dtype
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
...
@@ -1857,14 +1825,3 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1857,14 +1825,3 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
]
weight_quantizer
.
internal
=
True
weight_quantizer
.
internal
=
True
return
[
weight_quantizer
]
return
[
weight_quantizer
]
def
_customize_quantizers_float8_blockwise_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on blockwise scaling recipe + layernorm_linear."""
assert
(
recipe
.
float8_block_scaling
()
),
"blockwise scaling recipe quantizer customization here"
if
fwd
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"column"
:
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
all_gather_usage
=
True
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -17,11 +17,10 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
...
@@ -17,11 +17,10 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
.torch_version
import
torch_version
from
transformer_engine.pytorch.tensor.utils
import
is_custom
from
transformer_engine.pytorch.tensor.utils
import
is_custom
from
.base
import
(
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
_ub_communicators
,
_ub_communicators
,
get_ub
,
get_ub
,
TransformerEngineBaseModule
,
TransformerEngineBaseModule
,
...
@@ -46,6 +45,7 @@ from ..utils import (
...
@@ -46,6 +45,7 @@ from ..utils import (
clear_tensor_data
,
clear_tensor_data
,
requires_grad
,
requires_grad
,
needs_quantized_gemm
,
needs_quantized_gemm
,
get_nvtx_range_context
,
)
)
from
..distributed
import
(
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
set_tensor_model_parallel_attributes
,
...
@@ -57,6 +57,8 @@ from ..distributed import (
...
@@ -57,6 +57,8 @@ from ..distributed import (
use_reentrant_activation_recompute
,
use_reentrant_activation_recompute
,
in_fp8_activation_recompute_phase
,
in_fp8_activation_recompute_phase
,
_fsdp_scatter_tensors
,
_fsdp_scatter_tensors
,
_get_cuda_rng_state
,
_set_cuda_rng_state
,
)
)
from
..constants
import
dist_group_type
from
..constants
import
dist_group_type
from
..jit
import
no_torch_dynamo
from
..jit
import
no_torch_dynamo
...
@@ -174,7 +176,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -174,7 +176,7 @@ class _LayerNormMLP(torch.autograd.Function):
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
def
_
forward
(
ctx
,
ctx
,
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
ln_weight
:
torch
.
Tensor
,
ln_weight
:
torch
.
Tensor
,
...
@@ -183,55 +185,155 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -183,55 +185,155 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias
:
torch
.
Tensor
,
fc1_bias
:
torch
.
Tensor
,
fc2_weight
:
torch
.
Tensor
,
fc2_weight
:
torch
.
Tensor
,
fc2_bias
:
torch
.
Tensor
,
fc2_bias
:
torch
.
Tensor
,
eps
:
float
,
non_tensor_args
:
Tuple
,
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
fuse_wgrad_accumulation
:
bool
,
fc1_input_quantizer
:
Optional
[
Quantizer
],
fc1_weight_quantizer
:
Optional
[
Quantizer
],
fc1_output_quantizer
:
Optional
[
Quantizer
],
fc1_grad_input_quantizer
:
Optional
[
Quantizer
],
fc1_grad_weight_quantizer
:
Optional
[
Quantizer
],
fc1_grad_output_quantizer
:
Optional
[
Quantizer
],
fc2_input_quantizer
:
Optional
[
Quantizer
],
fc2_weight_quantizer
:
Optional
[
Quantizer
],
fc2_output_quantizer
:
Optional
[
Quantizer
],
fc2_grad_input_quantizer
:
Optional
[
Quantizer
],
fc2_grad_weight_quantizer
:
Optional
[
Quantizer
],
fc2_grad_output_quantizer
:
Optional
[
Quantizer
],
cpu_offloading
:
bool
,
tp_group
:
Union
[
dist_group_type
,
None
],
tp_size
:
int
,
sequence_parallel
:
bool
,
tensor_parallel
:
bool
,
activation_dtype
:
torch
.
dtype
,
return_layernorm_output
:
bool
,
return_layernorm_output_gathered
:
bool
,
bias_gelu_fusion
:
bool
,
set_parallel_mode
:
bool
,
is_grad_enabled
:
bool
,
fwd_ln_sm_margin
:
int
,
bwd_ln_sm_margin
:
int
,
zero_centered_gamma
:
bool
,
activation
:
str
,
activation_params
:
Optional
[
dict
],
normalization
:
str
,
ub_overlap_ag
:
bool
,
ub_overlap_rs
:
bool
,
ub_overlap_rs_dgrad
:
bool
,
ub_bulk_wgrad
:
bool
,
ub_bulk_dgrad
:
bool
,
gemm_gelu_fusion
:
bool
,
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
skip_fp8_weight_update
:
bool
,
symmetric_ar_type
:
str
,
debug
:
Optional
[
bool
]
=
False
,
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
eps
,
is_first_microbatch
,
fp8
,
fp8_calibration
,
wgrad_store
,
fuse_wgrad_accumulation
,
fc1_input_quantizer
,
fc1_weight_quantizer
,
fc1_output_quantizer
,
fc1_grad_input_quantizer
,
fc1_grad_weight_quantizer
,
fc1_grad_output_quantizer
,
fc2_input_quantizer
,
fc2_weight_quantizer
,
fc2_output_quantizer
,
fc2_grad_input_quantizer
,
fc2_grad_weight_quantizer
,
fc2_grad_output_quantizer
,
cpu_offloading
,
tp_group
,
tp_size
,
sequence_parallel
,
tensor_parallel
,
activation_dtype
,
return_layernorm_output
,
return_layernorm_output_gathered
,
bias_gelu_fusion
,
set_parallel_mode
,
is_grad_enabled
,
fwd_ln_sm_margin
,
bwd_ln_sm_margin
,
zero_centered_gamma
,
activation
,
activation_params
,
normalization
,
ub_overlap_ag
,
ub_overlap_rs
,
ub_overlap_rs_dgrad
,
ub_bulk_wgrad
,
ub_bulk_dgrad
,
gemm_gelu_fusion
,
fsdp_group
,
module
,
skip_fp8_weight_update
,
symmetric_ar_type
,
checkpoint
,
debug
,
recompute_for_bwd
,
)
=
non_tensor_args
# if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take
if
is_grad_enabled
and
not
recompute_for_bwd
:
ctx
.
checkpoint
=
checkpoint
if
checkpoint
:
# save the state of autocast and quantizers for recomputation
ctx
.
autocast_state
=
(
FP8GlobalStateManager
.
get_autocast_state
()
)
# to restore autocast state during recomputation
if
(
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
__class__
.
__name__
==
"DelayedScaling"
):
# only applicable for delayed scaling
FP8GlobalStateManager
.
copy_forward_fp8_meta_tensors_for_recompute
(
module
.
fp8_meta
)
# to restore quantizers during recomputation
# save the rng states
ctx
.
cpu_rng_state
=
torch
.
get_rng_state
()
ctx
.
cuda_rng_state
=
_get_cuda_rng_state
()
# whether to save activations regularly, or save inputs for recomputation in bwd
save_for_checkpoint
=
checkpoint
and
is_grad_enabled
and
not
recompute_for_bwd
# whether we are in the forward stage, or recomputing in the bwd stage (false if not checkpointing)
is_recomputation
=
checkpoint
and
is_grad_enabled
and
recompute_for_bwd
# save the initial state for recomputation by bwd
if
save_for_checkpoint
:
# save tensors
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
inp
,
ln_weight
,
ln_bias
,
fc1_weight
,
fc1_bias
,
fc2_weight
,
fc2_bias
,
)
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
tensor_objects
=
tensor_objects
ctx
.
other_args
=
{
"eps"
:
eps
,
"is_first_microbatch"
:
is_first_microbatch
,
"fp8"
:
fp8
,
"fp8_calibration"
:
fp8_calibration
,
"wgrad_store"
:
wgrad_store
,
"fuse_wgrad_accumulation"
:
fuse_wgrad_accumulation
,
"fc1_input_quantizer"
:
fc1_input_quantizer
,
"fc1_weight_quantizer"
:
fc1_weight_quantizer
,
"fc1_output_quantizer"
:
fc1_output_quantizer
,
"fc1_grad_input_quantizer"
:
fc1_grad_input_quantizer
,
"fc1_grad_weight_quantizer"
:
fc1_grad_weight_quantizer
,
"fc1_grad_output_quantizer"
:
fc1_grad_output_quantizer
,
"fc2_input_quantizer"
:
fc2_input_quantizer
,
"fc2_weight_quantizer"
:
fc2_weight_quantizer
,
"fc2_output_quantizer"
:
fc2_output_quantizer
,
"fc2_grad_input_quantizer"
:
fc2_grad_input_quantizer
,
"fc2_grad_weight_quantizer"
:
fc2_grad_weight_quantizer
,
"fc2_grad_output_quantizer"
:
fc2_grad_output_quantizer
,
"cpu_offloading"
:
cpu_offloading
,
"tp_group"
:
tp_group
,
"tp_size"
:
tp_size
,
"sequence_parallel"
:
sequence_parallel
,
"tensor_parallel"
:
tensor_parallel
,
"activation_dtype"
:
activation_dtype
,
"return_layernorm_output"
:
return_layernorm_output
,
"return_layernorm_output_gathered"
:
return_layernorm_output_gathered
,
"bias_gelu_fusion"
:
bias_gelu_fusion
,
"set_parallel_mode"
:
set_parallel_mode
,
"is_grad_enabled"
:
is_grad_enabled
,
"fwd_ln_sm_margin"
:
fwd_ln_sm_margin
,
"bwd_ln_sm_margin"
:
bwd_ln_sm_margin
,
"zero_centered_gamma"
:
zero_centered_gamma
,
"activation"
:
activation
,
"activation_params"
:
activation_params
,
"normalization"
:
normalization
,
"ub_overlap_ag"
:
ub_overlap_ag
,
"ub_overlap_rs"
:
ub_overlap_rs
,
"ub_overlap_rs_dgrad"
:
ub_overlap_rs_dgrad
,
"ub_bulk_wgrad"
:
ub_bulk_wgrad
,
"ub_bulk_dgrad"
:
ub_bulk_dgrad
,
"gemm_gelu_fusion"
:
gemm_gelu_fusion
,
"fsdp_group"
:
fsdp_group
,
"module"
:
module
,
"skip_fp8_weight_update"
:
skip_fp8_weight_update
,
"symmetric_ar_type"
:
symmetric_ar_type
,
"checkpoint"
:
checkpoint
,
"debug"
:
debug
,
"recompute_for_bwd"
:
True
,
# set this to true for recomputation phase
}
# Make sure input dimensions are compatible
# Make sure input dimensions are compatible
in_features
,
inp_shape
=
ln_weight
.
numel
(),
inp
.
shape
in_features
,
inp_shape
=
ln_weight
.
numel
(),
inp
.
shape
assert
inp_shape
[
-
1
]
==
in_features
,
"GEMM not possible"
assert
inp_shape
[
-
1
]
==
in_features
,
"GEMM not possible"
...
@@ -253,7 +355,14 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -253,7 +355,14 @@ class _LayerNormMLP(torch.autograd.Function):
start_offload
(
inputmat
)
start_offload
(
inputmat
)
tp_world_size
=
get_distributed_world_size
(
tp_group
)
tp_world_size
=
get_distributed_world_size
(
tp_group
)
backwards_needs_fc1_input
=
is_grad_enabled
and
fc1_weight
.
requires_grad
# bwd needs fc1 input when grad is enabled, fc1 needs grad, and either
# 1) no checkpointing
# or 2) doing the recomputation with checkpointing
backwards_needs_fc1_input
=
fc1_weight
.
requires_grad
and
(
(
is_grad_enabled
and
not
checkpoint
)
or
is_recomputation
)
device
=
inp
.
device
device
=
inp
.
device
# Configure Userbuffers communication (comm+GEMM overlap)
# Configure Userbuffers communication (comm+GEMM overlap)
...
@@ -311,7 +420,9 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -311,7 +420,9 @@ class _LayerNormMLP(torch.autograd.Function):
zero_centered_gamma
,
zero_centered_gamma
,
)
)
ln_out_return
=
None
ln_out_return
=
None
if
return_layernorm_output
or
return_layernorm_output_gathered
:
# do not return layernorm output unless 1) no checkpointing or 2) checkpointing but not recomputing
if
(
return_layernorm_output
or
return_layernorm_output_gathered
)
and
not
is_recomputation
:
ln_out_return
=
ln_out
ln_out_return
=
ln_out
# Prepare GEMM input
# Prepare GEMM input
...
@@ -319,7 +430,9 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -319,7 +430,9 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total
=
None
ln_out_total
=
None
ub_obj_lnout
=
None
ub_obj_lnout
=
None
if
sequence_parallel
:
if
sequence_parallel
:
if
return_layernorm_output_gathered
:
# do not return ln output if checkpointing and in recomputation, not necessary
if
return_layernorm_output_gathered
and
not
is_recomputation
:
# Perform all-gather in high precision if gathered
# Perform all-gather in high precision if gathered
# norm output will be returned
# norm output will be returned
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
)
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
)
...
@@ -327,8 +440,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -327,8 +440,6 @@ class _LayerNormMLP(torch.autograd.Function):
if
fp8
or
debug
:
if
fp8
or
debug
:
ln_out
=
fc1_input_quantizer
(
ln_out
)
ln_out
=
fc1_input_quantizer
(
ln_out
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
isinstance
(
fc1_input_quantizer
,
Float8BlockQuantizer
):
fc1_input_quantizer
.
all_gather_usage
=
False
ln_out_total
=
fc1_input_quantizer
(
ln_out_total
)
ln_out_total
=
fc1_input_quantizer
(
ln_out_total
)
else
:
else
:
quantizer
=
None
quantizer
=
None
...
@@ -442,7 +553,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -442,7 +553,6 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_outputs
=
general_gemm
(
fc1_outputs
=
general_gemm
(
fc1_weight_final
,
fc1_weight_final
,
ln_out_total
,
ln_out_total
,
get_workspace
(),
quantization_params
=
(
quantization_params
=
(
fc2_input_quantizer
fc2_input_quantizer
if
gemm_gelu_fusion
if
gemm_gelu_fusion
...
@@ -463,7 +573,12 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -463,7 +573,12 @@ class _LayerNormMLP(torch.autograd.Function):
# ------------------------------------------------------
# ------------------------------------------------------
# Deallocate FC1 GEMM input tensor if no longer needed
# Deallocate FC1 GEMM input tensor if no longer needed
if
not
is_grad_enabled
and
(
ln_out_total
is
not
ln_out_return
):
# first part of if statement means that we only clear ln_out_total if
# 1) checkpointing and not recomputing (in the forward stage, not bwd recompute stage)
# 2) not checkpointing and grad disabled
if
((
checkpoint
and
not
is_recomputation
)
or
not
is_grad_enabled
)
and
(
ln_out_total
is
not
ln_out_return
):
clear_tensor_data
(
ln_out_total
)
clear_tensor_data
(
ln_out_total
)
# ACTIVATION - sometimes activation is fused with the GEMM above.
# ACTIVATION - sometimes activation is fused with the GEMM above.
...
@@ -501,12 +616,27 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -501,12 +616,27 @@ class _LayerNormMLP(torch.autograd.Function):
else
:
else
:
act_out
=
activation_func
(
fc1_out
,
fc2_input_quantizer
,
**
act_params
)
act_out
=
activation_func
(
fc1_out
,
fc2_input_quantizer
,
**
act_params
)
if
not
is_grad_enabled
:
clear_tensor_data
(
fc1_out
)
if
not
fp8
and
fp8_calibration
:
if
not
fp8
and
fp8_calibration
:
if
fc2_input_quantizer
is
not
None
:
if
fc2_input_quantizer
is
not
None
:
fc2_input_quantizer
.
calibrate
(
act_out
)
fc2_input_quantizer
.
calibrate
(
act_out
)
# we want to skip fc2 computation if we are checkpointing and recomputing,
# otherwise we compute fc2
if
not
(
is_recomputation
and
checkpoint
):
# if we get to this point, we know this is not bwd recomputation
# so we must be in the fwd
# now is_grad_enabled can be true or false
# if false, can safely delete
# if true, we can only delete if checkpoint is true, since we will recompute anyways,
# otherwise, checkpoint is false, so cant delete
if
(
checkpoint
or
not
is_grad_enabled
):
# we can safely get rid of these if this is the case
clear_tensor_data
(
fc1_out
)
if
not
fp8
and
fp8_calibration
:
if
fc2_weight_quantizer
is
not
None
:
if
fc2_weight_quantizer
is
not
None
:
fc2_weight_quantizer
.
calibrate
(
fc2_weight
)
fc2_weight_quantizer
.
calibrate
(
fc2_weight
)
...
@@ -526,7 +656,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -526,7 +656,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
fc2_weight_final
,
fc2_weight_final
,
act_out
,
act_out
,
get_workspace
(),
out_dtype
=
activation_dtype
,
out_dtype
=
activation_dtype
,
bias
=
fc2_bias
,
bias
=
fc2_bias
,
quantization_params
=
fc2_output_quantizer
,
quantization_params
=
fc2_output_quantizer
,
...
@@ -539,8 +668,8 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -539,8 +668,8 @@ class _LayerNormMLP(torch.autograd.Function):
# Finished FC2 GEMM...
# Finished FC2 GEMM...
# ------------------------------------------------------
# ------------------------------------------------------
# Deallocate tensors if no longer needed
# Deallocate tensors if no longer needed
, again, can safely deallocate
if
not
is_grad_enabled
:
if
checkpoint
or
not
is_grad_enabled
:
# same logic as last clear_tensor_data block
clear_tensor_data
(
act_out
,
fc1_out_without_bias
,
fc1_out
)
clear_tensor_data
(
act_out
,
fc1_out_without_bias
,
fc1_out
)
# Prepare output tensor
# Prepare output tensor
...
@@ -561,8 +690,24 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -561,8 +690,24 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_out
=
gemm_out
fc2_out
=
gemm_out
fc2_out
=
fc2_out
.
view
(
-
1
,
*
inp_shape
[
1
:
-
1
],
fc2_out
.
shape
[
-
1
])
fc2_out
=
fc2_out
.
view
(
-
1
,
*
inp_shape
[
1
:
-
1
],
fc2_out
.
shape
[
-
1
])
# Cache state for backward pass
# now saving stuff for bwd:
if
is_grad_enabled
:
# if we are using checkpointing, this information will be saved in the bwd recomputation stage, so can skip it in fwd
# if we are not checkpointing, then we must save this if grad is enabled
if
is_grad_enabled
and
not
save_for_checkpoint
:
ctx
.
fc1_weight_quantizer
=
fc1_weight_quantizer
ctx
.
fc2_weight_quantizer
=
fc2_weight_quantizer
if
not
fc1_weight
.
requires_grad
:
if
not
return_layernorm_output
:
clear_tensor_data
(
ln_out
)
ln_out
=
None
if
not
fc2_weight
.
requires_grad
:
clear_tensor_data
(
act_out
)
act_out
=
None
if
not
checkpoint
:
# regular path, no selective activation checkpointing
if
cpu_offloading
:
if
cpu_offloading
:
mark_activation_offload
(
mark_activation_offload
(
inputmat
,
mu
,
rsigma
,
ln_out
,
fc1_out
,
fc1_out_without_bias
,
act_out
inputmat
,
mu
,
rsigma
,
ln_out
,
fc1_out
,
fc1_out_without_bias
,
act_out
...
@@ -572,26 +717,27 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -572,26 +717,27 @@ class _LayerNormMLP(torch.autograd.Function):
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves
# shards/unshards the base weights so we don't do it ourselves
ctx
.
fsdp_group
=
fsdp_group
ctx
.
fsdp_group
=
fsdp_group
ctx
.
fsdp_shapes
=
_fsdp_scatter_tensors
(
ctx
.
fsdp_shapes
=
(
_fsdp_scatter_tensors
(
# again, ony relevant if we have activations to save
fsdp_group
,
fsdp_group
,
mu
,
mu
,
rsigma
,
rsigma
,
ln_out
,
ln_out
,
fc1_out_without_bias
if
bias_gelu_fusion
else
fc1_out
,
fc1_out_without_bias
if
bias_gelu_fusion
else
fc1_out
,
act_out
,
act_out
,
fc1_weight_final
if
fp8
and
not
isinstance
(
fc1_weight
,
Float8Tensor
)
else
None
,
(
fc2_weight_final
if
fp8
and
not
isinstance
(
fc2_weight
,
Float8Tensor
)
else
None
,
fc1_weight_final
if
fp8
and
not
isinstance
(
fc1_weight
,
Float8Tensor
)
else
None
),
(
fc2_weight_final
if
fp8
and
not
isinstance
(
fc2_weight
,
Float8Tensor
)
else
None
),
)
)
)
ctx
.
fc1_weight_quantizer
=
fc1_weight_quantizer
ctx
.
fc2_weight_quantizer
=
fc2_weight_quantizer
if
not
fc1_weight
.
requires_grad
:
if
not
return_layernorm_output
:
clear_tensor_data
(
ln_out
)
ln_out
=
None
if
not
fc2_weight
.
requires_grad
:
clear_tensor_data
(
act_out
)
act_out
=
None
if
cpu_offloading
:
if
cpu_offloading
:
mark_not_offload
(
mark_not_offload
(
...
@@ -604,7 +750,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -604,7 +750,6 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight
,
fc2_weight
,
fc2_bias
,
fc2_bias
,
)
)
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
inputmat
,
inputmat
,
ln_weight
,
ln_weight
,
...
@@ -622,6 +767,9 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -622,6 +767,9 @@ class _LayerNormMLP(torch.autograd.Function):
rsigma
,
rsigma
,
)
)
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
tensor_objects
=
tensor_objects
if
fuse_wgrad_accumulation
:
if
fuse_wgrad_accumulation
:
# This check is needed to ensure that main_grad is not created
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# during the forward pass when using MCore FSDP as it creates
...
@@ -638,9 +786,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -638,9 +786,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
fc1_main_grad_func
=
lambda
:
fc1_weight
.
main_grad
ctx
.
fc1_main_grad_func
=
lambda
:
fc1_weight
.
main_grad
ctx
.
fc2_main_grad_func
=
lambda
:
fc2_weight
.
main_grad
ctx
.
fc2_main_grad_func
=
lambda
:
fc2_weight
.
main_grad
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
tensor_objects
=
tensor_objects
ctx
.
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
ctx
.
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
ctx
.
fc1_grad_input_quantizer
=
fc1_grad_input_quantizer
ctx
.
fc1_grad_input_quantizer
=
fc1_grad_input_quantizer
ctx
.
fc1_grad_weight_quantizer
=
fc1_grad_weight_quantizer
ctx
.
fc1_grad_weight_quantizer
=
fc1_grad_weight_quantizer
...
@@ -695,11 +840,30 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -695,11 +840,30 @@ class _LayerNormMLP(torch.autograd.Function):
):
):
_first_fp8_module
=
FP8GlobalStateManager
.
IS_FIRST_FP8_MODULE
_first_fp8_module
=
FP8GlobalStateManager
.
IS_FIRST_FP8_MODULE
ctx
.
reduce_and_update_bwd_fp8_tensors
=
FP8GlobalStateManager
.
is_first_fp8_module
()
ctx
.
reduce_and_update_bwd_fp8_tensors
=
FP8GlobalStateManager
.
is_first_fp8_module
()
if
in_fp8_activation_recompute_phase
():
if
in_fp8_activation_recompute_phase
()
or
is_recomputation
:
FP8GlobalStateManager
.
IS_FIRST_FP8_MODULE
=
_first_fp8_module
FP8GlobalStateManager
.
IS_FIRST_FP8_MODULE
=
_first_fp8_module
ctx
.
wgrad_store
=
wgrad_store
ctx
.
wgrad_store
=
wgrad_store
if
is_recomputation
:
# return the recomputed tensors
return
(
ctx
,
inputmat
,
ln_weight
,
ln_out
,
fc1_weight_final
,
fc1_weight
,
fc1_bias
,
fc1_out
,
fc1_out_without_bias
,
act_out
,
fc2_weight_final
,
fc2_weight
,
fc2_bias
,
mu
,
rsigma
,
)
# we only get to this point if we are not recomputing for bwd, since that would have returned in the block above
if
return_layernorm_output
:
if
return_layernorm_output
:
if
return_layernorm_output_gathered
:
if
return_layernorm_output_gathered
:
shape
=
list
(
inp_shape
)
shape
=
list
(
inp_shape
)
...
@@ -708,14 +872,101 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -708,14 +872,101 @@ class _LayerNormMLP(torch.autograd.Function):
return
fc2_out
,
ln_out_return
.
view
(
inp_shape
)
return
fc2_out
,
ln_out_return
.
view
(
inp_shape
)
return
fc2_out
return
fc2_out
@
staticmethod
def
forward
(
ctx
,
inp
:
torch
.
Tensor
,
ln_weight
:
torch
.
Tensor
,
ln_bias
:
torch
.
Tensor
,
fc1_weight
:
torch
.
Tensor
,
fc1_bias
:
torch
.
Tensor
,
fc2_weight
:
torch
.
Tensor
,
fc2_bias
:
torch
.
Tensor
,
non_tensor_args
:
Tuple
,
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
# pylint: disable=missing-function-docstring
# add recompute_for_bwd
non_tensor_args
+=
(
False
,)
return
_LayerNormMLP
.
_forward
(
ctx
,
inp
,
ln_weight
,
ln_bias
,
fc1_weight
,
fc1_bias
,
fc2_weight
,
fc2_bias
,
non_tensor_args
,
)
@
staticmethod
def
_recompute
(
ctx
):
# pylint: disable=missing-function-docstring
saved_tensors
=
ctx
.
saved_tensors
tensors
=
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx
.
tensor_objects
=
None
if
ctx
.
checkpoint
:
# do recomputation from the original args
# backward is not in autocast context, so we set the state here
# we also have to set the quantizer states to what they were before the forward pass (only relevant for DelayedScaling recipe)
final_autocast_state
=
(
FP8GlobalStateManager
.
get_autocast_state
()
)
# get current autocast state
FP8GlobalStateManager
.
set_autocast_state
(
ctx
.
autocast_state
)
# set old autocast state
if
(
ctx
.
other_args
[
"fp8"
]
and
FP8GlobalStateManager
.
get_fp8_recipe
().
__class__
.
__name__
==
"DelayedScaling"
):
# only applicable for delayed scaling
FP8GlobalStateManager
.
get_old_fp8_meta_tensors_for_recompute
(
ctx
.
other_args
[
"module"
].
fp8_meta
)
# set old quantizer state
# get current rng state
final_cpu_rng_state
=
torch
.
get_rng_state
()
final_cuda_rng_state
=
_get_cuda_rng_state
()
# set rng state for fwd
torch
.
set_rng_state
(
ctx
.
cpu_rng_state
)
_set_cuda_rng_state
(
ctx
.
cuda_rng_state
)
out
=
_LayerNormMLP
.
_forward
(
# recompute
ctx
,
*
tensors
,
tuple
(
ctx
.
other_args
.
values
()),
)
FP8GlobalStateManager
.
set_autocast_state
(
final_autocast_state
)
# restore autocast state
if
(
ctx
.
other_args
[
"fp8"
]
and
FP8GlobalStateManager
.
get_fp8_recipe
().
__class__
.
__name__
==
"DelayedScaling"
):
FP8GlobalStateManager
.
restore_fp8_meta_tensors
(
ctx
.
other_args
[
"module"
].
fp8_meta
)
# restore quantizers
# set rng state for fwd
torch
.
set_rng_state
(
final_cpu_rng_state
)
_set_cuda_rng_state
(
final_cuda_rng_state
)
return
out
# load from saved (return ctx is just because the other branch does too)
return
tuple
([
ctx
]
+
tensors
)
@
staticmethod
@
staticmethod
def
backward
(
def
backward
(
ctx
,
*
grad_outputs
:
Tuple
[
torch
.
Tensor
,
...]
ctx
,
*
grad_outputs
:
Tuple
[
torch
.
Tensor
,
...]
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
with
torch
.
cuda
.
nvtx
.
range
(
"_LayerNormMLP_backward"
):
with
get_nvtx_range_context
(
"_LayerNormMLP_backward"
):
saved_tensors
=
ctx
.
saved_tensors
(
# pylint: disable=unbalanced-tuple-unpacking
(
# pylint: disable=unbalanced-tuple-unpacking
ctx
,
inputmat
,
inputmat
,
ln_weight
,
ln_weight
,
ln_out
,
ln_out
,
...
@@ -730,11 +981,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -730,11 +981,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_bias
,
fc2_bias
,
mu
,
mu
,
rsigma
,
rsigma
,
)
=
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
)
=
_LayerNormMLP
.
_recompute
(
ctx
)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx
.
tensor_objects
=
None
# Since main_grad can be modified inplace, it should not be a part of saved_tensors
# Since main_grad can be modified inplace, it should not be a part of saved_tensors
fc1_weight_main_grad
=
(
fc1_weight_main_grad
=
(
...
@@ -883,7 +1130,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -883,7 +1130,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_output
,
*
_
=
general_gemm
(
gemm_output
,
*
_
=
general_gemm
(
fc2_weight
,
fc2_weight
,
grad_output
,
grad_output
,
get_workspace
(),
layout
=
"NN"
,
layout
=
"NN"
,
grad
=
True
,
grad
=
True
,
quantization_params
=
(
quantization_params
=
(
...
@@ -977,7 +1223,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -977,7 +1223,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure
# Arguments to include in wgrad GEMM closure
fc2_wgrad_gemm_kwargs
=
{
fc2_wgrad_gemm_kwargs
=
{
"workspace"
:
get_workspace
(),
"out_dtype"
:
(
"out_dtype"
:
(
origin_fc2_weight
.
main_grad
.
dtype
origin_fc2_weight
.
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
if
ctx
.
fuse_wgrad_accumulation
...
@@ -1155,7 +1400,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -1155,7 +1400,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
fc1_weight
,
fc1_weight
,
dact
,
dact
,
get_workspace
(),
out
=
gemm_out
,
out
=
gemm_out
,
out_dtype
=
ctx
.
activation_dtype
,
out_dtype
=
ctx
.
activation_dtype
,
quantization_params
=
ctx
.
fc1_grad_input_quantizer
,
quantization_params
=
ctx
.
fc1_grad_input_quantizer
,
...
@@ -1234,7 +1478,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -1234,7 +1478,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure
# Arguments to include in wgrad GEMM closure
fc1_wgrad_gemm_kwargs
=
{
fc1_wgrad_gemm_kwargs
=
{
"workspace"
:
get_workspace
(),
"out_dtype"
:
(
"out_dtype"
:
(
origin_fc1_weight
.
main_grad
.
dtype
origin_fc1_weight
.
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
if
ctx
.
fuse_wgrad_accumulation
...
@@ -1429,52 +1672,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -1429,52 +1672,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias_grad
if
fc1_bias
is
not
None
else
None
,
fc1_bias_grad
if
fc1_bias
is
not
None
else
None
,
fc2_wgrad
,
# pylint: disable=possibly-used-before-assignment
fc2_wgrad
,
# pylint: disable=possibly-used-before-assignment
fc2_bias_grad
,
fc2_bias_grad
,
None
,
# eps
None
,
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8_calibration
None
,
# wgrad_store
None
,
# fuse_wgrad_accumulation
None
,
# fc1_input_quantizer,
None
,
# fc1_weight_quantizer,
None
,
# fc1_output_quantizer,
None
,
# fc1_grad_input_quantizer,
None
,
# fc1_grad_weight_quantizer,
None
,
# fc1_grad_output_quantizer,
None
,
# fc2_input_quantizer,
None
,
# fc2_weight_quantizer,
None
,
# fc2_output_quantizer,
None
,
# fc2_grad_input_quantizer,
None
,
# fc2_grad_weight_quantizer,
None
,
# fc2_grad_output_quantizer,
None
,
# cpu_offloading
None
,
# tp_group
None
,
# tp_size
None
,
# sequence_parallel
None
,
# tensor_parallel
None
,
# activation_dtype
None
,
# return_layernorm_output
None
,
# return_layernorm_output_gathered
None
,
# bias_gelu_fusion
None
,
# set_parallel_mode
None
,
# is_grad_enabled
None
,
# fwd_ln_sm_margin
None
,
# bwd_ln_sm_margin
None
,
# zero_centered_gamma
None
,
# activation
None
,
# activation_params
None
,
# normalization
None
,
# ub_overlap_ag
None
,
# ub_overlap_rs
None
,
# ub_overlap_rs_dgrad
None
,
# ub_bulk_dgrad
None
,
# ub_bulk_wgrad
None
,
# gemm_gelu_fusion
None
,
# fsdp_group
None
,
# module
None
,
# skip_fp8_weight_update
None
,
# symmetric_ar_type
None
,
# debug
)
)
...
@@ -1491,38 +1689,38 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1491,38 +1689,38 @@ class LayerNormMLP(TransformerEngineBaseModule):
intermediate size to which input samples are projected.
intermediate size to which input samples are projected.
eps : float, default = 1e-5
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
a value added to the denominator of layer normalization for numerical stability.
bias : bool, default =
`
True
`
bias : bool, default = True
if set to `False`, the FC1 and FC2 layers will not learn an additive bias.
if set to
`
`False`
`
, the FC1 and FC2 layers will not learn an additive bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
type of normalization applied.
activation : str, default = 'gelu'
activation : str, default = 'gelu'
activation function used.
activation function used.
Options: 'gelu'
,
'geglu'
,
'qgelu'
,
'qgeglu'
, 'relu',
'reglu'
,
'srelu'
,
'sreglu',
Options:
``
'gelu'
``, ``
'geglu'
``, ``
'qgelu'
``, ``
'qgeglu'
``, ``'relu'``, ``
'reglu'
``, ``
'srelu'
``, ``
'sreglu'
``
,
'silu',
'swiglu', and 'clamped_swiglu'.
``'silu'``, ``
'swiglu'
``
, and
``
'clamped_swiglu'
``
.
activation_params : dict, default =
`
None
`
activation_params : dict, default = None
Additional parameters for the activation function.
Additional parameters for the activation function.
At the moment, only used for 'clamped_swiglu' activation which
At the moment, only used for
``
'clamped_swiglu'
``
activation which
supports 'limit' and 'alpha' parameters.
supports
``
'limit'
``
and
``
'alpha'
``
parameters.
init_method : Callable, default =
`
None
`
init_method : Callable, default = None
used for initializing FC1 weights in the following way: `init_method(weight)`.
used for initializing FC1 weights in the following way:
`
`init_method(weight)`
`
.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
When set to
`
`None`
`
, defaults to
`
`torch.nn.init.normal_(mean=0.0, std=0.023)`
`
.
output_layer_init_method : Callable, default =
`
None
`
output_layer_init_method : Callable, default = None
used for initializing FC2 weights in the following way:
used for initializing FC2 weights in the following way:
`output_layer_init_method(weight)`. When set to `None`, defaults to
`
`output_layer_init_method(weight)`
`
. When set to
`
`None`
`
, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`.
`
`torch.nn.init.normal_(mean=0.0, std=0.023)`
`
.
return_layernorm_output : bool, default =
`
False
`
return_layernorm_output : bool, default = False
if set to `True`, output of layernorm is returned from the
forwar
d
if set to
`
`True`
`
, output of layernorm is returned from the
:meth:`forward` metho
d
together with the output of the linear transformation.
together with the output of the linear transformation.
Example use case: residual connection for transformer module
Example use case: residual connection for transformer module
is taken post layernorm.
is taken post layernorm.
return_layernorm_output_gathered : bool, default =
`
False
`
return_layernorm_output_gathered : bool, default = False
if set to `True`, output of layernorm is returned after the all
if set to
`
`True`
`
, output of layernorm is returned after the all
gather operation. Ignored if return_layernorm_output is False.
gather operation. Ignored if
``
return_layernorm_output
``
is False.
Example use case: with sequence parallel, input to residual connection
Example use case: with sequence parallel, input to residual connection
for transformer module (e.g. LoRA) will need to be gathered.
for transformer module (e.g. LoRA) will need to be gathered.
Returning layernorm output gathered will prevent a redundant gather.
Returning layernorm output gathered will prevent a redundant gather.
zero_centered_gamma : bool, default =
'
False
'
zero_centered_gamma : bool, default = False
if set to
'
True
'
, gamma parameter in LayerNorm is initialized to 0 and
if set to
``
True
``
, gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
the LayerNorm formula changes to
.. math::
.. math::
...
@@ -1532,61 +1730,65 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1532,61 +1730,65 @@ class LayerNormMLP(TransformerEngineBaseModule):
The device on which the parameters of the model will be allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
forward pass.
name: str, default =
`
None
`
name
: str, default = None
name of the module, currently used for debugging purposes.
name of the module, currently used for debugging purposes.
Parallelism parameters
Parallelism parameters
----------------------
----------------------
set_parallel_mode : bool, default =
`
False
`
set_parallel_mode : bool, default = False
if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row
if set to
`
`True`
`
, FC1 is used as Column Parallel and FC2 is used as Row
Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default =
`
False
`
sequence_parallel : bool, default = False
if set to `True`, uses sequence parallelism.
if set to
`
`True`
`
, uses sequence parallelism.
tp_group : ProcessGroup, default =
`
None
`
tp_group : ProcessGroup, default = None
tensor parallel process group.
tensor parallel process group.
tp_size : int, default = 1
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
`
`set_tensor_parallel_group(tp_group)`
`
method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel collectives.
Optimization parameters
Optimization parameters
-----------------------
-----------------------
fuse_wgrad_accumulation : bool, default =
'
False
'
fuse_wgrad_accumulation : bool, default = False
if set to `True`, enables fusing of creation and accumulation of
if set to
`
`True`
`
, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
have an additional
`
`main_grad`
`
attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
regular
`
`grad`
`
) which is a pre-allocated buffer of the correct
size to accumulate gradients in. This argument along with
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
weight tensor having attribute
``
'overwrite_main_grad'
``
set to True
will overwrite `main_grad` instead of accumulating.
will overwrite
`
`main_grad`
`
instead of accumulating.
return_bias : bool, default =
`
False
`
return_bias : bool, default = False
when set to `True`, this module will not apply the additive bias for FC2, but
when set to
`
`True`
`
, this module will not apply the additive bias for FC2, but
instead return the bias value during the forward pass together with the
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default =
`
torch.get_default_dtype()
`
params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
would not fit in GPU memory.
seq_length: int
seq_length
: int
sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
functions are warmed up before training to ensure same kernels are used for forward
functions are warmed up before training to ensure same kernels are used for forward
propogation and activation recompute phase.
propogation and activation recompute phase.
micro_batch_size: int
micro_batch_size
: int
batch size per training step. Needed for JIT Warmup, a technique where jit
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
fused functions are warmed up before training to ensure same kernels are
used for forward propogation and activation recompute phase.
used for forward propogation and activation recompute phase.
delay_wgrad_compute : bool, default =
`
False
`
delay_wgrad_compute : bool, default = False
Whether or not to delay weight gradient computation. If set to `True`,
Whether or not to delay weight gradient computation. If set to
`
`True`
`
,
it's the user's responsibility to call
`module.
backward_dw` to compute
it's the user's responsibility to call
:meth:`
backward_dw` to compute
weight gradients.
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
Requires PyTorch version 2.7.0 or higher. When set to
``
None
``
, standard all-reduce
is used.
is used.
checkpoint : bool, default = False
whether to use selective activation checkpointing, where activations are not saved for bwd,
and instead are recomputed (skipping fc2, as it is not needed for backward). Trades compute
for memory. default is false, in which activations are saved in fwd. not supported for onnx forward
"""
"""
def
__init__
(
def
__init__
(
...
@@ -1622,6 +1824,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1622,6 +1824,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
ub_bulk_wgrad
:
bool
=
False
,
ub_bulk_wgrad
:
bool
=
False
,
delay_wgrad_compute
:
bool
=
False
,
delay_wgrad_compute
:
bool
=
False
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
checkpoint
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -1642,6 +1845,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1642,6 +1845,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
set_parallel_mode
=
set_parallel_mode
self
.
set_parallel_mode
=
set_parallel_mode
self
.
zero_centered_gamma
=
zero_centered_gamma
self
.
zero_centered_gamma
=
zero_centered_gamma
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
checkpoint
=
checkpoint
# GEMM-GELU fusion is currently only supported with split GEMM-AG overlap
# GEMM-GELU fusion is currently only supported with split GEMM-AG overlap
self
.
gemm_gelu_fusion
=
(
self
.
gemm_gelu_fusion
=
(
...
@@ -1788,15 +1992,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1788,15 +1992,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
"""Init scales and amaxes for fwd | bwd."""
"""Init scales and amaxes for fwd | bwd."""
super
().
set_meta_tensor
(
fwd
,
recipe
)
super
().
set_meta_tensor
(
fwd
,
recipe
)
#
customize quantizers based on each recipe & layer configs
#
Recipe-specific quantizer configuration
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_current_scaling
():
if
recipe
.
float8_current_scaling
():
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
elif
recipe
.
nvfp4
():
elif
recipe
.
nvfp4
():
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
# elif for other recipes (mxfp8, etc.)
def
reset_layer_norm_parameters
(
self
)
->
None
:
def
reset_layer_norm_parameters
(
self
)
->
None
:
"""Init LN params"""
"""Init LN params"""
...
@@ -1857,8 +2058,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1857,8 +2058,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
first microbatch (since it is the first gradient being
produced)
produced)
"""
"""
is_grad_enabled
=
torch
.
is_grad_enabled
()
if
is_in_onnx_export_mode
():
if
is_in_onnx_export_mode
():
return
self
.
onnx_forward
(
inp
)
return
self
.
onnx_forward
(
inp
,
is_grad_enabled
)
debug
=
self
.
is_debug_iter
()
debug
=
self
.
is_debug_iter
()
...
@@ -1874,19 +2077,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1874,19 +2077,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
if
get_ub
(
"fc2_fprop"
,
FP8GlobalStateManager
.
is_fp8_enabled
()).
is_fp8_ubuf
():
if
get_ub
(
"fc2_fprop"
,
FP8GlobalStateManager
.
is_fp8_enabled
()).
is_fp8_ubuf
():
fp8_output
=
True
fp8_output
=
True
with
torch
.
cuda
.
device
(
with
self
.
prepare_forward
(
inp
,
num_gemms
=
2
)
as
inp
:
getattr
(
self
,
list
(
self
.
named_parameters
())[
0
][
0
]).
device
),
self
.
prepare_forward
(
inp
,
num_gemms
=
2
)
as
inp
:
quantizers
=
(
quantizers
=
(
self
.
_get_quantizers
(
fp8_output
)
self
.
_get_quantizers
(
fp8_output
,
is_grad_enabled
)
if
not
debug
if
not
debug
else
self
.
_get_debug_quantizers
(
fp8_output
)
else
self
.
_get_debug_quantizers
(
fp8_output
,
is_grad_enabled
)
)
)
if
debug
:
if
debug
:
if
self
.
no_debug_features_active
(
quantizers
):
if
self
.
no_debug_features_active
(
quantizers
):
debug
=
False
debug
=
False
quantizers
=
self
.
_get_quantizers
(
fp8_output
)
quantizers
=
self
.
_get_quantizers
(
fp8_output
,
is_grad_enabled
)
# Get quantizers
# Get quantizers
(
(
...
@@ -1919,20 +2120,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1919,20 +2120,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
and
self
.
bias_gelu_nvfusion
and
not
use_reentrant_activation_recompute
()
):
and
self
.
bias_gelu_nvfusion
and
not
use_reentrant_activation_recompute
()
):
self
.
bias_gelu_nvfusion
=
False
self
.
bias_gelu_nvfusion
=
False
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
fwd_fn
=
_LayerNormMLP
.
apply
fwd_fn
=
_LayerNormMLP
.
apply
a
rgs
=
[]
a
utograd_ctx
=
[]
else
:
else
:
fwd_fn
=
_LayerNormMLP
.
forward
fwd_fn
=
_LayerNormMLP
.
forward
args
=
[
None
]
autograd_ctx
=
[
None
]
args
+=
(
inp
,
non_tensor_args
=
(
self
.
layer_norm_weight
,
self
.
layer_norm_bias
,
fc1_weight
,
fc1_bias
,
fc2_weight
,
fc2_bias
if
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
else
None
,
self
.
eps
,
self
.
eps
,
is_first_microbatch
,
is_first_microbatch
,
self
.
fp8
,
self
.
fp8
,
...
@@ -1961,8 +2156,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1961,8 +2156,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
return_layernorm_output_gathered
,
self
.
return_layernorm_output_gathered
,
self
.
bias_gelu_nvfusion
and
not
self
.
fp8
and
not
debug
,
self
.
bias_gelu_nvfusion
and
not
self
.
fp8
and
not
debug
,
self
.
set_parallel_mode
,
self
.
set_parallel_mode
,
torch
.
is_grad_enabled
()
,
is_grad_enabled
,
self
.
fwd_ln_sm_margin
if
torch
.
is_grad_enabled
()
else
self
.
inf_ln_sm_margin
,
self
.
fwd_ln_sm_margin
if
is_grad_enabled
else
self
.
inf_ln_sm_margin
,
self
.
bwd_ln_sm_margin
,
self
.
bwd_ln_sm_margin
,
self
.
zero_centered_gamma
,
self
.
zero_centered_gamma
,
self
.
activation
,
self
.
activation
,
...
@@ -1978,9 +2173,20 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1978,9 +2173,20 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
,
self
,
skip_fp8_weight_update
,
skip_fp8_weight_update
,
self
.
symmetric_ar_type
,
self
.
symmetric_ar_type
,
self
.
checkpoint
,
debug
,
debug
,
)
)
out
=
fwd_fn
(
*
args
)
out
=
fwd_fn
(
*
autograd_ctx
,
inp
,
self
.
layer_norm_weight
,
self
.
layer_norm_bias
,
fc1_weight
,
fc1_bias
,
fc2_weight
,
fc2_bias
if
self
.
apply_bias
and
not
self
.
gemm_bias_unfused_add
else
None
,
non_tensor_args
,
)
if
self
.
return_layernorm_output
:
if
self
.
return_layernorm_output
:
out
,
ln_out
=
out
out
,
ln_out
=
out
...
@@ -1996,7 +2202,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1996,7 +2202,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
return
out
,
ln_out
return
out
,
ln_out
return
out
return
out
def
_get_quantizers
(
self
,
fp8_output
):
def
_get_quantizers
(
self
,
fp8_output
,
is_grad_enabled
):
(
(
fc1_input_quantizer
,
fc1_input_quantizer
,
fc1_output_quantizer
,
fc1_output_quantizer
,
...
@@ -2013,6 +2219,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -2013,6 +2219,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
if
self
.
fp8
or
self
.
fp8_calibration
:
if
self
.
fp8
or
self
.
fp8_calibration
:
fc1_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
fc1_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
]
fc1_input_quantizer
.
internal
=
True
fc1_input_quantizer
.
internal
=
True
if
not
self
.
sequence_parallel
:
fc1_input_quantizer
.
optimize_for_gemm
=
True
fc2_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_INPUT
]
fc2_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_INPUT
]
fc2_input_quantizer
.
set_usage
(
fc2_input_quantizer
.
set_usage
(
rowwise
=
True
,
rowwise
=
True
,
...
@@ -2021,20 +2229,24 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -2021,20 +2229,24 @@ class LayerNormMLP(TransformerEngineBaseModule):
(
MXFP8Quantizer
,
Float8BlockQuantizer
,
NVFP4Quantizer
),
(
MXFP8Quantizer
,
Float8BlockQuantizer
,
NVFP4Quantizer
),
),
),
)
)
fc1_input_quantizer
.
internal
=
True
fc2_input_quantizer
.
internal
=
True
fc2_input_quantizer
.
optimize_for_gemm
=
True
if
fp8_output
:
if
fp8_output
:
fc2_output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
fc2_output_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_OUTPUT
tex
.
FP8FwdTensors
.
GEMM2_OUTPUT
]
]
if
torch
.
is_grad_enabled
()
:
if
is_grad_enabled
:
fc2_grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
fc2_grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
]
]
fc2_grad_output_quantizer
.
internal
=
True
fc2_grad_output_quantizer
.
internal
=
True
if
not
self
.
sequence_parallel
:
fc2_grad_output_quantizer
.
optimize_for_gemm
=
True
fc1_grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
fc1_grad_output_quantizer
=
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
]
]
fc1_grad_output_quantizer
.
internal
=
True
fc1_grad_output_quantizer
.
internal
=
True
fc1_grad_output_quantizer
.
optimize_for_gemm
=
True
return
(
return
(
fc1_input_quantizer
,
fc1_input_quantizer
,
...
@@ -2051,9 +2263,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -2051,9 +2263,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_grad_output_quantizer
,
fc2_grad_output_quantizer
,
)
)
def
onnx_forward
(
self
,
inp
:
torch
.
Tensor
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
...]]:
def
onnx_forward
(
self
,
inp
:
torch
.
Tensor
,
is_grad_enabled
:
bool
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
...]]:
"""
"""
ONNX-compatible version of the
forward function
that provides numerical equivalence
ONNX-compatible version of the
:meth:`forward` method
that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations.
while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios.
This simplified implementation is designed specifically for inference scenarios.
"""
"""
...
@@ -2061,14 +2275,23 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -2061,14 +2275,23 @@ class LayerNormMLP(TransformerEngineBaseModule):
assert
not
TEDebugState
.
debug_enabled
,
"Debug mode is not supported in ONNX export"
assert
not
TEDebugState
.
debug_enabled
,
"Debug mode is not supported in ONNX export"
assert_warmed_up
(
self
)
assert_warmed_up
(
self
)
# Get quantizers
(
(
fc1_input_quantizer
,
fc1_input_quantizer
,
fc1_weight_quantizer
,
fc1_weight_quantizer
,
_
,
_
,
_
,
_
,
fc2_input_quantizer
,
fc2_input_quantizer
,
fc2_weight_quantizer
,
fc2_weight_quantizer
,
output_quantizer
,
fc2_output_quantizer
,
*
_
,
_
,
)
=
self
.
_get_quantizers
(
False
)
_
,
_
,
)
=
self
.
_get_quantizers
(
False
,
is_grad_enabled
)
inp_dtype
=
inp
.
dtype
inp_dtype
=
inp
.
dtype
fc1_weight
,
fc2_weight
=
self
.
_get_weight_tensors
()
fc1_weight
,
fc2_weight
=
self
.
_get_weight_tensors
()
...
@@ -2142,7 +2365,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -2142,7 +2365,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_out
=
onnx_gemm
(
fc2_weight
,
act_out
,
fc2_bias
)
fc2_out
=
onnx_gemm
(
fc2_weight
,
act_out
,
fc2_bias
)
if
output_quantizer
is
not
None
:
if
fc2_
output_quantizer
is
not
None
:
raise
NotImplementedError
(
"ONNX export of quantized output is not supported"
)
raise
NotImplementedError
(
"ONNX export of quantized output is not supported"
)
if
self
.
return_layernorm_output
:
if
self
.
return_layernorm_output
:
...
@@ -2153,10 +2376,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -2153,10 +2376,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
return
fc2_out
,
fc2_bias
.
to
(
inp_dtype
)
return
fc2_out
,
fc2_bias
.
to
(
inp_dtype
)
return
fc2_out
return
fc2_out
def
_get_debug_quantizers
(
self
,
fp8_output
):
def
_get_debug_quantizers
(
self
,
fp8_output
,
is_grad_enabled
):
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
base_quantizers
=
list
(
self
.
_get_quantizers
(
fp8_output
))
base_quantizers
=
list
(
self
.
_get_quantizers
(
fp8_output
,
is_grad_enabled
))
assert
TEDebugState
.
debug_enabled
assert
TEDebugState
.
debug_enabled
def
make_debug
(
prefix
,
offset
):
def
make_debug
(
prefix
,
offset
):
...
@@ -2276,22 +2499,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -2276,22 +2499,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight_quantizer
.
internal
=
True
fc2_weight_quantizer
.
internal
=
True
return
[
fc1_weight_quantizer
,
fc2_weight_quantizer
]
return
[
fc1_weight_quantizer
,
fc2_weight_quantizer
]
def
_customize_quantizers_float8_blockwise_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on blockwise scaling recipe + layernorm_mlp."""
assert
(
recipe
.
float8_block_scaling
()
),
"blockwise scaling recipe quantizer customization here"
if
fwd
:
if
self
.
sequence_parallel
and
self
.
set_parallel_mode
:
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
all_gather_usage
=
True
else
:
if
self
.
sequence_parallel
and
self
.
set_parallel_mode
:
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
].
all_gather_usage
=
True
def
backward_dw
(
self
):
def
backward_dw
(
self
):
"""
"""
Execute the delayed weight gradient computation.
Execute the delayed weight gradient computation.
...
@@ -2299,7 +2506,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -2299,7 +2506,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
"""
"""
if
not
self
.
need_backward_dw
():
if
not
self
.
need_backward_dw
():
return
return
with
torch
.
cuda
.
nvtx
.
range
(
"_LayerNormMLP_wgrad"
):
with
get_
nvtx
_
range
_context
(
"_LayerNormMLP_wgrad"
):
(
fc2_wgrad
,
fc2_bias_grad_
,
*
_
),
tensor_list_fc2
=
self
.
wgrad_store
.
pop
()
(
fc2_wgrad
,
fc2_bias_grad_
,
*
_
),
tensor_list_fc2
=
self
.
wgrad_store
.
pop
()
if
self
.
use_bias
and
self
.
fc1_bias
.
grad
is
None
:
if
self
.
use_bias
and
self
.
fc1_bias
.
grad
is
None
:
(
fc1_wgrad
,
fc1_bias_grad
,
*
_
),
_
=
self
.
wgrad_store
.
pop
()
(
fc1_wgrad
,
fc1_bias_grad
,
*
_
),
_
=
self
.
wgrad_store
.
pop
()
...
...
Prev
1
…
25
26
27
28
29
30
31
32
33
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment