Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
035c48c0
Commit
035c48c0
authored
Mar 24, 2025
by
yuguo
Browse files
Merge branch 'main' of
https://github.com/NVIDIA/TransformerEngine
parents
ea272d4a
86813893
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
903 additions
and
123 deletions
+903
-123
3rdparty/cudnn-frontend
3rdparty/cudnn-frontend
+1
-1
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+2
-1
qa/L1_pytorch_distributed_unittest/test.sh
qa/L1_pytorch_distributed_unittest/test.sh
+1
-0
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
+399
-0
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
+35
-0
tests/pytorch/references/ref_per_tensor_cs.py
tests/pytorch/references/ref_per_tensor_cs.py
+20
-6
tests/pytorch/test_multi_tensor.py
tests/pytorch/test_multi_tensor.py
+42
-0
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+73
-1
transformer_engine/common/recipe/current_scaling.cu
transformer_engine/common/recipe/current_scaling.cu
+3
-36
transformer_engine/common/recipe/recipe_common.cuh
transformer_engine/common/recipe/recipe_common.cuh
+56
-0
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+6
-0
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu
...src/extensions/multi_tensor/multi_tensor_compute_scale.cu
+66
-0
transformer_engine/pytorch/csrc/extensions/normalization.cpp
transformer_engine/pytorch/csrc/extensions/normalization.cpp
+72
-17
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+3
-0
transformer_engine/pytorch/csrc/extensions/recipe.cpp
transformer_engine/pytorch/csrc/extensions/recipe.cpp
+19
-2
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+29
-1
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+3
-17
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+43
-1
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+11
-14
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+19
-26
No files found.
cudnn-frontend
@
6ed19fd2
Compare
20c28ea7
...
6ed19fd2
Subproject commit
20c28ea798fe99e31d7274e009ee2fbf0e88abfd
Subproject commit
6ed19fd213e33af2d9a1841b1023ccb2f81d45a1
qa/L0_pytorch_unittest/test.sh
View file @
035c48c0
...
@@ -26,7 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test
...
@@ -26,7 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_recipe.py
||
test_fail
"test_recipe.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_recipe.py
||
test_fail
"test_recipe.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_deferred_init.py
||
test_fail
"test_deferred_init.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_deferred_init.py
||
test_fail
"test_deferred_init.py"
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_numerics.py
||
test_fail
"test_numerics.py"
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_numerics.py
||
test_fail
"test_numerics.py"
NVTE_CUDNN_MXFP8_NORM
=
0
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_cuda_graphs.py
||
test_fail
"test_cuda_graphs.py"
PYTORCH_JIT
=
0
NVTE_TORCH_COMPILE
=
0
NVTE_ALLOW_NONDETERMINISTIC_ALGO
=
0 python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_cuda_graphs.py
||
test_fail
"test_cuda_graphs.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_jit.py
||
test_fail
"test_jit.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_jit.py
||
test_fail
"test_jit.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_fused_rope.py
||
test_fail
"test_fused_rope.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_fused_rope.py
||
test_fail
"test_fused_rope.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_float8tensor.py
||
test_fail
"test_float8tensor.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_float8tensor.py
||
test_fail
"test_float8tensor.py"
...
@@ -39,6 +39,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py ||
...
@@ -39,6 +39,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py ||
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_cpu_offloading.py
||
test_fail
"test_cpu_offloading.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_cpu_offloading.py
||
test_fail
"test_cpu_offloading.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_fused_attn.py
||
test_fail
"test_fused_attn.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_fused_attn.py
||
test_fail
"test_fused_attn.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_paged_attn.py
||
test_fail
"test_paged_attn.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_paged_attn.py
||
test_fail
"test_paged_attn.py"
if
[
"
$RET
"
-ne
0
]
;
then
if
[
"
$RET
"
-ne
0
]
;
then
echo
"Error in the following test cases:
$FAILED_CASES
"
echo
"Error in the following test cases:
$FAILED_CASES
"
exit
1
exit
1
...
...
qa/L1_pytorch_distributed_unittest/test.sh
View file @
035c48c0
...
@@ -26,6 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py |
...
@@ -26,6 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py |
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/distributed/test_comm_gemm_overlap.py
||
test_fail
"test_comm_gemm_overlap.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/distributed/test_comm_gemm_overlap.py
||
test_fail
"test_comm_gemm_overlap.py"
# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential
# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
||
test_fail
"test_fused_attn_with_cp.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
||
test_fail
"test_fused_attn_with_cp.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
||
test_fail
"test_cast_master_weights_to_fp8.py"
if
[
"
$RET
"
-ne
0
]
;
then
if
[
"
$RET
"
-ne
0
]
;
then
echo
"Error in the following test cases:
$FAILED_CASES
"
echo
"Error in the following test cases:
$FAILED_CASES
"
...
...
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
0 → 100644
View file @
035c48c0
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
argparse
import
datetime
import
os
import
sys
import
torch
from
torch
import
nn
import
torch.distributed
as
dist
from
transformer_engine.common.recipe
import
(
DelayedScaling
,
Float8CurrentScaling
,
Format
,
Recipe
,
)
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
,
cast_master_weights_to_fp8
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
def
_get_raw_data
(
quantized_tensor
):
"""Get the underlying data of a quantized tensor, used in zero-1 optimizer"""
if
isinstance
(
quantized_tensor
,
Float8Tensor
):
assert
hasattr
(
quantized_tensor
,
"_data"
),
"Float8Tensor does not have _data attribute"
assert
quantized_tensor
.
_data
.
dtype
==
torch
.
uint8
,
"Float8Tensor _data must be uint8"
return
quantized_tensor
.
_data
else
:
raise
ValueError
(
f
"Unsupported quantized tensor type:
{
type
(
quantized_tensor
)
}
"
)
class
MiniZero_1
:
"""A mini zero-1 optimizer implementation, just used for this test"""
def
__init__
(
self
,
weights
,
lr
,
dp_group
):
self
.
rank
=
dist
.
get_rank
(
dp_group
)
self
.
world_size
=
dist
.
get_world_size
(
dp_group
)
self
.
weights
=
weights
self
.
lr
=
lr
self
.
dp_group
=
dp_group
# [self.offsets[i], self.offsets[i+1]) is the range of weights[i] in the global buffer
self
.
offsets
=
[
0
]
for
weight
in
self
.
weights
:
self
.
offsets
.
append
(
self
.
offsets
[
-
1
]
+
weight
.
numel
())
# Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may
# not be the end range of the last weight.
if
self
.
offsets
[
-
1
]
%
self
.
world_size
!=
0
:
self
.
offsets
[
-
1
]
+=
self
.
world_size
-
self
.
offsets
[
-
1
]
%
self
.
world_size
self
.
master_weights
=
[]
# The start offset of the master weight in the weight
self
.
start_offsets
=
[]
# The overlapping area of the weight and this rank's local buffer
self
.
overlapping_areas
=
[]
# The start and end of this rank's local buffer in the global buffer
rank_start
=
self
.
offsets
[
-
1
]
//
self
.
world_size
*
self
.
rank
rank_end
=
rank_start
+
self
.
offsets
[
-
1
]
//
self
.
world_size
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
if
offset
>=
rank_end
or
(
offset
+
weight
.
numel
())
<=
rank_start
:
# This weight is not in this rank's local buffer
master_weight
=
None
start_offset
=
None
overlapping_area
=
None
else
:
overlapping_start
=
max
(
rank_start
,
offset
)
overlapping_end
=
min
(
rank_end
,
offset
+
weight
.
numel
())
length
=
overlapping_end
-
overlapping_start
start_offset
=
overlapping_start
-
offset
if
isinstance
(
weight
,
QuantizedTensor
):
# If weight is a FP8 tensor, we need to use the original high precision version
# to initialize the master weight.
high_precision_init_val
=
weight
.
get_high_precision_init_val
().
view
(
-
1
)
master_weight
=
high_precision_init_val
.
to
(
weight
.
device
).
float
()[
start_offset
:
start_offset
+
length
]
else
:
master_weight
=
(
weight
.
detach
().
view
(
-
1
).
float
()[
start_offset
:
start_offset
+
length
]
)
overlapping_area
=
(
overlapping_start
,
overlapping_end
)
self
.
master_weights
.
append
(
master_weight
)
self
.
start_offsets
.
append
(
start_offset
)
self
.
overlapping_areas
.
append
(
overlapping_area
)
# Create global buffer for grads reduce-scatter
self
.
grad_buffer
=
torch
.
empty
(
[
self
.
offsets
[
-
1
]],
dtype
=
torch
.
float32
,
device
=
weights
[
0
].
device
)
self
.
grad_buffer_slice
=
self
.
grad_buffer
[
rank_start
:
rank_end
]
# Create global buffer for weights all-gather
if
isinstance
(
self
.
weights
[
0
],
QuantizedTensor
):
weight_buffer_dtype
=
torch
.
uint8
else
:
weight_buffer_dtype
=
weights
[
0
].
dtype
self
.
weight_buffer
=
torch
.
empty
(
[
self
.
offsets
[
-
1
]],
dtype
=
weight_buffer_dtype
,
device
=
weights
[
0
].
device
)
self
.
weight_buffer_slice
=
self
.
weight_buffer
[
rank_start
:
rank_end
]
def
step
(
self
):
# -----------------------------------------------------------------------------------------
# Step 1: Copy grads to the grad buffer
# -----------------------------------------------------------------------------------------
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
start
=
offset
end
=
offset
+
weight
.
numel
()
self
.
grad_buffer
[
start
:
end
].
copy_
(
weight
.
main_grad
.
view
(
-
1
))
# -----------------------------------------------------------------------------------------
# Step 2: Grads reduce-scatter
# -----------------------------------------------------------------------------------------
# Don't use reduce_scatter directly to explicitly control the reduce order.
# dist.reduce_scatter_tensor(self.grad_buffer_slice, self.grad_buffer, op=dist.ReduceOp.AVG,
# group=self.dp_group)
buffers
=
[
torch
.
empty_like
(
self
.
grad_buffer
)
for
_
in
range
(
self
.
world_size
)]
dist
.
all_gather
(
buffers
,
self
.
grad_buffer
,
group
=
self
.
dp_group
)
for
i
in
range
(
1
,
self
.
world_size
):
buffers
[
0
]
+=
buffers
[
i
]
rank_start
=
self
.
offsets
[
-
1
]
//
self
.
world_size
*
self
.
rank
rank_end
=
rank_start
+
self
.
offsets
[
-
1
]
//
self
.
world_size
self
.
grad_buffer_slice
.
copy_
(
buffers
[
0
][
rank_start
:
rank_end
])
self
.
grad_buffer_slice
/=
self
.
world_size
# -----------------------------------------------------------------------------------------
# Step 3: Update master weights
# -----------------------------------------------------------------------------------------
for
master_weight
,
overlapping_area
in
zip
(
self
.
master_weights
,
self
.
overlapping_areas
):
if
master_weight
is
None
:
# This weight's master weight is in other rank.
continue
grad
=
self
.
grad_buffer
[
overlapping_area
[
0
]
:
overlapping_area
[
1
]]
master_weight
-=
grad
*
self
.
lr
# -----------------------------------------------------------------------------------------
# Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight
# -----------------------------------------------------------------------------------------
if
isinstance
(
self
.
weights
[
0
],
QuantizedTensor
):
# FP8 weights case
for
i
in
range
(
1
,
len
(
self
.
weights
)):
assert
isinstance
(
self
.
weights
[
i
],
QuantizedTensor
)
cast_master_weights_to_fp8
(
self
.
weights
,
self
.
master_weights
,
self
.
start_offsets
,
self
.
dp_group
)
else
:
# BF16 weights case
for
weight
,
master_weight
,
start_offset
in
zip
(
self
.
weights
,
self
.
master_weights
,
self
.
start_offsets
):
if
master_weight
is
None
:
continue
start
=
start_offset
end
=
start_offset
+
master_weight
.
numel
()
weight
.
data
.
view
(
-
1
)[
start
:
end
].
copy_
(
master_weight
)
# -----------------------------------------------------------------------------------------
# Step 5: Copy the updated weights (not all weights) to the weight buffer
# -----------------------------------------------------------------------------------------
for
i
in
range
(
len
(
self
.
weights
)):
master_weight
=
self
.
master_weights
[
i
]
if
master_weight
is
None
:
continue
start_offset
=
self
.
start_offsets
[
i
]
if
isinstance
(
self
.
weights
[
i
],
QuantizedTensor
):
weight
=
_get_raw_data
(
self
.
weights
[
i
])
else
:
weight
=
self
.
weights
[
i
]
weight_slice
=
weight
.
view
(
-
1
)[
start_offset
:
start_offset
+
master_weight
.
numel
()]
overlapping_start
,
overlapping_end
=
self
.
overlapping_areas
[
i
]
self
.
weight_buffer
[
overlapping_start
:
overlapping_end
].
copy_
(
weight_slice
)
# -----------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -----------------------------------------------------------------------------------------
dist
.
all_gather_into_tensor
(
self
.
weight_buffer
,
self
.
weight_buffer_slice
,
group
=
self
.
dp_group
)
# -----------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# -----------------------------------------------------------------------------------------
for
weight
,
offset
in
zip
(
self
.
weights
,
self
.
offsets
[:
-
1
]):
start
=
offset
end
=
offset
+
weight
.
numel
()
if
isinstance
(
weight
,
QuantizedTensor
):
weight
=
_get_raw_data
(
weight
)
weight
.
view
(
-
1
).
data
.
copy_
(
self
.
weight_buffer
[
start
:
end
])
class
MiniOptimizer
:
def
__init__
(
self
,
weights
,
lr
,
dp_group
):
self
.
world_size
=
dist
.
get_world_size
(
dp_group
)
self
.
weights
=
weights
self
.
lr
=
lr
self
.
dp_group
=
dp_group
master_weights
=
[]
for
weight
in
self
.
weights
:
master_weights
.
append
(
weight
.
detach
().
float
())
self
.
master_weights
=
master_weights
def
step
(
self
):
for
weight
,
master_weight
in
zip
(
self
.
weights
,
self
.
master_weights
):
main_grad
=
weight
.
main_grad
# Don't use all-reduce directly to explicitly control the reduce order.
# dist.all_reduce(main_grad, op=dist.ReduceOp.AVG, group=self.dp_group)
buffers
=
[
torch
.
empty_like
(
main_grad
)
for
_
in
range
(
self
.
world_size
)]
dist
.
all_gather
(
buffers
,
main_grad
,
group
=
self
.
dp_group
)
for
i
in
range
(
1
,
self
.
world_size
):
buffers
[
0
]
+=
buffers
[
i
]
main_grad
.
copy_
(
buffers
[
0
])
main_grad
/=
self
.
world_size
master_weight
-=
main_grad
*
self
.
lr
weight
.
data
.
copy_
(
master_weight
)
def
_test_zero_1
(
dp_group
):
"""Make sure the implementation of zero-1 optimizer is correct"""
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
torch
.
manual_seed
(
12345
)
torch
.
cuda
.
manual_seed
(
12345
)
weights
=
[
torch
.
randn
(
256
*
256
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
torch
.
randn
(
256
*
256
*
3
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
torch
.
randn
(
256
*
256
*
2
-
1
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
]
weights_1
=
weights
weights_2
=
[
weight
.
clone
()
for
weight
in
weights
]
lr
=
1.0
optimizer_1
=
MiniZero_1
(
weights_1
,
lr
,
dp_group
)
optimizer_2
=
MiniOptimizer
(
weights_2
,
lr
,
dp_group
)
for
_
in
range
(
100
):
for
w1
,
w2
in
zip
(
weights_1
,
weights_2
):
main_grads
=
[
torch
.
randn_like
(
w1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
# Choose based on rank to make sure the grads of different ranks are different.
main_grad
=
main_grads
[
rank
]
w1
.
main_grad
=
main_grad
w2
.
main_grad
=
main_grad
optimizer_1
.
step
()
optimizer_2
.
step
()
for
w1
,
w2
in
zip
(
weights_1
,
weights_2
):
torch
.
testing
.
assert_close
(
w1
,
w2
,
atol
=
0
,
rtol
=
0
)
def
quantization_recipe
(
quantization
)
->
Recipe
:
"""Quantization recipe setup"""
if
quantization
==
"fp8"
:
return
DelayedScaling
(
fp8_format
=
Format
.
HYBRID
,
amax_history_len
=
32
,
amax_compute_algo
=
"max"
)
elif
quantization
==
"fp8_cs"
:
return
Float8CurrentScaling
()
else
:
raise
ValueError
(
f
"Unsupported quantization:
{
quantization
}
"
)
def
_test_cast_master_weights_to_fp8
(
quantization
,
dp_group
):
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
torch
.
manual_seed
(
12345
)
torch
.
cuda
.
manual_seed
(
12345
)
mock_groups
=
[
dist
.
new_group
(
ranks
=
[
i
])
for
i
in
range
(
world_size
)]
mock_group
=
mock_groups
[
rank
]
linear_kwargs
=
{
"params_dtype"
:
torch
.
bfloat16
,
"bias"
:
False
,
"fuse_wgrad_accumulation"
:
True
}
# Create model with FP8 weights
with
te
.
fp8
.
fp8_model_init
(
enabled
=
quantization
is
not
None
,
recipe
=
quantization_recipe
(
quantization
),
preserve_high_precision_init_val
=
True
,
):
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Create model with BF16 weights
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
high_precision_init_val
=
w_fp8
.
get_high_precision_init_val
()
w
.
data
.
copy_
(
high_precision_init_val
)
# Allocate main_grads for each weight
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
w_fp8
.
main_grad
=
torch
.
zeros_like
(
w_fp8
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
w
.
main_grad
=
torch
.
zeros_like
(
w
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
optimizer_fp8
=
MiniZero_1
([
w
for
w
in
model_fp8
.
parameters
()],
10.0
,
dp_group
)
optimizer
=
MiniZero_1
([
w
for
w
in
model
.
parameters
()],
10.0
,
dp_group
)
for
_
in
range
(
100
):
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
w_fp8
.
main_grad
.
zero_
()
w
.
main_grad
.
zero_
()
inputs
=
[
torch
.
randn
(
16
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x
=
inputs
[
rank
]
with
te
.
fp8
.
fp8_autocast
(
enabled
=
quantization
is
not
None
,
fp8_recipe
=
quantization_recipe
(
quantization
),
fp8_group
=
mock_group
,
):
y_fp8
=
model_fp8
(
x
)
with
te
.
fp8_autocast
(
enabled
=
quantization
is
not
None
,
fp8_recipe
=
quantization_recipe
(
quantization
),
fp8_group
=
mock_group
,
):
y
=
model
(
x
)
targets
=
[
torch
.
randn_like
(
y
)
for
_
in
range
(
world_size
)]
# Choose based on rank to make sure the targets of different ranks are different.
target
=
targets
[
rank
]
loss_fp8
=
nn
.
MSELoss
()(
y_fp8
,
target
)
loss
=
nn
.
MSELoss
()(
y
,
target
)
loss_fp8
.
backward
()
loss
.
backward
()
optimizer_fp8
.
step
()
optimizer
.
step
()
torch
.
testing
.
assert_close
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
)
def
main
(
argv
=
None
,
namespace
=
None
):
WORLD_RANK
=
int
(
os
.
getenv
(
"RANK"
,
"0"
))
WORLD_SIZE
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
LOCAL_RANK
=
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
LOCAL_SIZE
=
int
(
os
.
getenv
(
"LOCAL_WORLD_SIZE"
,
"1"
))
assert
WORLD_SIZE
==
LOCAL_SIZE
# this test supports only 1 node
assert
LOCAL_SIZE
<=
torch
.
cuda
.
device_count
()
dist_init_kwargs
=
{
"backend"
:
"nccl"
,
"rank"
:
WORLD_RANK
,
"world_size"
:
WORLD_SIZE
,
"timeout"
:
datetime
.
timedelta
(
seconds
=
30
),
}
dist_init_kwargs
[
"init_method"
]
=
"env://"
dist_init_kwargs
[
"device_id"
]
=
torch
.
device
(
f
"cuda:
{
LOCAL_RANK
}
"
)
assert
dist
.
is_nccl_available
()
torch
.
cuda
.
set_device
(
LOCAL_RANK
)
dist
.
init_process_group
(
**
dist_init_kwargs
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--quantization"
,
type
=
str
,
default
=
None
,
choices
=
[
"fp8"
,
"fp8_cs"
])
args
=
parser
.
parse_args
(
argv
,
namespace
)
dp_group
=
dist
.
new_group
(
backend
=
"nccl"
)
_test_zero_1
(
dp_group
)
_test_cast_master_weights_to_fp8
(
args
.
quantization
,
dp_group
)
dist
.
destroy_process_group
()
return
0
if
__name__
==
"__main__"
:
sys
.
exit
(
main
())
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
0 → 100644
View file @
035c48c0
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
os
import
subprocess
from
pathlib
import
Path
import
pytest
import
torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
if
torch
.
cuda
.
device_count
()
<
2
:
pytest
.
skip
(
"cast_master_weights_to_fp8 test needs at least 2 GPUs."
)
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
TEST_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
NUM_PROCS
:
int
=
min
(
2
,
torch
.
cuda
.
device_count
())
LAUNCH_CMD
=
[
"torchrun"
,
f
"--nproc_per_node=
{
NUM_PROCS
}
"
]
def
_run_test
(
quantization
):
test_path
=
TEST_ROOT
/
"run_cast_master_weights_to_fp8.py"
test_cmd
=
LAUNCH_CMD
+
[
str
(
test_path
)]
+
[
"--quantization"
,
quantization
]
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
check
=
False
)
assert
result
.
returncode
==
0
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
"fp8"
,
"fp8_cs"
])
def
test_cast_master_weights_to_fp8
(
quantization
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
_run_test
(
quantization
)
tests/pytorch/references/ref_per_tensor_cs.py
View file @
035c48c0
...
@@ -8,12 +8,8 @@ import transformer_engine_torch as tex
...
@@ -8,12 +8,8 @@ import transformer_engine_torch as tex
from
transformer_engine.pytorch.constants
import
TE_DType_To_Torch
from
transformer_engine.pytorch.constants
import
TE_DType_To_Torch
# compute amax and scale
# Compute scale and scale_inv from amax
def
_ref_compute_amax_scale
(
x
,
quant_dtype
,
eps
,
pow_2_scales
):
def
_ref_compute_scale_and_scale_inv_from_amax
(
amax
,
fp8_max
,
eps
,
pow_2_scales
):
x_fp32
=
x
.
to
(
torch
.
float32
)
amax
=
torch
.
amax
(
torch
.
abs
(
x_fp32
)).
view
(
1
)
assert
amax
.
dtype
==
torch
.
float
,
"amax must be a float tensor."
fp8_max
=
torch
.
finfo
(
quant_dtype
).
max
# Clamping amax to avoid division by small numbers
# Clamping amax to avoid division by small numbers
amax
=
torch
.
max
(
amax
,
torch
.
tensor
(
eps
))
amax
=
torch
.
max
(
amax
,
torch
.
tensor
(
eps
))
...
@@ -52,6 +48,20 @@ def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales):
...
@@ -52,6 +48,20 @@ def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales):
# Compute scale_inv
# Compute scale_inv
scale_inv
=
torch
.
reciprocal
(
scale
)
scale_inv
=
torch
.
reciprocal
(
scale
)
return
scale
,
scale_inv
# compute amax and scale
def
_ref_compute_amax_scale
(
x
,
quant_dtype
,
eps
,
pow_2_scales
):
x_fp32
=
x
.
to
(
torch
.
float32
)
amax
=
torch
.
amax
(
torch
.
abs
(
x_fp32
)).
view
(
1
)
assert
amax
.
dtype
==
torch
.
float
,
"amax must be a float tensor."
fp8_max
=
torch
.
finfo
(
quant_dtype
).
max
scale
,
scale_inv
=
_ref_compute_scale_and_scale_inv_from_amax
(
amax
,
fp8_max
,
eps
,
pow_2_scales
)
# Clamping amax to avoid division by small numbers
amax
=
torch
.
max
(
amax
,
torch
.
tensor
(
eps
))
return
scale
,
scale_inv
,
amax
return
scale
,
scale_inv
,
amax
...
@@ -103,3 +113,7 @@ def ref_per_tensor_cs_cast(
...
@@ -103,3 +113,7 @@ def ref_per_tensor_cs_cast(
qx_t
=
_multi_dim_transpose
(
qx
)
qx_t
=
_multi_dim_transpose
(
qx
)
sx_t
=
sx
sx_t
=
sx
return
qx
,
sx
,
qx_t
,
sx_t
return
qx
,
sx
,
qx_t
,
sx_t
def
ref_compute_scale_and_scale_inv_from_amax
(
amax
,
fp8_max
,
eps
,
pow_2_scales
):
return
_ref_compute_scale_and_scale_inv_from_amax
(
amax
,
fp8_max
,
eps
,
pow_2_scales
)
tests/pytorch/test_multi_tensor.py
View file @
035c48c0
...
@@ -9,6 +9,9 @@ import transformer_engine.pytorch as te
...
@@ -9,6 +9,9 @@ import transformer_engine.pytorch as te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.optimizers
import
MultiTensorApply
from
transformer_engine.pytorch.optimizers
import
MultiTensorApply
from
references.ref_per_tensor_cs
import
ref_compute_scale_and_scale_inv_from_amax
input_size_pairs
=
[
input_size_pairs
=
[
(
7777
*
77
,
555
*
555
),
(
7777
*
77
,
555
*
555
),
(
777
,
555
),
(
777
,
555
),
...
@@ -216,3 +219,42 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type,
...
@@ -216,3 +219,42 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type,
if
per_tensor
:
if
per_tensor
:
torch
.
testing
.
assert_close
(
norm_per_tensor
,
normab
.
broadcast_to
(
norm_per_tensor
.
shape
))
torch
.
testing
.
assert_close
(
norm_per_tensor
,
normab
.
broadcast_to
(
norm_per_tensor
.
shape
))
assert
overflow_buf
.
item
()
==
0
assert
overflow_buf
.
item
()
==
0
@
pytest
.
mark
.
parametrize
(
"input_size_pair"
,
input_size_pairs
+
[(
1
,
1
)])
@
pytest
.
mark
.
parametrize
(
"applier"
,
appliers
)
@
pytest
.
mark
.
parametrize
(
"repeat"
,
[
1
,
55
])
@
pytest
.
mark
.
parametrize
(
"max_fp8"
,
[
448.0
,
57344.0
])
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"epsilon"
,
[
0.0
,
100.0
])
def
test_multi_tensor_compute_scale_and_scale_inv
(
input_size_pair
,
applier
,
repeat
,
max_fp8
,
pow_2_scales
,
epsilon
):
sizea
,
sizeb
=
input_size_pair
device
=
torch
.
device
(
"cuda"
)
overflow_buf
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
device
)
a
=
torch
.
randn
([
sizea
],
dtype
=
torch
.
float32
,
device
=
device
).
abs
()
b
=
torch
.
randn
([
sizeb
],
dtype
=
torch
.
float32
,
device
=
device
).
abs
()
amax_list
=
[]
for
i
in
range
(
repeat
):
amax_list
+=
[
a
.
clone
(),
b
.
clone
()]
scale_list
=
[
torch
.
empty_like
(
x
)
for
x
in
amax_list
]
scale_inv_list
=
[
torch
.
empty_like
(
x
)
for
x
in
amax_list
]
applier
(
tex
.
multi_tensor_compute_scale_and_scale_inv
,
overflow_buf
,
[
amax_list
,
scale_list
,
scale_inv_list
],
max_fp8
,
pow_2_scales
,
epsilon
,
)
for
amax
,
scale
,
scale_inv
in
zip
(
amax_list
,
scale_list
,
scale_inv_list
):
scale_ref
,
scale_inv_ref
=
ref_compute_scale_and_scale_inv_from_amax
(
amax
,
max_fp8
,
epsilon
,
pow_2_scales
)
torch
.
testing
.
assert_close
(
scale
,
scale_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
scale_inv
,
scale_inv_ref
,
rtol
=
0
,
atol
=
0
)
tests/pytorch/test_sanity.py
View file @
035c48c0
...
@@ -37,7 +37,12 @@ from transformer_engine.common import recipe
...
@@ -37,7 +37,12 @@ from transformer_engine.common import recipe
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
from
transformer_engine.pytorch.module.base
import
get_workspace
from
transformer_engine.pytorch.module.base
import
get_workspace
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
test_numerics
import
reset_rng_states
,
dtype_tols
from
test_numerics
import
reset_rng_states
,
dtype_tols
# Only run FP8 tests on supported devices.
# Only run FP8 tests on supported devices.
...
@@ -1207,3 +1212,70 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False
...
@@ -1207,3 +1212,70 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False
outputs
.
append
(
p
.
grad
)
outputs
.
append
(
p
.
grad
)
return
outputs
return
outputs
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_replace_raw_data_for_float8tensor
():
"""Test the functionality of replace_raw_data"""
torch
.
manual_seed
(
12345
)
torch
.
cuda
.
manual_seed
(
12345
)
fp8_quantizer
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
device
=
"cuda"
)
fp8_tensor
=
fp8_quantizer
.
make_empty
([
128
,
128
],
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
random_bf16_data
=
torch
.
randn
(
fp8_tensor
.
shape
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
fp8_quantizer
.
update_quantized
(
random_bf16_data
,
fp8_tensor
)
attrs_to_check
=
[
"_quantizer"
,
"_fp8_dtype"
,
"_scale_inv"
,
"_transpose"
,
"_transpose_invalid"
]
attrs
=
{}
for
attr
in
attrs_to_check
:
attrs
[
attr
]
=
getattr
(
fp8_tensor
,
attr
)
old_data
=
fp8_tensor
.
_data
new_data
=
torch
.
empty_like
(
old_data
)
replace_raw_data
(
fp8_tensor
,
new_data
)
# Make sure the new_data is properly assigned.
assert
fp8_tensor
.
_data
.
data_ptr
()
!=
old_data
.
data_ptr
()
assert
fp8_tensor
.
_data
.
data_ptr
()
==
new_data
.
data_ptr
()
# Make sure the values are not changed.
torch
.
testing
.
assert_close
(
old_data
,
fp8_tensor
.
_data
,
atol
=
0
,
rtol
=
0
)
# Make sure other attributes are not changed (totally identical)
for
attr
in
attrs_to_check
:
assert
id
(
getattr
(
fp8_tensor
,
attr
))
==
id
(
attrs
[
attr
])
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_fp8_model_init_high_precision_init_val
():
"""Test fp8_model_init with preserve_high_precision_init_val=True"""
with
fp8_model_init
(
preserve_high_precision_init_val
=
True
):
model
=
Linear
(
768
,
768
)
weight
=
model
.
weight
assert
isinstance
(
weight
,
QuantizedTensor
),
"Weight should be QuantizedTensor"
assert
hasattr
(
weight
,
"_high_precision_init_val"
),
"_high_precision_init_val not found"
assert
hasattr
(
weight
,
"get_high_precision_init_val"
),
"get_high_precision_init_val() not found"
assert
hasattr
(
weight
,
"clear_high_precision_init_val"
),
"clear_high_precision_init_val() not found"
high_precision
=
weight
.
get_high_precision_init_val
()
assert
high_precision
.
device
.
type
==
"cpu"
,
"high_precision_init_val is not on the CPU"
new_weight
=
weight
.
_get_quantizer
().
make_empty
(
shape
=
weight
.
shape
,
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
weight
.
_get_quantizer
().
update_quantized
(
high_precision
.
to
(
weight
.
device
),
new_weight
)
torch
.
testing
.
assert_close
(
new_weight
.
dequantize
(
dtype
=
weight
.
dtype
),
weight
.
dequantize
(
dtype
=
weight
.
dtype
),
rtol
=
0
,
atol
=
0
,
)
weight
.
clear_high_precision_init_val
()
assert
weight
.
get_high_precision_init_val
()
is
None
,
"clear_high_precision_init_val() not work"
assert
not
hasattr
(
weight
,
"._high_precision_init_val"
),
"clear_high_precision_init_val() not work"
transformer_engine/common/recipe/current_scaling.cu
View file @
035c48c0
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
#include "../common.h"
#include "../common.h"
#include "../util/logging.h"
#include "../util/logging.h"
#include "../util/vectorized_pointwise.h"
#include "../util/vectorized_pointwise.h"
#include "recipe_common.cuh"
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
{
namespace
{
...
@@ -135,7 +136,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
...
@@ -135,7 +136,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
"Output tensor for amax computation has invalid amax tensor "
"Output tensor for amax computation has invalid amax tensor "
"(expected FP32, got dtype="
,
"(expected FP32, got dtype="
,
to_string
(
output
.
amax
.
dtype
),
")"
);
to_string
(
output
.
amax
.
dtype
),
")"
);
CheckOutputTensor
(
output
,
"output_compute_amax"
);
CheckOutputTensor
(
output
,
"output_compute_amax"
,
true
);
// Compute amax
// Compute amax
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
...
@@ -151,41 +152,7 @@ namespace {
...
@@ -151,41 +152,7 @@ namespace {
__global__
void
compute_scale_from_amax_kernel
(
const
float
*
amax_ptr
,
float
*
scale_ptr
,
__global__
void
compute_scale_from_amax_kernel
(
const
float
*
amax_ptr
,
float
*
scale_ptr
,
const
float
max_fp8
,
const
bool
force_pow_2_scales
,
const
float
max_fp8
,
const
bool
force_pow_2_scales
,
const
float
epsilon
)
{
const
float
epsilon
)
{
float
amax
=
*
amax_ptr
;
*
scale_ptr
=
compute_scale_from_amax
(
*
amax_ptr
,
max_fp8
,
force_pow_2_scales
,
epsilon
);
if
(
amax
<
epsilon
)
{
amax
=
epsilon
;
}
float
scale
=
1.
f
;
if
(
isinf
(
amax
)
||
amax
==
0.
f
)
{
*
scale_ptr
=
scale
;
return
;
}
scale
=
max_fp8
/
amax
;
// The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32.
if
(
isinf
(
scale
))
{
// use fp32 max to represent the scale
scale
=
std
::
numeric_limits
<
float
>::
max
();
}
if
(
isnan
(
scale
))
{
scale
=
1.
f
;
}
if
(
force_pow_2_scales
)
{
uint32_t
scale_bits
=
*
reinterpret_cast
<
uint32_t
*>
(
&
scale
);
scale_bits
&=
0xFF800000
;
// If the exponent was zero, we have a logic error.
__builtin_assume
(
scale_bits
!=
0
);
__builtin_assume
(
scale_bits
!=
0x80000000
);
scale
=
*
reinterpret_cast
<
float
*>
(
&
scale_bits
);
}
*
scale_ptr
=
scale
;
}
}
}
// namespace
}
// namespace
...
...
transformer_engine/common/recipe/recipe_common.cuh
0 → 100644
View file @
035c48c0
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
#define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
#include <limits>
namespace
transformer_engine
{
__device__
__forceinline__
float
compute_scale_from_amax
(
float
amax
,
float
max_fp8
,
bool
force_pow_2_scales
,
float
epsilon
)
{
if
(
amax
<
epsilon
)
{
amax
=
epsilon
;
}
float
scale
=
1.
f
;
if
(
isinf
(
amax
)
||
amax
==
0.
f
)
{
return
scale
;
}
// Here we don't use "scale = max_fp8 / amax" because it has different results with/without
// "--use_fast_math".
// "__fdiv_rn" has the same behavior with "max_fp8 / amax" when not using fast math.
scale
=
__fdiv_rn
(
max_fp8
,
amax
);
// The amax is too small that the scale becoming infinite in FP32. In other word,
// the scale is not representable in FP32.
if
(
isinf
(
scale
))
{
// use fp32 max to represent the scale
scale
=
std
::
numeric_limits
<
float
>::
max
();
}
if
(
isnan
(
scale
))
{
scale
=
1.
f
;
}
if
(
force_pow_2_scales
)
{
uint32_t
scale_bits
=
*
reinterpret_cast
<
uint32_t
*>
(
&
scale
);
scale_bits
&=
0xFF800000
;
// If the exponent was zero, we have a logic error.
__builtin_assume
(
scale_bits
!=
0
);
__builtin_assume
(
scale_bits
!=
0x80000000
);
scale
=
*
reinterpret_cast
<
float
*>
(
&
scale_bits
);
}
return
scale
;
}
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
transformer_engine/pytorch/csrc/extensions.h
View file @
035c48c0
...
@@ -262,6 +262,8 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads
...
@@ -262,6 +262,8 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads
* FP8 recipe
* FP8 recipe
**************************************************************************************************/
**************************************************************************************************/
void
compute_amax
(
const
at
::
Tensor
&
tensor
,
at
::
Tensor
&
amax
);
void
fused_amax_and_scale_update_after_reduction
(
const
at
::
Tensor
&
amax_reduction_buffer
,
void
fused_amax_and_scale_update_after_reduction
(
const
at
::
Tensor
&
amax_reduction_buffer
,
std
::
vector
<
at
::
Tensor
>
amax_histories
,
std
::
vector
<
at
::
Tensor
>
amax_histories
,
std
::
vector
<
at
::
Tensor
>
scales
,
std
::
vector
<
at
::
Tensor
>
scales
,
...
@@ -369,6 +371,10 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
...
@@ -369,6 +371,10 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
float
momentum
,
float
dampening
,
float
lr
,
bool
nesterov
,
bool
first_run
,
float
momentum
,
float
dampening
,
float
lr
,
bool
nesterov
,
bool
first_run
,
bool
wd_after_momentum
,
float
scale
);
bool
wd_after_momentum
,
float
scale
);
void
multi_tensor_compute_scale_and_scale_inv_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
max_fp8
,
bool
force_pow_2_scales
,
float
epsilon
);
/***************************************************************************************************
/***************************************************************************************************
* padding
* padding
**************************************************************************************************/
**************************************************************************************************/
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu
0 → 100644
View file @
035c48c0
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream>
#include "common/recipe/recipe_common.cuh"
#include "common/utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 256
struct
ComputeScaleAndScaleInvFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
3
>
&
tl
,
// NOLINT(*)
float
max_fp8
,
bool
force_pow_2_scales
,
float
epsilon
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
*
amax
=
reinterpret_cast
<
float
*>
(
tl
.
addresses
[
0
][
tensor_loc
]);
amax
+=
chunk_idx
*
chunk_size
;
float
*
scale
=
reinterpret_cast
<
float
*>
(
tl
.
addresses
[
1
][
tensor_loc
]);
scale
+=
chunk_idx
*
chunk_size
;
float
*
scale_inv
=
reinterpret_cast
<
float
*>
(
tl
.
addresses
[
2
][
tensor_loc
]);
scale_inv
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
for
(
int
i_start
=
threadIdx
.
x
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
float
scale_val
=
transformer_engine
::
compute_scale_from_amax
(
amax
[
i_start
],
max_fp8
,
force_pow_2_scales
,
epsilon
);
scale
[
i_start
]
=
scale_val
;
transformer_engine
::
reciprocal
(
scale_inv
+
i_start
,
scale_val
);
}
}
};
void
multi_tensor_compute_scale_and_scale_inv_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
max_fp8
,
bool
force_pow_2_scales
,
float
epsilon
)
{
using
namespace
at
;
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
ComputeScaleAndScaleInvFunctor
(),
max_fp8
,
force_pow_2_scales
,
epsilon
);
AT_CUDA_CHECK
(
cudaGetLastError
());
}
transformer_engine/pytorch/csrc/extensions/normalization.cpp
View file @
035c48c0
...
@@ -6,12 +6,13 @@
...
@@ -6,12 +6,13 @@
#include "common/util/system.h"
#include "common/util/system.h"
#include "extensions.h"
#include "extensions.h"
#include "pybind.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
std
::
pair
<
TensorWrapper
,
py
::
object
>
createOutputTensor
(
const
NVTEShape
&
shape
,
DType
dtype
,
std
::
pair
<
TensorWrapper
,
py
::
object
>
createOutputTensor
(
const
NVTEShape
&
shape
,
DType
dtype
,
py
::
handle
quantizer
)
{
py
::
handle
quantizer
)
{
std
::
vector
<
size_t
>
shape_vec
;
std
::
vector
<
size_t
>
shape_vec
;
for
(
in
t
i
=
0
;
i
<
shape
.
ndim
;
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
shape
.
ndim
;
i
++
)
{
size_t
t
=
shape
.
data
[
i
];
size_t
t
=
shape
.
data
[
i
];
shape_vec
.
push_back
(
t
);
shape_vec
.
push_back
(
t
);
}
}
...
@@ -74,6 +75,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
...
@@ -74,6 +75,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
float
eps
,
py
::
object
out
,
py
::
handle
quantizer
,
float
eps
,
py
::
object
out
,
py
::
handle
quantizer
,
DType
out_dtype
,
const
int
sm_margin
,
DType
out_dtype
,
const
int
sm_margin
,
const
bool
zero_centered_gamma
)
{
const
bool
zero_centered_gamma
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
using
namespace
transformer_engine
::
pytorch
;
using
namespace
transformer_engine
::
pytorch
;
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
@@ -107,14 +109,17 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
...
@@ -107,14 +109,17 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
}
}
// Determine whether to avoid fused kernel
// Determine whether to avoid fused kernel
bool
force_unfused_kernel
=
false
;
bool
force_unfused_kernel
=
true
;
if
(
my_quantizer
->
get_scaling_mode
()
==
NVTE_MXFP8_1D_SCALING
)
{
if
(
quantizer
.
is_none
())
{
if
(
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_CUDNN_MXFP8_NORM"
,
false
))
{
// No need for separate quantization step if output is unquantized
// TE only supports MXFP8 norm with cuDNN backend
force_unfused_kernel
=
false
;
force_unfused_kernel
=
true
;
}
else
if
(
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
}
else
if
(
N
%
128
!=
0
||
H
%
128
!=
0
)
{
// Always used fused kernel for FP8 delayed scaling
// cuDNN norm requires full tile for MXFP8
force_unfused_kernel
=
false
;
force_unfused_kernel
=
true
;
}
else
if
(
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
if
(
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
// cuDNN MXFP8 kernel requires full tile
force_unfused_kernel
=
N
%
128
!=
0
||
H
%
128
!=
0
;
}
}
}
}
TensorWrapper
unquantized_out_cu
;
TensorWrapper
unquantized_out_cu
;
...
@@ -145,6 +150,29 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
...
@@ -145,6 +150,29 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel
// Quantize output if using unfused kernel
if
(
force_unfused_kernel
)
{
if
(
force_unfused_kernel
)
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
nvte_compute_amax
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
// check if we need to do amax reudction (depending on model parallel configs)
if
(
my_quantizer_cs
->
with_amax_reduction
)
{
c10
::
intrusive_ptr
<
dist_group_type
>
process_group_ptr
=
my_quantizer_cs
->
amax_reduction_group
;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at
::
Tensor
&
amax_tensor_torch
=
my_quantizer_cs
->
amax
;
std
::
vector
<
at
::
Tensor
>
tensors
=
{
amax_tensor_torch
};
// allreduce amax tensor
c10d
::
AllreduceOptions
allreduce_opts
;
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
}
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
nvte_compute_scale_from_amax
(
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
out_cu
.
defaultShape
);
}
nvte_quantize_noop
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
nullptr
,
nvte_quantize_noop
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
nullptr
,
at
::
cuda
::
getCurrentCUDAStream
());
at
::
cuda
::
getCurrentCUDAStream
());
}
}
...
@@ -196,6 +224,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
...
@@ -196,6 +224,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
py
::
object
out
,
py
::
handle
quantizer
,
py
::
object
out
,
py
::
handle
quantizer
,
transformer_engine
::
DType
out_dtype
,
const
int
sm_margin
,
transformer_engine
::
DType
out_dtype
,
const
int
sm_margin
,
const
bool
zero_centered_gamma
)
{
const
bool
zero_centered_gamma
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
using
namespace
transformer_engine
::
pytorch
;
using
namespace
transformer_engine
::
pytorch
;
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
@@ -223,14 +252,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
...
@@ -223,14 +252,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
}
}
// Determine whether to avoid fused kernel
// Determine whether to avoid fused kernel
bool
force_unfused_kernel
=
false
;
bool
force_unfused_kernel
=
true
;
if
(
my_quantizer
->
get_scaling_mode
()
==
NVTE_MXFP8_1D_SCALING
)
{
if
(
quantizer
.
is_none
())
{
if
(
!
transformer_engine
::
getenv
<
bool
>
(
"NVTE_CUDNN_MXFP8_NORM"
,
false
))
{
// No need for separate quantization step if output is unquantized
// TE only supports MXFP8 norm with cuDNN backend
force_unfused_kernel
=
false
;
force_unfused_kernel
=
true
;
}
else
if
(
IsFloat8Quantizers
(
quantizer
.
ptr
()))
{
}
else
if
(
N
%
128
!=
0
||
H
%
128
!=
0
)
{
// Always used fused kernel for FP8 delayed scaling
// cuDNN norm requires full tile for MXFP8
force_unfused_kernel
=
false
;
force_unfused_kernel
=
true
;
}
else
if
(
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
if
(
transformer_engine
::
getenv
<
bool
>
(
"NVTE_NORM_FWD_USE_CUDNN"
))
{
// cuDNN MXFP8 kernel requires full tile
force_unfused_kernel
=
N
%
128
!=
0
||
H
%
128
!=
0
;
}
}
}
}
TensorWrapper
unquantized_out_cu
;
TensorWrapper
unquantized_out_cu
;
...
@@ -261,6 +293,29 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
...
@@ -261,6 +293,29 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel
// Quantize output if using unfused kernel
if
(
force_unfused_kernel
)
{
if
(
force_unfused_kernel
)
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
nvte_compute_amax
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
// check if we need to do amax reudction (depending on model parallel configs)
if
(
my_quantizer_cs
->
with_amax_reduction
)
{
c10
::
intrusive_ptr
<
dist_group_type
>
process_group_ptr
=
my_quantizer_cs
->
amax_reduction_group
;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at
::
Tensor
&
amax_tensor_torch
=
my_quantizer_cs
->
amax
;
std
::
vector
<
at
::
Tensor
>
tensors
=
{
amax_tensor_torch
};
// allreduce amax tensor
c10d
::
AllreduceOptions
allreduce_opts
;
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
}
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
nvte_compute_scale_from_amax
(
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
out_cu
.
defaultShape
);
}
nvte_quantize_noop
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
nullptr
,
nvte_quantize_noop
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
nullptr
,
at
::
cuda
::
getCurrentCUDAStream
());
at
::
cuda
::
getCurrentCUDAStream
());
}
}
...
...
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
035c48c0
...
@@ -181,6 +181,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -181,6 +181,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"dtype"
),
py
::
kw_only
(),
py
::
arg
(
"out"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
arg
(
"dtype"
),
py
::
kw_only
(),
py
::
arg
(
"out"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"get_fused_attn_backend"
,
&
get_fused_attn_backend
,
"Get Fused Attention backend"
,
m
.
def
(
"get_fused_attn_backend"
,
&
get_fused_attn_backend
,
"Get Fused Attention backend"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"compute_amax"
,
&
compute_amax
,
"Compute amax"
,
py
::
arg
(
"input"
),
py
::
arg
(
"amax"
));
m
.
def
(
"fused_amax_and_scale_update_after_reduction"
,
&
fused_amax_and_scale_update_after_reduction
,
m
.
def
(
"fused_amax_and_scale_update_after_reduction"
,
&
fused_amax_and_scale_update_after_reduction
,
"Update amax history and FP8 scale/scale_inv after reduction"
,
"Update amax history and FP8 scale/scale_inv after reduction"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
...
@@ -271,6 +272,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -271,6 +272,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"multi_tensor_sgd"
,
&
multi_tensor_sgd_cuda
,
m
.
def
(
"multi_tensor_sgd"
,
&
multi_tensor_sgd_cuda
,
"Fused SGD optimizer for list of contiguous tensors"
,
"Fused SGD optimizer for list of contiguous tensors"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"multi_tensor_compute_scale_and_scale_inv"
,
&
multi_tensor_compute_scale_and_scale_inv_cuda
,
"Fused compute scale and scale_inv from amax"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// Data structures
// Data structures
py
::
class_
<
transformer_engine
::
pytorch
::
FP8TensorMeta
>
(
m
,
"FP8TensorMeta"
)
py
::
class_
<
transformer_engine
::
pytorch
::
FP8TensorMeta
>
(
m
,
"FP8TensorMeta"
)
...
...
transformer_engine/pytorch/csrc/extensions/recipe.cpp
View file @
035c48c0
...
@@ -12,10 +12,27 @@
...
@@ -12,10 +12,27 @@
#include "common/common.h"
#include "common/common.h"
#include "extensions.h"
#include "extensions.h"
void
fused_amax_and_scale_update_after_reduction
(
const
at
::
Tensor
&
amax_reduction_buffer
,
void
compute_amax
(
const
at
::
Tensor
&
tensor
,
at
::
Tensor
&
amax
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
auto
input_tensor
=
tensor
.
contiguous
();
const
TensorWrapper
&
te_input
=
makeTransformerEngineTensor
(
input_tensor
);
TORCH_CHECK
(
amax
.
scalar_type
()
==
at
::
kFloat
,
"amax must be a float tensor"
);
TORCH_CHECK
(
amax
.
numel
()
==
1
,
"amax must have exactly one element"
);
TensorWrapper
fake_te_output
(
nullptr
,
te_input
.
shape
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
// It doesn't matter because we only compute amax.
amax
.
data_ptr
<
float
>
());
nvte_compute_amax
(
te_input
.
data
(),
fake_te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
}
void
fused_amax_and_scale_update_after_reduction
(
const
at
::
Tensor
&
amax_reduction_buffer
,
std
::
vector
<
at
::
Tensor
>
amax_histories
,
std
::
vector
<
at
::
Tensor
>
amax_histories
,
std
::
vector
<
at
::
Tensor
>
scales
,
std
::
vector
<
at
::
Tensor
>
scales
,
const
std
::
string
&
amax_compute_algo
,
const
std
::
string
&
amax_compute_algo
,
transformer_engine
::
DType
fp8_dtype
,
transformer_engine
::
DType
fp8_dtype
,
float
margin
)
{
float
margin
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
...
transformer_engine/pytorch/fp8.py
View file @
035c48c0
...
@@ -100,6 +100,7 @@ class FP8GlobalStateManager:
...
@@ -100,6 +100,7 @@ class FP8GlobalStateManager:
FP8_RECIPE
=
None
FP8_RECIPE
=
None
FP8_DISTRIBUTED_GROUP
=
None
FP8_DISTRIBUTED_GROUP
=
None
FP8_PARAMETERS
=
False
FP8_PARAMETERS
=
False
HIGH_PRECISION_INIT_VAL
=
False
IS_FIRST_FP8_MODULE
=
False
IS_FIRST_FP8_MODULE
=
False
FP8_GRAPH_CAPTURING
=
False
FP8_GRAPH_CAPTURING
=
False
FP8_AUTOCAST_DEPTH
=
0
FP8_AUTOCAST_DEPTH
=
0
...
@@ -124,6 +125,7 @@ class FP8GlobalStateManager:
...
@@ -124,6 +125,7 @@ class FP8GlobalStateManager:
cls
.
FP8_RECIPE
=
None
cls
.
FP8_RECIPE
=
None
cls
.
FP8_DISTRIBUTED_GROUP
=
None
cls
.
FP8_DISTRIBUTED_GROUP
=
None
cls
.
FP8_PARAMETERS
=
False
cls
.
FP8_PARAMETERS
=
False
cls
.
HIGH_PRECISION_INIT_VAL
=
False
cls
.
IS_FIRST_FP8_MODULE
=
False
cls
.
IS_FIRST_FP8_MODULE
=
False
cls
.
FP8_GRAPH_CAPTURING
=
False
cls
.
FP8_GRAPH_CAPTURING
=
False
cls
.
FP8_AUTOCAST_DEPTH
=
0
cls
.
FP8_AUTOCAST_DEPTH
=
0
...
@@ -274,6 +276,11 @@ class FP8GlobalStateManager:
...
@@ -274,6 +276,11 @@ class FP8GlobalStateManager:
"""Should the parameters be stored as FP8"""
"""Should the parameters be stored as FP8"""
return
cls
.
FP8_PARAMETERS
return
cls
.
FP8_PARAMETERS
@
classmethod
def
with_high_precision_init_val
(
cls
)
->
bool
:
"""Should the high precision initial values be stored with FP8 parameters"""
return
cls
.
HIGH_PRECISION_INIT_VAL
@
classmethod
@
classmethod
def
fp8_graph_capturing
(
cls
)
->
bool
:
def
fp8_graph_capturing
(
cls
)
->
bool
:
"""Is CUDA graph capture under way?"""
"""Is CUDA graph capture under way?"""
...
@@ -507,7 +514,11 @@ class FP8GlobalStateManager:
...
@@ -507,7 +514,11 @@ class FP8GlobalStateManager:
@
contextmanager
@
contextmanager
def
fp8_model_init
(
enabled
:
bool
=
True
,
recipe
:
Optional
[
Recipe
]
=
None
)
->
None
:
def
fp8_model_init
(
enabled
:
bool
=
True
,
recipe
:
Optional
[
Recipe
]
=
None
,
preserve_high_precision_init_val
:
bool
=
False
,
)
->
None
:
"""
"""
Context manager for FP8 initialization of parameters.
Context manager for FP8 initialization of parameters.
...
@@ -518,6 +529,12 @@ def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> Non
...
@@ -518,6 +529,12 @@ def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> Non
with fp8_model_init(enabled=True):
with fp8_model_init(enabled=True):
model = transformer_engine.pytorch.Linear(768, 768)
model = transformer_engine.pytorch.Linear(768, 768)
# Preserving high precision initial value to initialize master weight
with fp8_model_init(enabled=True, preserve_high_precision_init_val=True):
model = transformer_engine.pytorch.Linear(768, 768)
master_weight = model.weight.get_high_precision_init_val()
model.weight.clear_high_precision_init_val()
Parameters
Parameters
----------
----------
enabled: bool, default = `True`
enabled: bool, default = `True`
...
@@ -533,18 +550,29 @@ def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> Non
...
@@ -533,18 +550,29 @@ def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> Non
* LoRA-like fine-tuning, where the main parameters of the model do not change.
* LoRA-like fine-tuning, where the main parameters of the model do not change.
recipe: transformer_engine.common.recipe.Recipe, default = `None`
recipe: transformer_engine.common.recipe.Recipe, default = `None`
Recipe used to create the parameters. If left to None, it uses the default FP8 recipe.
Recipe used to create the parameters. If left to None, it uses the default FP8 recipe.
preserve_high_precision_init_val: bool, default = `False`
when enabled, store the high precision tensor used to initialize FP8 parameters
in CPU memory, and add two function attributes named `get_high_precision_init_val()`
and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high
precision tensor. The purpose is that users can use this high-precision copy
to initialize master weights, avoiding the loss of precision that can occur when
using FP8 parameters directly. Note that after the master weights are initialized,
users should call `clear_high_precision_init_val()` to release this CPU memory.
This functionality is *EXPERIMENTAL*.
This functionality is *EXPERIMENTAL*.
"""
"""
_fp8_parameters
=
FP8GlobalStateManager
.
FP8_PARAMETERS
_fp8_parameters
=
FP8GlobalStateManager
.
FP8_PARAMETERS
_fp8_recipe
=
FP8GlobalStateManager
.
FP8_RECIPE
_fp8_recipe
=
FP8GlobalStateManager
.
FP8_RECIPE
_high_precision_init_val
=
FP8GlobalStateManager
.
HIGH_PRECISION_INIT_VAL
FP8GlobalStateManager
.
FP8_PARAMETERS
=
enabled
FP8GlobalStateManager
.
FP8_PARAMETERS
=
enabled
FP8GlobalStateManager
.
FP8_RECIPE
=
get_default_fp8_recipe
()
if
recipe
is
None
else
recipe
FP8GlobalStateManager
.
FP8_RECIPE
=
get_default_fp8_recipe
()
if
recipe
is
None
else
recipe
FP8GlobalStateManager
.
HIGH_PRECISION_INIT_VAL
=
preserve_high_precision_init_val
try
:
try
:
yield
yield
finally
:
finally
:
FP8GlobalStateManager
.
FP8_PARAMETERS
=
_fp8_parameters
FP8GlobalStateManager
.
FP8_PARAMETERS
=
_fp8_parameters
FP8GlobalStateManager
.
FP8_RECIPE
=
_fp8_recipe
FP8GlobalStateManager
.
FP8_RECIPE
=
_fp8_recipe
FP8GlobalStateManager
.
HIGH_PRECISION_INIT_VAL
=
_high_precision_init_val
@
contextmanager
@
contextmanager
...
...
transformer_engine/pytorch/module/_common.py
View file @
035c48c0
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
"""Internal function used by multiple modules."""
"""Internal function used by multiple modules."""
import
os
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
,
Callable
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
,
Callable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
reduce
from
functools
import
reduce
...
@@ -16,9 +15,6 @@ from .. import cpp_extensions as tex
...
@@ -16,9 +15,6 @@ from .. import cpp_extensions as tex
from
..constants
import
TE_DType
from
..constants
import
TE_DType
from
..utils
import
get_default_init_method
from
..utils
import
get_default_init_method
from
..tensor.float8_tensor
import
Float8Tensor
from
..tensor.float8_tensor
import
Float8Tensor
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
_use_cudnn_mxfp8_norm
=
bool
(
int
(
os
.
getenv
(
"NVTE_CUDNN_MXFP8_NORM"
,
"0"
)))
def
_get_normalization_func
(
normalization
:
str
,
forward
:
bool
):
def
_get_normalization_func
(
normalization
:
str
,
forward
:
bool
):
...
@@ -86,26 +82,16 @@ def apply_normalization(
...
@@ -86,26 +82,16 @@ def apply_normalization(
inputs
=
(
inputmat
,
ln_weight
)
if
ln_bias
is
None
else
(
inputmat
,
ln_weight
,
ln_bias
)
inputs
=
(
inputmat
,
ln_weight
)
if
ln_bias
is
None
else
(
inputmat
,
ln_weight
,
ln_bias
)
split_mxfp8_cast
=
False
return
normalization_func
(
if
not
_use_cudnn_mxfp8_norm
and
isinstance
(
output_quantizer
,
MXFP8Quantizer
):
split_mxfp8_cast
=
True
output
=
normalization_func
(
*
inputs
,
*
inputs
,
eps
,
eps
,
None
if
split_mxfp8_cast
else
ln_out
,
ln_out
,
None
if
split_mxfp8_cast
else
output_quantizer
,
output_quantizer
,
TE_DType
[
output_dtype
]
if
output_dtype
in
TE_DType
else
output_dtype
,
TE_DType
[
output_dtype
]
if
output_dtype
in
TE_DType
else
output_dtype
,
fwd_ln_sm_margin
,
fwd_ln_sm_margin
,
zero_centered_gamma
,
zero_centered_gamma
,
)
)
return
(
(
output_quantizer
.
quantize
(
output
[
0
],
out
=
ln_out
),
*
output
[
1
:])
if
split_mxfp8_cast
else
output
)
class
_NoopCatFunc
(
torch
.
autograd
.
Function
):
class
_NoopCatFunc
(
torch
.
autograd
.
Function
):
"""Concatenate tensors, doing a no-op if possible
"""Concatenate tensors, doing a no-op if possible
...
...
transformer_engine/pytorch/module/base.py
View file @
035c48c0
...
@@ -10,6 +10,7 @@ import warnings
...
@@ -10,6 +10,7 @@ import warnings
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
types
import
MethodType
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -424,6 +425,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -424,6 +425,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self
.
sequence_parallel
=
False
self
.
sequence_parallel
=
False
self
.
param_init_meta
=
{}
self
.
param_init_meta
=
{}
self
.
primary_weights_in_fp8
=
FP8GlobalStateManager
.
with_fp8_parameters
()
self
.
primary_weights_in_fp8
=
FP8GlobalStateManager
.
with_fp8_parameters
()
self
.
preserve_high_precision_init_val
=
FP8GlobalStateManager
.
with_high_precision_init_val
()
self
.
fsdp_wrapped
=
False
self
.
fsdp_wrapped
=
False
self
.
fsdp_group
=
None
self
.
fsdp_group
=
None
self
.
_fp8_workspaces
:
Dict
[
str
,
QuantizedTensor
]
=
{}
self
.
_fp8_workspaces
:
Dict
[
str
,
QuantizedTensor
]
=
{}
...
@@ -921,7 +923,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -921,7 +923,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# If primary weights are in fp8, wrap the parameter as FP8Tensor
# If primary weights are in fp8, wrap the parameter as FP8Tensor
fp8_meta_index
=
self
.
param_init_meta
[
name
].
fp8_meta_index
fp8_meta_index
=
self
.
param_init_meta
[
name
].
fp8_meta_index
high_precision_init_val
=
None
if
self
.
primary_weights_in_fp8
and
fp8_meta_index
is
not
None
:
if
self
.
primary_weights_in_fp8
and
fp8_meta_index
is
not
None
:
if
self
.
preserve_high_precision_init_val
:
high_precision_init_val
=
param
.
detach
().
cpu
()
quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
fp8_meta_index
]
quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
fp8_meta_index
]
assert
(
assert
(
quantizer
is
not
None
quantizer
is
not
None
...
@@ -933,7 +939,34 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -933,7 +939,34 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# NOTE: Currently this can only be broken when primary weights are in Fp8 but
# NOTE: Currently this can only be broken when primary weights are in Fp8 but
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety.
# a parameter so we always re-apply it just for extra safety.
setattr
(
self
,
name
,
torch
.
nn
.
Parameter
(
param
))
param
=
torch
.
nn
.
Parameter
(
param
)
if
high_precision_init_val
is
not
None
:
# - Master weights are initialized from model weights, if we use fp8 primary
# weights to initialize master weights, the numerical values of master weights
# are not consistent with the numerical values when we initialize them from
# bf16/fp16 weights.
# - So we add a `_high_precision_init_val` attribute to each model weight to store
# the original bf16/fp16 weight on cpu before casting it to fp8. And users can
# use `get_high_precision_init_val` to get this cpu tensor.
# - This cpu tensor is not needed once the master weight is initialized, so users
# should call `clear_high_precision_init_val` to remove it after master weight
# is initialized.
def
get
(
self
):
if
hasattr
(
self
,
"_high_precision_init_val"
):
return
self
.
_high_precision_init_val
return
None
def
clear
(
self
):
if
hasattr
(
self
,
"_high_precision_init_val"
):
del
self
.
_high_precision_init_val
param
.
_high_precision_init_val
=
high_precision_init_val
param
.
get_high_precision_init_val
=
MethodType
(
get
,
param
)
param
.
clear_high_precision_init_val
=
MethodType
(
clear
,
param
)
setattr
(
self
,
name
,
param
)
@
abstractmethod
@
abstractmethod
def
forward
(
self
):
def
forward
(
self
):
...
@@ -972,6 +1005,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -972,6 +1005,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
FSDP process group that the weights are distributed over.
FSDP process group that the weights are distributed over.
"""
"""
# FP8 primary weights
if
isinstance
(
tensor
,
QuantizedTensor
):
if
update_workspace
and
quantizer
is
not
None
:
tensor
.
update_usage
(
rowwise_usage
=
quantizer
.
rowwise_usage
,
columnwise_usage
=
quantizer
.
columnwise_usage
,
)
return
tensor
# Try getting workspace from cache
# Try getting workspace from cache
out
=
None
out
=
None
if
cache_name
is
not
None
:
if
cache_name
is
not
None
:
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
035c48c0
...
@@ -130,20 +130,17 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -130,20 +130,17 @@ class _GroupedLinear(torch.autograd.Function):
)
)
weights_fp8
=
[]
weights_fp8
=
[]
bias_dtype
=
torch
.
bfloat16
if
activation_dtype
==
torch
.
float32
else
activation_dtype
bias_dtype
=
torch
.
bfloat16
if
activation_dtype
==
torch
.
float32
else
activation_dtype
if
not
isinstance
(
weights
[
0
],
QuantizedTensor
):
# FP8 cast to workspace buffer
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
for
i
in
range
(
num_gemms
):
for
i
in
range
(
num_gemms
):
weight_fp8
=
module
.
get_weight_workspace
(
weight_fp8
=
module
.
get_weight_workspace
(
tensor
=
weights
[
i
],
tensor
=
weights
[
i
],
quantizer
=
weight_quantizers
[
i
],
quantizer
=
weight_quantizers
[
i
],
cache_name
=
(
None
if
is_first_microbatch
is
None
else
f
"weight
{
i
}
"
),
cache_name
=
(
None
if
is_first_microbatch
is
None
else
f
"weight
{
i
}
"
),
update_workspace
=
update_workspace
,
update_workspace
=
update_workspace
,
skip_update_flag
=
skip_fp8_weight_update
,
skip_update_flag
=
skip_fp8_weight_update
,
)
)
weights_fp8
.
append
(
weight_fp8
)
weights_fp8
.
append
(
weight_fp8
)
else
:
weights_fp8
=
weights
else
:
else
:
inputmats
=
inputmats_no_fp8
inputmats
=
inputmats_no_fp8
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
035c48c0
...
@@ -55,9 +55,9 @@ from ..tensor.quantized_tensor import (
...
@@ -55,9 +55,9 @@ from ..tensor.quantized_tensor import (
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
)
)
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
from
..cpu_offload
import
is_cpu_offload_enabled
,
set_offloading_param
from
..cpu_offload
import
is_cpu_offload_enabled
,
set_offloading_param
from
..cpp_extensions
import
(
from
..cpp_extensions
import
(
general_gemm
,
general_gemm
,
...
@@ -160,11 +160,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -160,11 +160,6 @@ class _LayerNormLinear(torch.autograd.Function):
# Configure quantizer for normalization output
# Configure quantizer for normalization output
with_quantized_norm
=
fp8
and
not
return_layernorm_output
with_quantized_norm
=
fp8
and
not
return_layernorm_output
# for Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer
# so we need to set with_quantized_norm to False
if
isinstance
(
input_quantizer
,
Float8CurrentScalingQuantizer
):
with_quantized_norm
=
False
if
with_quantized_norm
:
if
with_quantized_norm
:
if
with_input_all_gather
:
if
with_input_all_gather
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
...
@@ -261,28 +256,26 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -261,28 +256,26 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm_input_cast_comm"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm_input_cast_comm"
)
# Cast weight to expected dtype
# Cast weight to expected dtype
weightmat
=
weight
quantized_weight
=
False
if
not
fp8
:
if
not
fp8
:
weightmat
=
cast_if_needed
(
weightmat
,
activation_dtype
)
quantized_weight
=
False
weightmat
=
cast_if_needed
(
weight
,
activation_dtype
)
else
:
else
:
if
not
isinstance
(
weight
,
QuantizedTensor
):
quantized_weight
=
not
isinstance
(
weight
,
QuantizedTensor
)
quantized_weight
=
True
# Configure quantizer
# Configure quantizer
if
weight_quantizer
is
not
None
:
if
weight_quantizer
is
not
None
:
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
True
)
# FP8 cast to workspace buffer
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
weightmat
=
module
.
get_weight_workspace
(
weightmat
=
module
.
get_weight_workspace
(
tensor
=
weight
,
tensor
=
weight
,
quantizer
=
weight_quantizer
,
quantizer
=
weight_quantizer
,
cache_name
=
(
None
if
is_first_microbatch
is
None
else
"weight"
),
cache_name
=
(
None
if
is_first_microbatch
is
None
else
"weight"
),
update_workspace
=
update_workspace
,
update_workspace
=
update_workspace
,
skip_update_flag
=
skip_fp8_weight_update
,
skip_update_flag
=
skip_fp8_weight_update
,
fsdp_group
=
fsdp_group
,
fsdp_group
=
fsdp_group
,
)
)
# Cast bias to expected dtype
# Cast bias to expected dtype
bias_dtype
=
activation_dtype
bias_dtype
=
activation_dtype
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment