Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ab3e5a92
Commit
ab3e5a92
authored
May 09, 2025
by
yuguo
Browse files
Merge commit '
04c730c0
' of...
Merge commit '
04c730c0
' of
https://github.com/NVIDIA/TransformerEngine
parents
a8d19fd9
04c730c0
Changes
174
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1243 additions
and
274 deletions
+1243
-274
transformer_engine/pytorch/ops/op.py
transformer_engine/pytorch/ops/op.py
+51
-35
transformer_engine/pytorch/optimizers/fused_adam.py
transformer_engine/pytorch/optimizers/fused_adam.py
+13
-5
transformer_engine/pytorch/permutation.py
transformer_engine/pytorch/permutation.py
+156
-146
transformer_engine/pytorch/tensor/__init__.py
transformer_engine/pytorch/tensor/__init__.py
+24
-0
transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
.../pytorch/tensor/_internal/float8_blockwise_tensor_base.py
+246
-0
transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py
...mer_engine/pytorch/tensor/_internal/float8_tensor_base.py
+12
-2
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
+614
-0
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+0
-7
transformer_engine/pytorch/tensor/mxfp8_tensor.py
transformer_engine/pytorch/tensor/mxfp8_tensor.py
+4
-1
transformer_engine/pytorch/tensor/quantized_tensor.py
transformer_engine/pytorch/tensor/quantized_tensor.py
+10
-2
transformer_engine/pytorch/tensor/utils.py
transformer_engine/pytorch/tensor/utils.py
+46
-14
transformer_engine/pytorch/transformer.py
transformer_engine/pytorch/transformer.py
+13
-0
transformer_engine/pytorch/triton/permutation.py
transformer_engine/pytorch/triton/permutation.py
+40
-62
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+14
-0
No files found.
transformer_engine/pytorch/ops/op.py
View file @
ab3e5a92
...
@@ -17,8 +17,10 @@ from transformer_engine.common.recipe import Recipe
...
@@ -17,8 +17,10 @@ from transformer_engine.common.recipe import Recipe
from
..fp8
import
(
from
..fp8
import
(
MXFP8BlockScalingRecipeState
,
MXFP8BlockScalingRecipeState
,
DelayedScalingRecipeState
,
DelayedScalingRecipeState
,
Float8BlockScalingRecipeState
,
FP8GlobalStateManager
,
FP8GlobalStateManager
,
RecipeState
,
RecipeState
,
fp8_autocast
,
)
)
from
..tensor
import
Quantizer
from
..tensor
import
Quantizer
...
@@ -218,6 +220,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
...
@@ -218,6 +220,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
if
num_quantizers
==
0
:
if
num_quantizers
==
0
:
continue
continue
if
recipe
.
float8_block_scaling
():
raise
NotImplementedError
(
"Fusible operations do not support FP8 block scaling recipe"
)
# Construct quantization recipe state
# Construct quantization recipe state
recipe_state
=
RecipeState
.
create
(
recipe_state
=
RecipeState
.
create
(
recipe
,
recipe
,
...
@@ -259,8 +266,13 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
...
@@ -259,8 +266,13 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
continue
continue
recipe_state
=
self
.
_fp8_metas
[
mode
][
fp8_meta_key
]
recipe_state
=
self
.
_fp8_metas
[
mode
][
fp8_meta_key
]
need_to_reset_recipe_state
=
(
need_to_reset_recipe_state
=
(
recipe
.
delayed
()
and
not
isinstance
(
recipe_state
,
DelayedScalingRecipeState
)
(
recipe
.
delayed
()
and
not
isinstance
(
recipe_state
,
DelayedScalingRecipeState
))
)
or
(
recipe
.
mxfp8
()
and
not
isinstance
(
recipe_state
,
MXFP8BlockScalingRecipeState
))
or
(
recipe
.
mxfp8
()
and
not
isinstance
(
recipe_state
,
MXFP8BlockScalingRecipeState
))
or
(
recipe
.
float8_block_scaling
()
and
not
isinstance
(
recipe_state
,
Float8BlockScalingRecipeState
)
)
)
if
need_to_reset_recipe_state
:
if
need_to_reset_recipe_state
:
self
.
_reset_quantization_recipe_state
(
recipe
=
recipe
)
self
.
_reset_quantization_recipe_state
(
recipe
=
recipe
)
return
return
...
@@ -508,7 +520,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
...
@@ -508,7 +520,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
def
get_extra_state
(
self
)
->
torch
.
Tensor
:
def
get_extra_state
(
self
)
->
torch
.
Tensor
:
"""Serialize extra state
"""Serialize extra state
Contains metadata for
FP8 casting
.
Contains metadata for
quantization recipe
.
"""
"""
...
@@ -540,21 +552,25 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
...
@@ -540,21 +552,25 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
dst
.
copy_
(
src
,
non_blocking
=
True
)
dst
.
copy_
(
src
,
non_blocking
=
True
)
return
dst
return
dst
# Store
FP8 state
# Store
quantizer state if needed
state
=
{}
state
=
{}
for
mode
in
(
"forward"
,
"backward"
):
for
mode
in
(
"forward"
,
"backward"
):
#
Get state for a given FP8 tensor
#
Skip if op has no quantizer state
if
self
.
num_quantizers
(
mode
)
==
0
:
if
self
.
_fp8_metas
is
None
or
self
.
_fp8_metas
[
mode
]
is
None
:
continue
continue
fp8_meta
=
self
.
get_fp8_meta
(
mode
)
# Quantizer state
fp8_meta
=
self
.
_fp8_metas
[
mode
]
state
[
mode
]
=
{}
state
[
mode
]
=
{}
state
[
mode
][
"recipe"
]
=
fp8_meta
[
"recipe"
]
# Store tensors
# Copy tensors to CPU and store
if
"scaling_fwd"
in
fp8_meta
:
if
state
[
mode
][
"recipe"
].
delayed
():
if
mode
==
"forward"
:
state
[
mode
][
"scale_fwd"
]
=
to_cpu
(
fp8_meta
[
"scaling_fwd"
].
scale
)
state
[
mode
][
"scale_fwd"
]
=
to_cpu
(
fp8_meta
[
"scaling_fwd"
].
scale
)
state
[
mode
][
"amax_history_fwd"
]
=
to_cpu
(
fp8_meta
[
"scaling_fwd"
].
amax_history
)
state
[
mode
][
"amax_history_fwd"
]
=
to_cpu
(
fp8_meta
[
"scaling_fwd"
].
amax_history
)
if
"scaling_bwd"
in
fp8_meta
:
if
mode
==
"backward"
:
state
[
mode
][
"scale_bwd"
]
=
to_cpu
(
fp8_meta
[
"scaling_bwd"
].
scale
)
state
[
mode
][
"scale_bwd"
]
=
to_cpu
(
fp8_meta
[
"scaling_bwd"
].
scale
)
state
[
mode
][
"amax_history_bwd"
]
=
to_cpu
(
fp8_meta
[
"scaling_bwd"
].
amax_history
)
state
[
mode
][
"amax_history_bwd"
]
=
to_cpu
(
fp8_meta
[
"scaling_bwd"
].
amax_history
)
...
@@ -595,37 +611,37 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
...
@@ -595,37 +611,37 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
dst
.
data
=
torch
.
empty
(
src
.
size
(),
dtype
=
dst
.
dtype
,
device
=
dst
.
device
)
dst
.
data
=
torch
.
empty
(
src
.
size
(),
dtype
=
dst
.
dtype
,
device
=
dst
.
device
)
dst
.
copy_
(
src
,
non_blocking
=
True
)
dst
.
copy_
(
src
,
non_blocking
=
True
)
# Load
FP8 state
# Load
quantizer state if needed
for
mode
in
(
"forward"
,
"backward"
):
for
mode
in
(
"forward"
,
"backward"
):
#
Get state for a given FP8 tensor
#
Skip if checkpoint has no quantizer state
if
mode
not
in
state
:
if
mode
not
in
state
:
continue
continue
if
self
.
num_quantizers
(
mode
)
==
0
:
continue
fp8_meta
=
self
.
get_fp8_meta
(
mode
)
if
fp8_meta
is
None
:
continue
# Load extra state
# Get op's quantizer state, initializing if needed
if
self
.
_fp8_metas
is
None
or
self
.
_fp8_metas
[
mode
]
is
None
:
with
fp8_autocast
(
fp8_recipe
=
state
[
mode
][
"recipe"
]):
self
.
_reset_quantization_recipe_state
()
fp8_meta
=
self
.
_fp8_metas
[
mode
]
# Load extra items
fp8_meta
[
"recipe"
]
=
state
[
mode
][
"recipe"
]
fp8_meta
.
update
(
state
[
mode
][
"extra_fp8_variables"
])
fp8_meta
.
update
(
state
[
mode
][
"extra_fp8_variables"
])
if
"amax_history_fwd"
in
state
[
mode
]:
fp8_meta
[
"recipe"
].
amax_history_len
=
state
[
mode
][
"amax_history_fwd"
].
size
(
0
)
elif
"amax_history_bwd"
in
state
[
mode
]:
fp8_meta
[
"recipe"
].
amax_history_len
=
state
[
mode
][
"amax_history_bwd"
].
size
(
0
)
if
"global_fp8_buffer_pos_fwd_recompute"
in
fp8_meta
:
if
"global_fp8_buffer_pos_fwd_recompute"
in
fp8_meta
:
del
fp8_meta
[
"global_fp8_buffer_pos_fwd_recompute"
]
del
fp8_meta
[
"global_fp8_buffer_pos_fwd_recompute"
]
# Load tensors
# Load tensors
fp8_meta
=
self
.
get_fp8_meta
(
mode
)
if
state
[
mode
][
"recipe"
].
delayed
():
if
"scaling_fwd"
in
fp8_meta
:
if
mode
==
"forward"
:
fp8_meta_fwd
=
fp8_meta
[
"scaling_fwd"
]
copy_tensor
(
state
[
mode
][
"scale_fwd"
],
fp8_meta
[
"scaling_fwd"
].
scale
)
copy_tensor
(
state
[
mode
][
"scale_fwd"
],
fp8_meta_fwd
.
scale
)
copy_tensor
(
copy_tensor
(
state
[
mode
][
"amax_history_fwd"
],
fp8_meta_fwd
.
amax_history
)
state
[
mode
][
"amax_history_fwd"
],
fp8_meta
[
"scaling_fwd"
].
amax_history
if
"scaling_bwd"
in
fp8_meta
:
)
fp8_meta_bwd
=
fp8_meta
[
"scaling_bwd"
]
if
mode
==
"backward"
:
copy_tensor
(
state
[
mode
][
"scale_bwd"
],
fp8_meta_bwd
.
scale
)
copy_tensor
(
state
[
mode
][
"scale_bwd"
],
fp8_meta
[
"scaling_bwd"
].
scale
)
copy_tensor
(
state
[
mode
][
"amax_history_bwd"
],
fp8_meta_bwd
.
amax_history
)
copy_tensor
(
state
[
mode
][
"amax_history_bwd"
],
fp8_meta
[
"scaling_bwd"
].
amax_history
)
# Finish CPU-GPU memory transfers
# Finish CPU-GPU memory transfers
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
...
transformer_engine/pytorch/optimizers/fused_adam.py
View file @
ab3e5a92
...
@@ -133,10 +133,10 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -133,10 +133,10 @@ class FusedAdam(torch.optim.Optimizer):
# Add constraints to dtypes of states.
# Add constraints to dtypes of states.
if
master_weights
and
master_weight_dtype
not
in
[
torch
.
float32
,
torch
.
float16
]:
if
master_weights
and
master_weight_dtype
not
in
[
torch
.
float32
,
torch
.
float16
]:
raise
RuntimeError
(
"FusedAdam only supports fp32/fp16 master weights."
)
raise
RuntimeError
(
"FusedAdam only supports fp32/fp16 master weights."
)
if
exp_avg_dtype
not
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
uint8
]:
if
exp_avg_dtype
not
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
uint8
]:
raise
RuntimeError
(
"FusedAdam only supports fp32/fp16/fp8 exp_avg."
)
raise
RuntimeError
(
"FusedAdam only supports fp32/fp16/
bf16/
fp8 exp_avg."
)
if
exp_avg_sq_dtype
not
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
uint8
]:
if
exp_avg_sq_dtype
not
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
uint8
]:
raise
RuntimeError
(
"FusedAdam only supports fp32/fp16/fp8 exp_avg_sq."
)
raise
RuntimeError
(
"FusedAdam only supports fp32/fp16/
bf16/
fp8 exp_avg_sq."
)
# Currently, capturable mode only supports fp32 master weights and optimizer states.
# Currently, capturable mode only supports fp32 master weights and optimizer states.
# The reason is, if the master weights or optimizer states are not in fp32 dtype,
# The reason is, if the master weights or optimizer states are not in fp32 dtype,
...
@@ -259,6 +259,10 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -259,6 +259,10 @@ class FusedAdam(torch.optim.Optimizer):
scale (torch.Tensor): A FP32 tensor representing the scaling factor.
scale (torch.Tensor): A FP32 tensor representing the scaling factor.
"""
"""
assert
unscaled_state
.
dtype
==
torch
.
float32
assert
unscaled_state
.
dtype
==
torch
.
float32
if
scaled_state
.
dtype
==
torch
.
bfloat16
:
scaled_state
.
copy_
(
unscaled_state
.
bfloat16
())
return
dtype
=
self
.
name_to_dtype_map
[
state_name
]
dtype
=
self
.
name_to_dtype_map
[
state_name
]
if
dtype
==
torch
.
uint8
:
if
dtype
==
torch
.
uint8
:
assert
isinstance
(
scaled_state
,
Float8Tensor
)
assert
isinstance
(
scaled_state
,
Float8Tensor
)
...
@@ -313,8 +317,11 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -313,8 +317,11 @@ class FusedAdam(torch.optim.Optimizer):
else
:
else
:
assert
state
[
state_name
].
dtype
==
torch
.
float32
assert
state
[
state_name
].
dtype
==
torch
.
float32
unscaled
=
state
[
state_name
]
unscaled
=
state
[
state_name
]
elif
dtype
==
torch
.
bfloat16
:
assert
state
[
state_name
].
dtype
==
torch
.
bfloat16
unscaled
=
state
[
state_name
].
float
()
else
:
else
:
raise
RuntimeError
(
f
"Dtype of
{
state_name
}
can only be fp8/fp16/fp32."
)
raise
RuntimeError
(
f
"Dtype of
{
state_name
}
can only be fp8/fp16/
bf16/
fp32."
)
return
unscaled
return
unscaled
def
set_scaled_state
(
self
,
param
,
state_name
,
unscaled_state
):
def
set_scaled_state
(
self
,
param
,
state_name
,
unscaled_state
):
...
@@ -329,6 +336,7 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -329,6 +336,7 @@ class FusedAdam(torch.optim.Optimizer):
and 'master_param`.
and 'master_param`.
unscaled_state (torch.Tensor): The original high-precision(FP32) state.
unscaled_state (torch.Tensor): The original high-precision(FP32) state.
"""
"""
store_param_remainders
=
(
store_param_remainders
=
(
self
.
store_param_remainders
self
.
store_param_remainders
and
state_name
==
"master_param"
and
state_name
==
"master_param"
...
...
transformer_engine/pytorch/permutation.py
View file @
ab3e5a92
...
@@ -4,14 +4,16 @@
...
@@ -4,14 +4,16 @@
"""MoE Permutaion API"""
"""MoE Permutaion API"""
import
warnings
import
warnings
from
typing
import
Tuple
from
typing
import
Optional
,
Tuple
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
import
transformer_engine.pytorch.triton.permutation
as
triton_permutation
import
transformer_engine.pytorch.triton.permutation
as
triton_permutation
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.tensor.quantized_tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
Float8BlockwiseQTensor
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
__all__
=
[
__all__
=
[
"moe_permute"
,
"moe_permute"
,
...
@@ -46,16 +48,6 @@ class _moe_permute_index_map(torch.autograd.Function):
...
@@ -46,16 +48,6 @@ class _moe_permute_index_map(torch.autograd.Function):
assert
inp
.
size
(
0
)
==
index
.
size
(
0
),
"Permute not possible"
assert
inp
.
size
(
0
)
==
index
.
size
(
0
),
"Permute not possible"
# Data type check
# Data type check
fp8
=
isinstance
(
inp
,
Float8Tensor
)
if
fp8
:
assert
(
inp
.
_quantizer
.
scale
.
ndim
==
0
),
"Only one factor scaling per tensor (Delayed Scaling) supported by moe_permute."
dtype
=
inp
.
_fp8_dtype
fp8_scale_inv
=
inp
.
_scale_inv
fake_dtype
=
inp
.
dtype
inp
=
inp
.
_data
else
:
dtype
=
TE_DType
[
inp
.
dtype
]
dtype
=
TE_DType
[
inp
.
dtype
]
if
index
.
dtype
!=
torch
.
int32
:
if
index
.
dtype
!=
torch
.
int32
:
warnings
.
warn
(
warnings
.
warn
(
...
@@ -80,19 +72,9 @@ class _moe_permute_index_map(torch.autograd.Function):
...
@@ -80,19 +72,9 @@ class _moe_permute_index_map(torch.autograd.Function):
_moe_permute_index_map
.
max_expanded_token_num
,
_moe_permute_index_map
.
max_expanded_token_num
,
)
)
if
fp8
:
permuted_act
=
Float8Tensor
(
data
=
permuted_act
,
fp8_dtype
=
dtype
,
fp8_scale_inv
=
fp8_scale_inv
,
shape
=
permuted_act
.
shape
,
dtype
=
fake_dtype
,
)
ctx
.
row_id_map
=
row_id_map
ctx
.
row_id_map
=
row_id_map
ctx
.
num_tokens
=
index
.
size
(
0
)
ctx
.
num_tokens
=
index
.
size
(
0
)
ctx
.
topK
=
index
.
size
(
1
)
ctx
.
topK
=
index
.
size
(
1
)
ctx
.
fp8
=
fp8
return
permuted_act
,
row_id_map
return
permuted_act
,
row_id_map
@
staticmethod
@
staticmethod
...
@@ -109,30 +91,12 @@ class _moe_permute_index_map(torch.autograd.Function):
...
@@ -109,30 +91,12 @@ class _moe_permute_index_map(torch.autograd.Function):
if
not
permuted_act_grad
.
is_contiguous
():
if
not
permuted_act_grad
.
is_contiguous
():
permuted_act_grad
=
permuted_act_grad
.
contiguous
()
permuted_act_grad
=
permuted_act_grad
.
contiguous
()
if
ctx
.
fp8
:
assert
isinstance
(
permuted_act_grad
,
Float8Tensor
),
"Grad of the output must be in Float8Tensor type for FP8 moe_permute."
dtype
=
permuted_act_grad
.
_fp8_dtype
fp8_scale_inv
=
permuted_act_grad
.
_scale_inv
fake_dtype
=
permuted_act_grad
.
dtype
permuted_act_grad
=
permuted_act_grad
.
_data
else
:
dtype
=
TE_DType
[
permuted_act_grad
.
dtype
]
dtype
=
TE_DType
[
permuted_act_grad
.
dtype
]
act_grad
=
None
act_grad
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
act_grad
=
tex
.
moe_permute_bwd
(
act_grad
=
tex
.
moe_permute_bwd
(
permuted_act_grad
,
dtype
,
ctx
.
row_id_map
,
torch
.
empty
(
0
),
ctx
.
num_tokens
,
ctx
.
topK
permuted_act_grad
,
dtype
,
ctx
.
row_id_map
,
torch
.
empty
(
0
),
ctx
.
num_tokens
,
ctx
.
topK
)
)
if
ctx
.
fp8
:
act_grad
=
Float8Tensor
(
data
=
act_grad
,
fp8_dtype
=
dtype
,
fp8_scale_inv
=
fp8_scale_inv
*
ctx
.
topK
,
shape
=
act_grad
.
shape
,
dtype
=
fake_dtype
,
)
return
act_grad
,
None
,
None
,
None
return
act_grad
,
None
,
None
,
None
...
@@ -176,13 +140,6 @@ class _moe_unpermute_index_map(torch.autograd.Function):
...
@@ -176,13 +140,6 @@ class _moe_unpermute_index_map(torch.autograd.Function):
assert
row_id_map
.
is_cuda
,
"TransformerEngine needs CUDA."
assert
row_id_map
.
is_cuda
,
"TransformerEngine needs CUDA."
# Data type check
# Data type check
fp8
=
isinstance
(
inp
,
Float8Tensor
)
if
fp8
:
dtype
=
inp
.
_fp8_dtype
fp8_scale_inv
=
inp
.
_scale_inv
fake_dtype
=
inp
.
dtype
inp
=
inp
.
_data
else
:
dtype
=
TE_DType
[
inp
.
dtype
]
dtype
=
TE_DType
[
inp
.
dtype
]
if
row_id_map
.
dtype
!=
torch
.
int32
:
if
row_id_map
.
dtype
!=
torch
.
int32
:
warnings
.
warn
(
warnings
.
warn
(
...
@@ -193,17 +150,7 @@ class _moe_unpermute_index_map(torch.autograd.Function):
...
@@ -193,17 +150,7 @@ class _moe_unpermute_index_map(torch.autograd.Function):
unpermuted_output
=
tex
.
moe_unpermute_fwd
(
inp
,
dtype
,
row_id_map
,
probs
,
num_tokens
,
topK
)
unpermuted_output
=
tex
.
moe_unpermute_fwd
(
inp
,
dtype
,
row_id_map
,
probs
,
num_tokens
,
topK
)
if
fp8
:
unpermuted_output
=
Float8Tensor
(
data
=
unpermuted_output
,
fp8_dtype
=
dtype
,
fp8_scale_inv
=
fp8_scale_inv
,
shape
=
unpermuted_output
.
shape
,
dtype
=
fake_dtype
,
)
ctx
.
save_for_backward
(
inp
,
row_id_map
,
probs
)
ctx
.
save_for_backward
(
inp
,
row_id_map
,
probs
)
ctx
.
fp8
=
fp8
return
unpermuted_output
return
unpermuted_output
@
staticmethod
@
staticmethod
...
@@ -219,17 +166,7 @@ class _moe_unpermute_index_map(torch.autograd.Function):
...
@@ -219,17 +166,7 @@ class _moe_unpermute_index_map(torch.autograd.Function):
if
not
unpermuted_act_grad
.
is_contiguous
():
if
not
unpermuted_act_grad
.
is_contiguous
():
unpermuted_act_grad
=
unpermuted_act_grad
.
contiguous
()
unpermuted_act_grad
=
unpermuted_act_grad
.
contiguous
()
if
ctx
.
fp8
:
assert
isinstance
(
unpermuted_act_grad
,
Float8Tensor
),
"Grad of the output must be in Float8Tensor type for FP8 moe_unpermute."
dtype
=
unpermuted_act_grad
.
_fp8_dtype
fp8_scale_inv
=
unpermuted_act_grad
.
_scale_inv
fake_dtype
=
unpermuted_act_grad
.
dtype
unpermuted_act_grad
=
unpermuted_act_grad
.
_data
else
:
dtype
=
TE_DType
[
unpermuted_act_grad
.
dtype
]
dtype
=
TE_DType
[
unpermuted_act_grad
.
dtype
]
inp
,
row_id_map
,
probs
=
ctx
.
saved_tensors
inp
,
row_id_map
,
probs
=
ctx
.
saved_tensors
act_grad
=
None
act_grad
=
None
...
@@ -238,14 +175,6 @@ class _moe_unpermute_index_map(torch.autograd.Function):
...
@@ -238,14 +175,6 @@ class _moe_unpermute_index_map(torch.autograd.Function):
act_grad
,
prob_grad
=
tex
.
moe_unpermute_bwd
(
act_grad
,
prob_grad
=
tex
.
moe_unpermute_bwd
(
unpermuted_act_grad
,
inp
,
dtype
,
row_id_map
,
probs
unpermuted_act_grad
,
inp
,
dtype
,
row_id_map
,
probs
)
)
if
ctx
.
fp8
:
act_grad
=
Float8Tensor
(
data
=
act_grad
,
fp8_dtype
=
dtype
,
fp8_scale_inv
=
fp8_scale_inv
,
shape
=
act_grad
.
shape
,
dtype
=
fake_dtype
,
)
if
not
ctx
.
needs_input_grad
[
2
]:
if
not
ctx
.
needs_input_grad
[
2
]:
prob_grad
=
None
prob_grad
=
None
...
@@ -282,22 +211,54 @@ class _moe_permute_mask_map(torch.autograd.Function):
...
@@ -282,22 +211,54 @@ class _moe_permute_mask_map(torch.autograd.Function):
row_id_map
=
triton_permutation
.
make_row_id_map
(
routing_map
,
num_tokens
,
num_experts
)
row_id_map
=
triton_permutation
.
make_row_id_map
(
routing_map
,
num_tokens
,
num_experts
)
fp8
=
isinstance
(
inp
,
Float8Tensor
)
fp8
=
isinstance
(
inp
,
QuantizedTensor
)
per_tensor_recipe
=
isinstance
(
inp
,
Float8Tensor
)
blockwise_recipe
=
isinstance
(
inp
,
Float8BlockwiseQTensor
)
mxfp8_recipe
=
isinstance
(
inp
,
MXFP8Tensor
)
if
fp8
:
if
fp8
:
fp8_dtype
=
inp
.
_fp8_dtype
fp8_dtype
=
inp
.
_fp8_dtype
fp8_scale_inv
=
inp
.
_scale_inv
fake_dtype
=
inp
.
dtype
fake_dtype
=
inp
.
dtype
# blockwise scaling
if
blockwise_recipe
:
fp8_scale
=
inp
.
_rowwise_scale_inv
.
T
.
contiguous
()
scale_hidden_dim
=
fp8_scale
.
shape
[
1
]
assert
num_tokens
==
fp8_scale
.
shape
[
0
],
"scale and input shape mismatch"
inp
=
inp
.
_rowwise_data
# mxfp8 scaling
elif
mxfp8_recipe
:
fp8_scale
=
inp
.
_rowwise_scale_inv
.
contiguous
()
scale_hidden_dim
=
fp8_scale
.
shape
[
1
]
assert
num_tokens
==
fp8_scale
.
shape
[
0
],
"scale and input shape mismatch"
inp
=
inp
.
_rowwise_data
# per-tensor scaling
elif
per_tensor_recipe
:
# Kernel does not need scale in per-tensor scaling
fp8_scale
=
None
scale_hidden_dim
=
None
fp8_scale_inv
=
inp
.
_scale_inv
inp
=
inp
.
_data
inp
=
inp
.
_data
output
,
permuted_probs
=
triton_permutation
.
permute_with_mask_map
(
else
:
raise
ValueError
(
"Unsupported FP8 recipe"
)
else
:
fp8_scale
=
None
fp8_dtype
=
None
scale_hidden_dim
=
None
output
,
permuted_scale
,
permuted_probs
=
triton_permutation
.
permute_with_mask_map
(
inp
,
inp
,
row_id_map
,
row_id_map
,
probs
,
probs
,
fp8_scale
,
num_tokens
,
num_tokens
,
num_experts
,
num_experts
,
num_out_tokens
,
num_out_tokens
,
hidden_size
,
hidden_size
,
scale_hidden_dim
,
)
)
if
fp8
:
if
fp8
:
if
per_tensor_recipe
:
output
=
Float8Tensor
(
output
=
Float8Tensor
(
data
=
output
,
data
=
output
,
fp8_dtype
=
fp8_dtype
,
fp8_dtype
=
fp8_dtype
,
...
@@ -305,6 +266,31 @@ class _moe_permute_mask_map(torch.autograd.Function):
...
@@ -305,6 +266,31 @@ class _moe_permute_mask_map(torch.autograd.Function):
shape
=
output
.
shape
,
shape
=
output
.
shape
,
dtype
=
fake_dtype
,
dtype
=
fake_dtype
,
)
)
elif
blockwise_recipe
:
output
=
Float8BlockwiseQTensor
(
shape
=
output
.
shape
,
dtype
=
fake_dtype
,
rowwise_data
=
output
,
rowwise_scale_inv
=
permuted_scale
.
T
.
contiguous
(),
columnwise_data
=
None
,
columnwise_scale_inv
=
None
,
fp8_dtype
=
fp8_dtype
,
quantizer
=
None
,
is_2D_scaled
=
False
,
requires_grad
=
output
.
requires_grad
,
)
elif
mxfp8_recipe
:
output
=
MXFP8Tensor
(
shape
=
output
.
shape
,
dtype
=
fake_dtype
,
fp8_dtype
=
fp8_dtype
,
rowwise_data
=
output
,
rowwise_scale_inv
=
permuted_scale
.
contiguous
(),
columnwise_data
=
None
,
columnwise_scale_inv
=
None
,
quantizer
=
None
,
requires_grad
=
output
.
requires_grad
,
)
ctx
.
save_for_backward
(
row_id_map
)
ctx
.
save_for_backward
(
row_id_map
)
ctx
.
num_experts
=
num_experts
ctx
.
num_experts
=
num_experts
...
@@ -327,14 +313,9 @@ class _moe_permute_mask_map(torch.autograd.Function):
...
@@ -327,14 +313,9 @@ class _moe_permute_mask_map(torch.autograd.Function):
probs_grad
=
None
probs_grad
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
(
row_id_map
,)
=
ctx
.
saved_tensors
(
row_id_map
,)
=
ctx
.
saved_tensors
fp8
=
isinstance
(
permuted_act_grad
,
Float8Tensor
)
assert
not
isinstance
(
if
fp8
:
permuted_act_grad
,
QuantizedTensor
fp8_dtype
=
permuted_act_grad
.
_fp8_dtype
),
"The backward of moe_permute does not support FP8."
fp8_scale_inv
=
permuted_act_grad
.
_scale_inv
fake_dtype
=
permuted_act_grad
.
dtype
permuted_act_grad
=
permuted_act_grad
.
_data
else
:
fp8_dtype
=
None
act_grad
,
probs_grad
=
triton_permutation
.
unpermute_with_mask_map
(
act_grad
,
probs_grad
=
triton_permutation
.
unpermute_with_mask_map
(
permuted_act_grad
,
permuted_act_grad
,
row_id_map
,
row_id_map
,
...
@@ -343,15 +324,6 @@ class _moe_permute_mask_map(torch.autograd.Function):
...
@@ -343,15 +324,6 @@ class _moe_permute_mask_map(torch.autograd.Function):
ctx
.
num_tokens
,
ctx
.
num_tokens
,
ctx
.
num_experts
,
ctx
.
num_experts
,
ctx
.
hidden_size
,
ctx
.
hidden_size
,
fp8_dtype
,
)
if
fp8
:
act_grad
=
Float8Tensor
(
data
=
act_grad
,
fp8_dtype
=
fp8_dtype
,
fp8_scale_inv
=
fp8_scale_inv
*
ctx
.
num_experts
,
shape
=
act_grad
.
shape
,
dtype
=
fake_dtype
,
)
)
if
not
ctx
.
needs_input_grad
[
3
]:
if
not
ctx
.
needs_input_grad
[
3
]:
probs_grad
=
None
probs_grad
=
None
...
@@ -366,8 +338,8 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
...
@@ -366,8 +338,8 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
ctx
,
ctx
,
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
row_id_map
:
torch
.
Tensor
,
row_id_map
:
torch
.
Tensor
,
merging_probs
:
torch
.
Tensor
,
merging_probs
:
Optional
[
torch
.
Tensor
]
,
restore_shape
:
torch
.
Size
,
restore_shape
:
Optional
[
torch
.
Size
]
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
if
not
inp
.
numel
():
if
not
inp
.
numel
():
...
@@ -387,17 +359,9 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
...
@@ -387,17 +359,9 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
assert
inp
.
is_cuda
,
"TransformerEngine needs CUDA."
assert
inp
.
is_cuda
,
"TransformerEngine needs CUDA."
assert
row_id_map
.
is_cuda
,
"TransformerEngine needs CUDA."
assert
row_id_map
.
is_cuda
,
"TransformerEngine needs CUDA."
fp8
=
isinstance
(
inp
,
Float8Tensor
)
assert
not
isinstance
(
if
fp8
:
inp
,
QuantizedTensor
fp8_dtype
=
inp
.
_fp8_dtype
),
"The forward of moe_unpermute does not support FP8."
if
not
with_probs
:
fp8_scale_inv
=
inp
.
_scale_inv
*
num_experts
else
:
fp8_scale_inv
=
inp
.
_scale_inv
fake_dtype
=
inp
.
dtype
inp
=
inp
.
_data
else
:
fp8_dtype
=
None
unpermuted_output
,
_
=
triton_permutation
.
unpermute_with_mask_map
(
unpermuted_output
,
_
=
triton_permutation
.
unpermute_with_mask_map
(
inp
,
inp
,
row_id_map
,
row_id_map
,
...
@@ -406,15 +370,6 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
...
@@ -406,15 +370,6 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
num_tokens
,
num_tokens
,
num_experts
,
num_experts
,
hidden_size
,
hidden_size
,
fp8_dtype
=
fp8_dtype
,
)
if
fp8
:
unpermuted_output
=
Float8Tensor
(
data
=
unpermuted_output
,
fp8_dtype
=
fp8_dtype
,
fp8_scale_inv
=
fp8_scale_inv
,
shape
=
unpermuted_output
.
shape
,
dtype
=
fake_dtype
,
)
)
if
with_probs
:
if
with_probs
:
...
@@ -442,16 +397,44 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
...
@@ -442,16 +397,44 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
else
:
else
:
(
row_id_map
,)
=
ctx
.
saved_tensors
(
row_id_map
,)
=
ctx
.
saved_tensors
fp8
=
isinstance
(
unpermuted_act_grad
,
Float8Tensor
)
fp8
=
isinstance
(
unpermuted_act_grad
,
QuantizedTensor
)
per_tensor_recipe
=
isinstance
(
unpermuted_act_grad
,
Float8Tensor
)
blockwise_recipe
=
isinstance
(
unpermuted_act_grad
,
Float8BlockwiseQTensor
)
mxfp8_recipe
=
isinstance
(
unpermuted_act_grad
,
MXFP8Tensor
)
if
fp8
:
if
fp8
:
fp8_dtype
=
unpermuted_act_grad
.
_fp8_dtype
fp8_dtype
=
unpermuted_act_grad
.
_fp8_dtype
fp8_scale_inv
=
unpermuted_act_grad
.
_scale_inv
fake_dtype
=
unpermuted_act_grad
.
dtype
fake_dtype
=
unpermuted_act_grad
.
dtype
# per-tensor scaling
if
per_tensor_recipe
:
# Kernel does not need scale in per-tensor scaling
fp8_scale
=
None
scale_hidden_dim
=
None
fp8_scale_inv
=
unpermuted_act_grad
.
_scale_inv
unpermuted_act_grad
=
unpermuted_act_grad
.
_data
unpermuted_act_grad
=
unpermuted_act_grad
.
_data
# blockwise scaling
elif
blockwise_recipe
:
fp8_scale
=
unpermuted_act_grad
.
_rowwise_scale_inv
.
T
.
contiguous
()
unpermuted_act_grad
=
unpermuted_act_grad
.
_rowwise_data
scale_hidden_dim
=
fp8_scale
.
shape
[
1
]
assert
ctx
.
num_tokens
==
fp8_scale
.
shape
[
0
],
"scale and input shape mismatch"
# mxfp8 scaling
elif
mxfp8_recipe
:
fp8_scale
=
unpermuted_act_grad
.
_rowwise_scale_inv
.
contiguous
()
unpermuted_act_grad
=
unpermuted_act_grad
.
_rowwise_data
scale_hidden_dim
=
fp8_scale
.
shape
[
1
]
assert
ctx
.
num_tokens
==
fp8_scale
.
shape
[
0
],
"scale and input shape mismatch"
else
:
raise
ValueError
(
"Unsupported FP8 recipe"
)
else
:
else
:
scale_hidden_dim
=
None
fp8_dtype
=
None
fp8_dtype
=
None
fp8_scale
=
None
if
ctx
.
with_probs
:
if
ctx
.
with_probs
:
assert
(
not
fp8
),
"The backward of moe_unpermute with merging probs does not support FP8."
act_grad
,
probs_grad
=
(
act_grad
,
probs_grad
=
(
triton_permutation
.
unpermute_with_mask_map_bwd_with_merging_probs
(
triton_permutation
.
unpermute_with_mask_map_bwd_with_merging_probs
(
unpermuted_act_grad
,
unpermuted_act_grad
,
...
@@ -462,21 +445,23 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
...
@@ -462,21 +445,23 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
ctx
.
num_experts
,
ctx
.
num_experts
,
ctx
.
num_permuted_tokens
,
ctx
.
num_permuted_tokens
,
ctx
.
hidden_size
,
ctx
.
hidden_size
,
fp8_dtype
,
)
)
)
)
else
:
else
:
act_grad
,
_
=
triton_permutation
.
permute_with_mask_map
(
act_grad
,
permuted_scale
,
_
=
triton_permutation
.
permute_with_mask_map
(
unpermuted_act_grad
,
unpermuted_act_grad
,
row_id_map
,
row_id_map
,
None
,
None
,
fp8_scale
,
ctx
.
num_tokens
,
ctx
.
num_tokens
,
ctx
.
num_experts
,
ctx
.
num_experts
,
ctx
.
num_permuted_tokens
,
ctx
.
num_permuted_tokens
,
ctx
.
hidden_size
,
ctx
.
hidden_size
,
scale_hidden_dim
,
)
)
if
fp8
:
if
fp8
:
if
per_tensor_recipe
:
act_grad
=
Float8Tensor
(
act_grad
=
Float8Tensor
(
data
=
act_grad
,
data
=
act_grad
,
fp8_dtype
=
fp8_dtype
,
fp8_dtype
=
fp8_dtype
,
...
@@ -484,6 +469,31 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
...
@@ -484,6 +469,31 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
shape
=
act_grad
.
shape
,
shape
=
act_grad
.
shape
,
dtype
=
fake_dtype
,
dtype
=
fake_dtype
,
)
)
elif
blockwise_recipe
:
act_grad
=
Float8BlockwiseQTensor
(
shape
=
act_grad
.
shape
,
dtype
=
fake_dtype
,
rowwise_data
=
act_grad
,
rowwise_scale_inv
=
permuted_scale
.
T
.
contiguous
(),
columnwise_data
=
None
,
columnwise_scale_inv
=
None
,
fp8_dtype
=
fp8_dtype
,
quantizer
=
None
,
is_2D_scaled
=
False
,
requires_grad
=
act_grad
.
requires_grad
,
)
elif
mxfp8_recipe
:
act_grad
=
MXFP8Tensor
(
shape
=
act_grad
.
shape
,
dtype
=
fake_dtype
,
fp8_dtype
=
fp8_dtype
,
rowwise_data
=
act_grad
,
rowwise_scale_inv
=
permuted_scale
.
contiguous
(),
columnwise_data
=
None
,
columnwise_scale_inv
=
None
,
quantizer
=
None
,
requires_grad
=
act_grad
.
requires_grad
,
)
if
not
ctx
.
needs_input_grad
[
2
]:
if
not
ctx
.
needs_input_grad
[
2
]:
probs_grad
=
None
probs_grad
=
None
...
@@ -568,10 +578,10 @@ def moe_permute_with_probs(
...
@@ -568,10 +578,10 @@ def moe_permute_with_probs(
def
moe_unpermute
(
def
moe_unpermute
(
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
row_id_map
:
torch
.
Tensor
,
row_id_map
:
torch
.
Tensor
,
merging_probs
:
torch
.
Tensor
=
None
,
merging_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
restore_shape
:
torch
.
Tensor
=
None
,
restore_shape
:
Optional
[
torch
.
Size
]
=
None
,
map_type
:
str
=
"mask"
,
map_type
:
str
=
"mask"
,
probs
:
torch
.
Tensor
=
None
,
probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
...
@@ -588,7 +598,7 @@ def moe_unpermute(
...
@@ -588,7 +598,7 @@ def moe_unpermute(
The tensor of probabilities corresponding to the permuted tokens. If provided,
The tensor of probabilities corresponding to the permuted tokens. If provided,
the unpermuted tokens will be merged with their respective probabilities.
the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
restore_shape: torch.
Tensor
restore_shape: torch.
Size, default = None
The output shape after the unpermute operation.
The output shape after the unpermute operation.
map_type: str, default = 'mask'
map_type: str, default = 'mask'
Type of the routing map tensor. Should be the same as the value passed to moe_permute.
Type of the routing map tensor. Should be the same as the value passed to moe_permute.
...
...
transformer_engine/pytorch/tensor/__init__.py
View file @
ab3e5a92
...
@@ -42,3 +42,27 @@ def _make_module_cast_func(dtype):
...
@@ -42,3 +42,27 @@ def _make_module_cast_func(dtype):
torch
.
nn
.
Module
.
float
=
_make_module_cast_func
(
torch
.
float32
)
torch
.
nn
.
Module
.
float
=
_make_module_cast_func
(
torch
.
float32
)
torch
.
nn
.
Module
.
half
=
_make_module_cast_func
(
torch
.
float16
)
torch
.
nn
.
Module
.
half
=
_make_module_cast_func
(
torch
.
float16
)
torch
.
nn
.
Module
.
bfloat16
=
_make_module_cast_func
(
torch
.
bfloat16
)
torch
.
nn
.
Module
.
bfloat16
=
_make_module_cast_func
(
torch
.
bfloat16
)
def
get_all_tensor_types
():
"""
Get all tensor-like types that can be used in TE.
"""
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
,
Float8TensorBase
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
,
MXFP8TensorBase
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
Float8BlockwiseQTensor
,
Float8BlockwiseQTensorBase
,
)
all_tensor_types
=
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
,
Float8Tensor
,
Float8TensorBase
,
MXFP8Tensor
,
MXFP8TensorBase
,
Float8BlockwiseQTensor
,
Float8BlockwiseQTensorBase
,
]
return
all_tensor_types
transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
0 → 100644
View file @
ab3e5a92
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Mixin class holding data specific for Float8BlockwiseQTensor"""
from
__future__
import
annotations
import
math
from
typing
import
Optional
,
Dict
,
Any
,
Tuple
import
torch
from
transformer_engine_torch
import
DType
as
TE_DType
from
...constants
import
TE_DType_To_Torch
from
..quantized_tensor
import
Quantizer
class
Float8BlockwiseQTensorBase
:
"""Mixin class that holds data attributes of Float8BlockwiseQTensor.
Float8BlockwiseQTensor inherits from the PyTorch tensor class and this
mixin class. If this class is instantiated directly, it has the same
data, lower CPU overhead, and less functionality. It should only
be instantiated directly for performance-critical internal usage.
"""
_rowwise_data
:
Optional
[
torch
.
Tensor
]
_columnwise_data
:
Optional
[
torch
.
Tensor
]
_quantizer
:
Quantizer
_fp8_dtype
:
TE_DType
_rowwise_scale_inv
:
Optional
[
torch
.
Tensor
]
_columnwise_scale_inv
:
Optional
[
torch
.
Tensor
]
_is_2D_scaled
:
bool
def
__new__
(
cls
,
*
args
,
rowwise_data
:
Optional
[
torch
.
Tensor
],
rowwise_scale_inv
:
Optional
[
torch
.
Tensor
],
columnwise_data
:
Optional
[
torch
.
Tensor
],
columnwise_scale_inv
:
Optional
[
torch
.
Tensor
],
fp8_dtype
:
TE_DType
,
quantizer
:
Quantizer
,
is_2D_scaled
:
bool
,
**
kwargs
,
):
instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
instance
.
_rowwise_data
=
rowwise_data
instance
.
_columnwise_data
=
columnwise_data
instance
.
_quantizer
=
quantizer
instance
.
_fp8_dtype
=
fp8_dtype
instance
.
_rowwise_scale_inv
=
rowwise_scale_inv
instance
.
_columnwise_scale_inv
=
columnwise_scale_inv
instance
.
_is_2D_scaled
=
is_2D_scaled
return
instance
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
"""Get this tensor's metadata."""
return
{
"rowwise_data"
:
self
.
_rowwise_data
,
"rowwise_scale_inv"
:
self
.
_rowwise_scale_inv
,
"columnwise_data"
:
self
.
_columnwise_data
,
"columnwise_scale_inv"
:
self
.
_columnwise_scale_inv
,
"fp8_dtype"
:
self
.
_fp8_dtype
,
"quantizer"
:
self
.
_quantizer
,
"is_2D_scaled"
:
self
.
_is_2D_scaled
,
}
def
prepare_for_saving
(
self
,
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
Float8BlockwiseQTensorBase
]:
"""
Prepare the tensor base for saving for backward
This does not clear the tensors currently, because with PP config
that clears the weight cache between micro-batches. If the rowwise
data is not required for backward, this is a possible memory
pessimization, but is consistent with the other quantized tensor
classes.
"""
tensors
=
[
self
.
_rowwise_data
,
self
.
_columnwise_data
]
return
tensors
,
self
def
restore_from_saved
(
self
,
tensors
:
list
[
Optional
[
torch
.
Tensor
]]
)
->
list
[
Optional
[
torch
.
Tensor
]]:
"""Restore the tensor base data from the saved tensors list."""
self
.
_rowwise_data
=
tensors
[
0
]
self
.
_columnwise_data
=
tensors
[
1
]
return
tensors
[
2
:]
def
get_data_tensors
(
self
):
"""Get this Tensor's data."""
return
self
.
_rowwise_data
,
self
.
_columnwise_data
def
_transpose_dq_columnwise_output
(
self
,
columnwise_dq
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Takes dequantized columnwise data and permutes to a rowwise shape"""
if
columnwise_dq
.
dim
()
<
2
:
return
columnwise_dq
permute_dims
=
list
(
range
(
1
,
columnwise_dq
.
dim
()))
permute_dims
.
append
(
0
)
return
torch
.
permute
(
columnwise_dq
,
tuple
(
permute_dims
)).
contiguous
()
def
_dequantize_vectorwise
(
self
,
*
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
torch
.
Tensor
:
block_len
=
128
q_M
,
q_K
=
1
,
1
if
self
.
_rowwise_data
is
not
None
:
q
=
self
.
_rowwise_data
scale_inv
=
self
.
_rowwise_scale_inv
transpose_output
=
False
if
len
(
q
.
shape
)
>=
1
:
q_K
=
q
.
shape
[
-
1
]
for
i
in
range
(
len
(
q
.
shape
)
-
1
):
q_M
*=
q
.
shape
[
i
]
else
:
assert
self
.
_columnwise_data
is
not
None
,
"No data to dequantize"
q
=
self
.
_columnwise_data
scale_inv
=
self
.
_columnwise_scale_inv
transpose_output
=
True
if
len
(
q
.
shape
)
>=
1
:
q_M
=
q
.
shape
[
0
]
for
i
in
range
(
1
,
len
(
q
.
shape
)):
q_K
*=
q
.
shape
[
i
]
orig_shape
=
q
.
shape
q
=
q
.
reshape
(
q_M
,
q_K
)
k_tiles
,
scale_m
=
scale_inv
.
shape
if
q_K
%
block_len
!=
0
:
k_pad_amount
=
(
block_len
-
(
q_K
%
block_len
))
%
block_len
q
=
torch
.
nn
.
functional
.
pad
(
q
,
(
0
,
k_pad_amount
,
0
,
0
),
mode
=
"constant"
,
value
=
0
).
contiguous
()
_
,
padded_K
=
q
.
shape
q_tiled
=
q
.
reshape
(
q_M
,
k_tiles
,
block_len
)
if
scale_m
>
q_M
:
# scale_m is 4 element aligned.
scale_inv
=
scale_inv
[:,
:
q_M
].
contiguous
()
dq_scale
=
scale_inv
.
transpose
(
-
2
,
-
1
).
contiguous
().
reshape
(
q_M
,
k_tiles
,
1
)
torch_q_dtype
=
TE_DType_To_Torch
[
self
.
_fp8_dtype
]
result
=
q_tiled
.
view
(
torch_q_dtype
).
to
(
torch
.
float32
)
*
dq_scale
if
padded_K
!=
q_K
:
result
=
result
.
reshape
(
q_M
,
padded_K
)[:,
:
q_K
]
result
=
result
.
to
(
dtype
)
if
len
(
orig_shape
)
==
0
:
result
=
result
.
reshape
([])
else
:
result
=
result
.
reshape
(
*
orig_shape
).
contiguous
()
if
transpose_output
:
return
self
.
_transpose_dq_columnwise_output
(
result
)
return
result
def
dequantize
(
self
,
*
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
torch
.
Tensor
:
"""
Construct plain PyTorch tensor from Float8BlockwiseQTensor
"""
block_len
=
128
if
not
self
.
_is_2D_scaled
:
return
self
.
_dequantize_vectorwise
(
dtype
=
dtype
)
def
format_scale_as_logical_shape
(
q_K
,
scales
,
block_len
):
# The GEMM for 2D blocks required padding in the scales.
derived_scale_k_shape
=
math
.
ceil
(
q_K
/
block_len
)
_
,
scale_K
=
scales
.
shape
if
derived_scale_k_shape
==
scale_K
:
return
scales
return
scales
[:,
:
derived_scale_k_shape
].
contiguous
()
q_M
,
q_K
=
1
,
1
if
self
.
_rowwise_data
is
not
None
:
q
=
self
.
_rowwise_data
scale_inv
=
self
.
_rowwise_scale_inv
transpose_output
=
False
if
len
(
q
.
shape
)
>=
1
:
q_K
=
q
.
shape
[
-
1
]
for
i
in
range
(
len
(
q
.
shape
)
-
1
):
q_M
*=
q
.
shape
[
i
]
else
:
assert
self
.
_columnwise_data
is
not
None
,
"No data to dequantize"
q
=
self
.
_columnwise_data
scale_inv
=
self
.
_columnwise_scale_inv
transpose_output
=
True
if
len
(
q
.
shape
)
>=
1
:
q_M
=
q
.
shape
[
0
]
for
i
in
range
(
1
,
len
(
q
.
shape
)):
q_K
*=
q
.
shape
[
i
]
orig_shape
=
q
.
shape
q
=
q
.
reshape
(
q_M
,
q_K
)
formatted_scales
=
format_scale_as_logical_shape
(
q_K
,
scale_inv
,
block_len
)
assert
len
(
formatted_scales
.
shape
)
==
2
m_tiles
,
k_tiles
=
formatted_scales
.
shape
unpadded_m
,
unpadded_k
=
q_M
,
q_K
m_block_len
=
block_len
k_block_len
=
block_len
if
q_M
%
m_block_len
!=
0
or
q_K
%
k_block_len
!=
0
:
m_pad_amount
=
(
m_block_len
-
(
q_M
%
m_block_len
))
%
m_block_len
k_pad_amount
=
(
k_block_len
-
(
q_K
%
k_block_len
))
%
k_block_len
q
=
torch
.
nn
.
functional
.
pad
(
q
,
(
0
,
k_pad_amount
,
0
,
m_pad_amount
),
mode
=
"constant"
,
value
=
0
).
contiguous
()
padded_M
,
padded_K
=
q
.
shape
q_tiled
=
q
.
reshape
(
m_tiles
,
m_block_len
,
k_tiles
,
k_block_len
)
torch_q_dtype
=
TE_DType_To_Torch
[
self
.
_fp8_dtype
]
result
=
q_tiled
.
view
(
torch_q_dtype
).
to
(
torch
.
float32
)
*
formatted_scales
.
view
(
m_tiles
,
1
,
k_tiles
,
1
)
result
=
result
.
view
(
padded_M
,
padded_K
).
to
(
dtype
)
if
padded_M
!=
unpadded_m
or
padded_K
!=
unpadded_k
:
result
=
result
[:
unpadded_m
,
:
unpadded_k
]
if
len
(
orig_shape
)
==
0
:
result
=
result
.
reshape
([])
else
:
result
=
result
.
reshape
(
*
orig_shape
).
contiguous
()
if
transpose_output
:
return
self
.
_transpose_dq_columnwise_output
(
result
)
return
result
def
size
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=missing-function-docstring
if
self
.
_rowwise_data
is
not
None
:
return
self
.
_rowwise_data
.
size
(
*
args
,
**
kwargs
)
dims
=
list
(
self
.
_columnwise_data
.
size
(
*
args
,
**
kwargs
))
reordered
=
[]
for
i
in
range
(
1
,
len
(
dims
)):
reordered
.
append
(
dims
[
i
])
reordered
.
append
(
dims
[
0
])
return
torch
.
Size
(
reordered
)
def
__repr__
(
self
):
if
self
.
_rowwise_data
is
not
None
:
data
=
self
.
dequantize
()
descriptor
=
"rowwise"
else
:
data
=
self
.
dequantize
()
descriptor
=
"columnwise"
return
(
"Float8BlockwiseQTensorBase("
f
"fp8_dtype=
{
self
.
_fp8_dtype
}
, "
f
"
{
descriptor
}
_scaled_data=
{
data
}
"
)
transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py
View file @
ab3e5a92
...
@@ -27,12 +27,14 @@ class _FromFloat8Func(torch.autograd.Function):
...
@@ -27,12 +27,14 @@ class _FromFloat8Func(torch.autograd.Function):
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
dtype
=
torch_to_transformer_engine_dtype
[
dtype
]
te_
dtype
=
torch_to_transformer_engine_dtype
[
dtype
]
# Make sure FP8 data is in expected format
# Make sure FP8 data is in expected format
if
tensor
.
_data
is
not
None
:
if
tensor
.
_data
is
not
None
:
if
tensor
.
_data
.
numel
()
==
0
:
return
torch
.
empty_like
(
tensor
.
_data
,
dtype
=
dtype
)
# Cast from FP8
# Cast from FP8
return
tex
.
dequantize
(
tensor
,
dtype
)
return
tex
.
dequantize
(
tensor
,
te_
dtype
)
raise
NotImplementedError
(
"Casting back from the transpose not implemented yet!"
)
raise
NotImplementedError
(
"Casting back from the transpose not implemented yet!"
)
...
@@ -134,3 +136,11 @@ class Float8TensorBase:
...
@@ -134,3 +136,11 @@ class Float8TensorBase:
f
"data=
{
self
.
dequantize
()
}
"
f
"data=
{
self
.
dequantize
()
}
"
")"
")"
)
)
def
_create_transpose
(
self
):
"""Update FP8 transpose cache"""
data
=
self
.
_data
if
not
data
.
is_contiguous
():
data
=
data
.
contiguous
()
self
.
_transpose
=
tex
.
fp8_transpose
(
data
,
self
.
_fp8_dtype
,
out
=
self
.
_transpose
)
self
.
_transpose_invalid
=
False
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
0 → 100644
View file @
ab3e5a92
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tensor class with FP8 data quantized with NxN tiles"""
from
__future__
import
annotations
from
typing
import
Optional
,
Tuple
,
Iterable
import
math
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
from
..utils
import
devices_match
,
round_up_to_nearest_multiple
aten
=
torch
.
ops
.
aten
class
Float8BlockQuantizer
(
Quantizer
):
"""Builder class for tensors quantized with current scaling using
NxN quantization tilings to choose scale.
This class is typically used to convert a high-precision tensor
(e.g. in FP32 or BF16) into a quantized tensor (e.g. in FP8).
"""
dtype
:
TE_DType
block_len
:
int
amax_epsilon
:
float
force_pow_2_scales
:
bool
block_scaling_dim
:
int
def
__init__
(
self
,
fp8_dtype
:
TE_DType
,
*
,
rowwise
:
bool
,
columnwise
:
bool
,
amax_epsilon
:
float
=
0.0
,
force_pow_2_scales
:
bool
=
True
,
block_scaling_dim
:
int
=
2
,
)
->
None
:
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
dtype
=
fp8_dtype
self
.
block_len
=
128
self
.
force_pow_2_scales
=
force_pow_2_scales
self
.
amax_epsilon
=
amax_epsilon
self
.
block_scaling_dim
=
block_scaling_dim
def
update_quantized
(
self
,
src
:
torch
.
Tensor
,
dst
:
QuantizedTensor
,
*
,
noop_flag
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
QuantizedTensor
:
"""Update the quantized tensor with data from the source tensor.
This method quantizes the input tensor and stores the result in the destination tensor.
Parameters
----------
src : torch.Tensor
Source tensor containing the data to be quantized
dst : QuantizedTensor
Destination tensor where the quantized data will be stored
noop_flag : Optional[torch.Tensor]
Optional flag tensor indicating whether to skip the quantization operation
Returns
-------
QuantizedTensor
The destination tensor containing the quantized data
Raises
------
AssertionError
If the destination tensor is not a Float8BlockwiseQTensor
"""
assert
isinstance
(
dst
,
Float8BlockwiseQTensor
),
f
"Cannot store quantized blockwise tensor in
{
type
(
dst
)
}
type."
# Make sure input is in expected format
if
not
devices_match
(
src
.
device
,
dst
.
device
):
src
=
src
.
to
(
device
=
dst
.
device
)
if
not
src
.
is_contiguous
():
src
=
src
.
contiguous
()
# Launch cast kernel
tex
.
quantize
(
src
,
self
,
dst
,
noop_flag
)
dst
.
_fp8_dtype
=
self
.
dtype
return
dst
def
get_scale_shape
(
self
,
shape
:
Iterable
[
int
],
columnwise
:
bool
)
->
Tuple
[
int
,
int
]:
"""Calculate the shape of the scaling tensor for blockwise quantization.
This method determines the shape of the scaling tensor needed for blockwise quantization,
taking into account the input tensor shape and whether columnwise scaling is used.
The scales are padded to multiples of 4 on the inner dimension for compatibility with GEMM.
Parameters
----------
shape : Iterable[int]
Shape of the input tensor to be quantized
columnwise : bool
Whether to use columnwise scaling (True) or rowwise scaling (False)
Returns
-------
Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim)
For 2D tensors:
- If columnwise: (roundup(K/blocksize), round_to_multiple(roundup(M/blocksize), 4))
- If rowwise: (roundup(M/blocksize), round_to_multiple(roundup(K/blocksize), 4))
For 1D tensors:
- If columnwise: (roundup(M/blocksize), round_to_multiple(K, 4))
- If rowwise: (roundup(K/blocksize), round_to_multiple(M, 4))
"""
M
,
K
=
1
,
1
for
i
in
range
(
len
(
shape
)
-
1
):
M
*=
shape
[
i
]
if
len
(
shape
)
>
0
:
K
=
shape
[
-
1
]
if
self
.
block_scaling_dim
==
2
:
if
columnwise
:
outer
=
math
.
ceil
(
K
/
self
.
block_len
)
inner
=
round_up_to_nearest_multiple
(
math
.
ceil
(
M
/
self
.
block_len
),
4
)
return
(
outer
,
inner
)
outer
=
math
.
ceil
(
M
/
self
.
block_len
)
inner
=
round_up_to_nearest_multiple
(
math
.
ceil
(
K
/
self
.
block_len
),
4
)
return
(
outer
,
inner
)
assert
self
.
block_scaling_dim
==
1
,
"Only 1D or 2D blocks supported"
if
columnwise
:
outer
=
math
.
ceil
(
M
/
self
.
block_len
)
inner
=
round_up_to_nearest_multiple
(
K
,
4
)
return
(
outer
,
inner
)
outer
=
math
.
ceil
(
K
/
self
.
block_len
)
inner
=
round_up_to_nearest_multiple
(
M
,
4
)
return
(
outer
,
inner
)
def
get_columnwise_shape
(
self
,
shape
:
Iterable
[
int
])
->
Tuple
[
int
,
...]:
"""Calculate the shape of a tensor after columnwise permutation.
This method rearranges the dimensions of a tensor to be columnwise,
moving the last dimension to the front and keeping the order of other dimensions.
Parameters
----------
shape : Iterable[int]
Original shape of the tensor
Returns
-------
Tuple[int, ...]
New shape with dimensions rearranged for columnwise layout.
For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1).
Returns empty tuple for empty input shape.
"""
if
len
(
shape
)
==
0
:
return
tuple
()
colwise_shape
=
[
shape
[
-
1
]]
for
i
in
range
(
len
(
shape
)
-
1
):
colwise_shape
.
append
(
shape
[
i
])
return
tuple
(
colwise_shape
)
# TODO(kwyss): With FP8 gather support, we need to implement a
# shape/layout/swizzle check to know whether FP8 gather works
# cleanly by stacking data without aliasing tiles and whether
# the scales also stack on the proper dimensions.
def
make_empty
(
self
,
shape
:
Iterable
[
int
],
*
,
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
False
,
)
->
Float8BlockwiseQTensor
:
"""Construct quantized tensor with uninitialized data"""
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
# Allocate FP8 data
data
=
None
scale_inv
=
None
if
self
.
rowwise_usage
:
data
=
torch
.
empty
(
shape
,
dtype
=
torch
.
uint8
,
device
=
device
)
scale_shape
=
self
.
get_scale_shape
(
shape
,
columnwise
=
False
)
scale_inv
=
torch
.
empty
(
scale_shape
,
dtype
=
torch
.
float32
,
device
=
device
,
)
# Allocate FP8 data transpose if needed
columnwise_data
=
None
columnwise_scale_inv
=
None
if
self
.
columnwise_usage
:
columnwise_data
=
torch
.
empty
(
self
.
get_columnwise_shape
(
shape
),
dtype
=
torch
.
uint8
,
device
=
device
)
columnwise_scale_shape
=
self
.
get_scale_shape
(
shape
,
columnwise
=
True
)
columnwise_scale_inv
=
torch
.
empty
(
columnwise_scale_shape
,
dtype
=
torch
.
float32
,
device
=
device
,
)
# Construct FP8 tensor
return
Float8BlockwiseQTensor
(
shape
=
shape
,
dtype
=
dtype
,
fp8_dtype
=
self
.
dtype
,
rowwise_data
=
data
,
rowwise_scale_inv
=
scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
quantizer
=
self
,
is_2D_scaled
=
self
.
block_scaling_dim
==
2
,
requires_grad
=
requires_grad
,
)
def
calibrate
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
# NOTE: This interface is specific to requirements like delayed scaling
# where state from an estimator influences distribution parameters.
pass
class
Float8BlockwiseQTensor
(
Float8BlockwiseQTensorBase
,
QuantizedTensor
):
"""Tensor class with FP8 data quantized via NxN blocks or 1xN blocks.
The tensor presents as having a standard, higher-precision dtype,
but the data itself is (scaled) FP8. For most tensor operations,
the data will be cast to the nominal dtype before performing the
operation.
Parameters
----------
rowwise_data: torch.Tensor
FP8 data in a uint8 tensor matching shape of dequantized tensor.
rowwise_scale_inv: torch.Tensor
FP32 dequantization scales in GEMM format for dequantizing rowwise_data.
columnwise_data: Optional[torch.Tensor]
FP8 data in a uint8 tensor matching shape of dequantized tensor transpose.
columnwise_scale_inv: Optional[torch.Tensor]
FP32 dequantization scales in GEMM format for dequantizing columnwise_data.
fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3
FP8 format.
quantizer: Quantizer - the Float8BlockQuantizer that quantized this tensor and
holds configuration about quantization and dequantization modes.
"""
def
__repr__
(
self
,
*
,
tensor_contents
=
None
):
return
(
f
"Float8BlockwiseQTensor(fp8_dtype=
{
self
.
_fp8_dtype
}
,"
f
" is_2D_scaled=
{
self
.
_is_2D_scaled
}
,"
f
" data=
{
self
.
dequantize
(
dtype
=
self
.
dtype
)
}
)"
)
def
_get_quantizer
(
self
)
->
Quantizer
:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
assert
self
.
_quantizer
is
not
None
return
self
.
_quantizer
def
quantize_
(
self
,
tensor
:
torch
.
Tensor
,
*
,
noop_flag
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Float8BlockwiseQTensor
:
"""Update FP8 data
Parameters
----------
tensor: torch.Tensor
Tensor to copy from
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
if
isinstance
(
tensor
,
QuantizedTensor
):
return
self
.
quantize_
(
tensor
.
dequantize
())
self
.
_get_quantizer
().
update_quantized
(
tensor
,
self
,
noop_flag
=
noop_flag
)
return
self
def
dequantize
(
self
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
torch
.
Tensor
:
"""
Construct plain PyTorch tensor from Float8BlockwiseQTensor
By default the resulting tensor's dtype is the
Float8BlockwiseQTensor's pre-quantized dtype.
"""
if
dtype
is
not
None
:
dequant_dtype
=
dtype
else
:
dequant_dtype
=
self
.
dtype
return
super
().
dequantize
(
dtype
=
dequant_dtype
)
def
detach
(
self
)
->
Float8BlockwiseQTensor
:
# pylint: disable=missing-function-docstring
return
Float8BlockwiseQTensor
.
make_like
(
self
)
def
update_usage
(
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
columnwise_usage
:
Optional
[
bool
]
=
None
):
"""
update_usage can be used to clear out one of two possible copies of the data.
"""
if
rowwise_usage
is
None
:
rowwise_usage
=
self
.
_rowwise_data
is
not
None
if
columnwise_usage
is
None
:
columnwise_usage
=
self
.
_columnwise_data
is
not
None
assert
(
columnwise_usage
or
rowwise_usage
),
"Must retain some data either columnwise or rowwise"
if
columnwise_usage
and
rowwise_usage
:
assert
(
self
.
_rowwise_data
is
not
None
and
self
.
_rowwise_scale_inv
is
not
None
and
self
.
_columnwise_data
is
not
None
and
self
.
_columnwise_scale_inv
is
not
None
),
"Cannot update to rowwise and columnwise usage."
return
if
rowwise_usage
:
assert
(
self
.
_rowwise_data
is
not
None
and
self
.
_rowwise_scale_inv
is
not
None
),
"Cannot update to rowwise usage."
self
.
_columnwise_data
=
None
self
.
_columnwise_scale_inv
=
None
return
if
columnwise_usage
:
assert
(
self
.
_columnwise_data
is
not
None
and
self
.
_columnwise_scale_inv
is
not
None
),
"Cannot update to columnwise usage."
self
.
_rowwise_data
=
None
self
.
_rowwise_scale_inv
=
None
return
return
def
clone
(
self
)
->
Float8BlockwiseQTensor
:
# pylint: disable=missing-function-docstring
rowwise_data
=
None
if
self
.
_rowwise_data
is
not
None
:
rowwise_data
=
self
.
_rowwise_data
.
detach
().
clone
()
columnwise_data
=
None
if
self
.
_columnwise_data
is
not
None
:
columnwise_data
=
self
.
_columnwise_data
.
detach
().
clone
()
return
_IdentityFunc
.
apply
(
self
,
{
"rowwise_data"
:
rowwise_data
,
"columnwise_data"
:
columnwise_data
,
},
)
def
view
(
self
,
*
shape
:
Tuple
[
int
])
->
Float8BlockwiseQTensor
:
# pylint: disable=missing-function-docstring
return
_ViewFunc
.
apply
(
self
,
shape
)
def
reshape
(
self
,
*
shape
:
Tuple
[
int
])
->
Float8BlockwiseQTensor
:
# pylint: disable=missing-function-docstring
return
_ReshapeFunc
.
apply
(
self
,
shape
)
@
classmethod
def
__torch_dispatch__
(
cls
,
func
,
types
,
args
,
kwargs
=
None
):
# View op
if
func
==
aten
.
view
.
default
:
tensor
=
args
[
0
]
data
=
tensor
.
_rowwise_data
if
data
is
None
:
# Columnwise data only.
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
orig_size
=
data
.
size
()
out_data
=
data
.
__torch_dispatch__
(
func
,
types
,
[
data
]
+
list
(
args
[
1
:]),
kwargs
,
)
if
orig_size
!=
out_data
.
size
():
raise
NotImplementedError
(
"Changing shape with view not implemented "
" (scales and columnwise data untouched)."
)
return
Float8BlockwiseQTensor
.
make_like
(
tensor
)
# Default case
return
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
def
contiguous
(
self
,
memory_format
:
torch
.
memory_format
=
torch
.
contiguous_format
,
)
->
Float8BlockwiseQTensor
:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
"""
if
(
self
.
_rowwise_data
is
not
None
and
self
.
_rowwise_data
.
is_contiguous
(
memory_format
=
memory_format
)
and
(
(
self
.
_columnwise_data
is
None
)
or
(
self
.
_columnwise_data
.
is_contiguous
(
memory_format
=
memory_format
))
)
):
return
self
raise
ValueError
(
"Float8BlockwiseQTensor does not support different memory formats!"
)
def
clear
(
self
):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self
.
_rowwise_data
=
torch
.
Tensor
()
if
self
.
_rowwise_data
is
not
None
else
None
self
.
_columnwise_data
=
torch
.
Tensor
()
if
self
.
_columnwise_data
is
not
None
else
None
@
classmethod
def
_make_in_reduce_ex
(
cls
,
shape
:
torch
.
Size
,
rowwise_data
:
torch
.
Tensor
,
rowwise_scale_inv
:
torch
.
Tensor
,
columnwise_data
:
torch
.
Tensor
,
columnwise_scale_inv
:
torch
.
Tensor
,
fp8_dtype
:
TE_DType
,
dtype
:
torch
.
dtype
,
quantizer
:
Quantizer
,
is_2D_scaled
:
bool
,
)
->
Float8BlockwiseQTensor
:
"""Build Float8BlockwiseQTensor, for use in __reduce__
__reduce_ex__ assumes object constructor has positional
arguments.
"""
return
Float8BlockwiseQTensor
(
shape
=
shape
,
rowwise_data
=
rowwise_data
,
rowwise_scale_inv
=
rowwise_scale_inv
,
fp8_dtype
=
fp8_dtype
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
dtype
=
dtype
,
quantizer
=
quantizer
,
is_2D_scaled
=
is_2D_scaled
,
)
def
__reduce_ex__
(
self
,
protocol
:
int
)
->
tuple
:
"""Custom pickling to remove references to FP8 metadata objects"""
return
(
Float8BlockwiseQTensor
.
_make_in_reduce_ex
,
(
self
.
shape
,
self
.
_rowwise_data
,
self
.
_rowwise_scale_inv
,
self
.
_columnwise_data
,
self
.
_columnwise_scale_inv
,
self
.
_fp8_dtype
,
self
.
dtype
,
self
.
_quantizer
,
self
.
_is_2D_scaled
,
),
)
def
_get_data
(
self
)
->
Float8BlockwiseQTensor
:
"""Get tensor data property"""
return
self
@
torch
.
no_grad
()
def
_set_data
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
"""Set tensor data property
Just takes FP8 data if setting from a Float8BlockwiseQTensor. Otherwise
casts to FP8.
"""
# Tensor device
new_device
=
tensor
.
device
if
tensor
.
is_cuda
else
self
.
device
def
_set_from_tensor
(
dst
:
Float8BlockwiseQTensor
,
src
:
Float8BlockwiseQTensor
):
dst
.
_rowwise_data
=
src
.
_rowwise_data
dst
.
_columnwise_data
=
src
.
_columnwise_data
dst
.
_quantizer
=
src
.
_quantizer
dst
.
_fp8_dtype
=
src
.
_fp8_dtype
dst
.
_rowwise_scale_inv
=
src
.
_rowwise_scale_inv
dst
.
_columnwise_scale_inv
=
src
.
_columnwise_scale_inv
# Check that tensor dimensions match
if
(
self
.
size
()
!=
tensor
.
size
()
or
self
.
stride
()
!=
tensor
.
stride
()
or
self
.
layout
!=
tensor
.
layout
):
raise
ValueError
(
"Invalid tensor for updating Float8BlockwiseQTensor data"
)
# Just copy FP8 data if other tensor is Float8BlockwiseQTensor
if
(
isinstance
(
tensor
,
Float8BlockwiseQTensor
)
and
self
.
storage_offset
()
==
tensor
.
storage_offset
()
and
devices_match
(
self
.
device
,
new_device
)
):
_set_from_tensor
(
self
,
tensor
)
return
if
isinstance
(
tensor
,
Float8BlockwiseQTensor
):
assert
tensor
.
_quantizer
is
not
None
,
"Can't quantize without a quantizer"
quantizer
=
tensor
.
_quantizer
else
:
assert
self
.
_quantizer
is
not
None
,
"Can't quantize without a quantizer"
quantizer
=
self
.
_quantizer
# Quantize to FP8
quantizer
.
update_quantized
(
tensor
,
self
)
# Cast to FP8 when setting Float8BlockwiseQTensor.data
data
=
property
(
_get_data
,
_set_data
)
class
_ViewFunc
(
torch
.
autograd
.
Function
):
"""View function
View the Float8BlockwiseQTensor using the provided shape.
"""
@
staticmethod
def
forward
(
ctx
,
tensor
:
Float8BlockwiseQTensor
,
shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
Float8BlockwiseQTensor
:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
if
ctx
is
not
None
:
ctx
.
shape
=
tensor
.
shape
if
shape
is
None
:
return
tensor
if
list
(
shape
)
!=
list
(
tensor
.
shape
):
raise
NotImplementedError
(
"View not implemented."
)
return
tensor
@
staticmethod
def
backward
(
ctx
,
grad
:
torch
.
Tensor
,
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
...]:
# pylint: disable=missing-function-docstring
if
isinstance
(
grad
,
Float8BlockwiseQTensor
):
raise
NotImplementedError
(
"View bwd not implemented"
)
return
grad
.
view
(
ctx
.
shape
),
None
class
_ReshapeFunc
(
torch
.
autograd
.
Function
):
"""Reshape function
Reshape the Float8BlockwiseQTensor using the provided shape.
"""
@
staticmethod
def
forward
(
ctx
,
tensor
:
Float8BlockwiseQTensor
,
shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
Float8BlockwiseQTensor
:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
if
ctx
is
not
None
:
ctx
.
shape
=
tensor
.
shape
if
shape
is
None
:
return
tensor
# Canonicalize shape
if
not
isinstance
(
shape
,
Iterable
):
shape
=
[
shape
]
elif
len
(
shape
)
==
1
and
isinstance
(
shape
[
0
],
Iterable
):
shape
=
shape
[
0
]
if
-
1
in
shape
:
shape
=
list
(
shape
)
d_inferred
=
-
math
.
prod
(
tensor
.
shape
)
//
math
.
prod
(
shape
)
for
i
,
d
in
enumerate
(
shape
):
if
d
==
-
1
:
shape
[
i
]
=
d_inferred
break
if
list
(
shape
)
!=
list
(
tensor
.
shape
):
raise
NotImplementedError
(
"Reshape not implemented yet."
)
return
tensor
@
staticmethod
def
backward
(
ctx
,
grad
:
torch
.
Tensor
,
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
...]:
# pylint: disable=missing-function-docstring
if
isinstance
(
grad
,
Float8BlockwiseQTensor
):
raise
NotImplementedError
(
"Reshape bwd not implemented yet."
)
return
grad
.
view
(
ctx
.
shape
),
None
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
ab3e5a92
...
@@ -422,13 +422,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
...
@@ -422,13 +422,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
return
Float8Tensor
.
make_like
(
self
)
return
Float8Tensor
.
make_like
(
self
)
def
_create_transpose
(
self
):
data
=
self
.
_data
if
not
data
.
is_contiguous
():
data
=
data
.
contiguous
()
self
.
_transpose
=
tex
.
fp8_transpose
(
data
,
self
.
_fp8_dtype
,
out
=
self
.
_transpose
)
self
.
_transpose_invalid
=
False
def
update_usage
(
def
update_usage
(
self
,
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
...
...
transformer_engine/pytorch/tensor/mxfp8_tensor.py
View file @
ab3e5a92
...
@@ -347,6 +347,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
...
@@ -347,6 +347,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
columnwise_scale_inv
:
torch
.
Tensor
,
columnwise_scale_inv
:
torch
.
Tensor
,
fp8_dtype
:
TE_DType
,
fp8_dtype
:
TE_DType
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
shape
:
torch
.
shape
,
)
->
MXFP8Tensor
:
)
->
MXFP8Tensor
:
"""Build MXFP8Tensor, for use in __reduce__
"""Build MXFP8Tensor, for use in __reduce__
...
@@ -361,10 +362,11 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
...
@@ -361,10 +362,11 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
columnwise_data
=
columnwise_data
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
columnwise_scale_inv
=
columnwise_scale_inv
,
dtype
=
dtype
,
dtype
=
dtype
,
shape
=
shape
,
)
)
def
__reduce_ex__
(
self
,
protocol
:
int
)
->
tuple
:
def
__reduce_ex__
(
self
,
protocol
:
int
)
->
tuple
:
"""Custom pickling
to remove references to FP8 metadata objects
"""
"""Custom pickling"""
return
(
return
(
MXFP8Tensor
.
_make_in_reduce_ex
,
MXFP8Tensor
.
_make_in_reduce_ex
,
(
(
...
@@ -374,6 +376,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
...
@@ -374,6 +376,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
self
.
_columnwise_scale_inv
,
self
.
_columnwise_scale_inv
,
self
.
_fp8_dtype
,
self
.
_fp8_dtype
,
self
.
dtype
,
self
.
dtype
,
self
.
shape
,
),
),
)
)
...
...
transformer_engine/pytorch/tensor/quantized_tensor.py
View file @
ab3e5a92
...
@@ -37,7 +37,8 @@ def prepare_for_saving(
...
@@ -37,7 +37,8 @@ def prepare_for_saving(
def
restore_from_saved
(
def
restore_from_saved
(
tensors
:
list
[
Optional
[
Any
]],
tensors
:
list
[
Optional
[
Any
]],
saved_tensors
:
list
[
Optional
[
Union
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
]]],
saved_tensors
:
list
[
Optional
[
Union
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
]]],
)
->
list
[
Optional
[
Any
]]:
return_saved_tensors
:
bool
=
False
,
)
->
list
[
Optional
[
Any
]]
|
tuple
[
list
[
Optional
[
Any
]],
list
[
Optional
[
torch
.
Tensor
]]]:
"""Recombine the tensor data and metadata during backward pass."""
"""Recombine the tensor data and metadata during backward pass."""
tensor_objects
=
[]
tensor_objects
=
[]
for
tensor
in
tensors
:
for
tensor
in
tensors
:
...
@@ -47,6 +48,9 @@ def restore_from_saved(
...
@@ -47,6 +48,9 @@ def restore_from_saved(
else
:
else
:
saved_tensors
=
tensor
.
restore_from_saved
(
saved_tensors
)
saved_tensors
=
tensor
.
restore_from_saved
(
saved_tensors
)
tensor_objects
.
append
(
tensor
)
tensor_objects
.
append
(
tensor
)
if
return_saved_tensors
:
return
tensor_objects
,
saved_tensors
return
tensor_objects
return
tensor_objects
...
@@ -113,7 +117,11 @@ class Quantizer(abc.ABC):
...
@@ -113,7 +117,11 @@ class Quantizer(abc.ABC):
"""Quantize tensor in-place"""
"""Quantize tensor in-place"""
def
quantize
(
def
quantize
(
self
,
tensor
:
torch
.
Tensor
,
*
,
out
:
Optional
[
QuantizedTensor
]
=
None
self
,
tensor
:
torch
.
Tensor
,
*
,
out
:
Optional
[
QuantizedTensor
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
# pylint: disable=unused-argument # used by override
)
->
QuantizedTensor
:
)
->
QuantizedTensor
:
"""Quantize tensor"""
"""Quantize tensor"""
if
out
is
not
None
:
if
out
is
not
None
:
...
...
transformer_engine/pytorch/tensor/utils.py
View file @
ab3e5a92
...
@@ -39,7 +39,9 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor):
...
@@ -39,7 +39,9 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor):
raise
ValueError
(
f
"replace_raw_data for
{
type
(
tensor
)
}
is not supported yet"
)
raise
ValueError
(
f
"replace_raw_data for
{
type
(
tensor
)
}
is not supported yet"
)
def
cast_master_weights_to_fp8
(
model_weights
,
master_weights
,
start_offsets
,
group
):
def
cast_master_weights_to_fp8
(
model_weights
,
master_weights
,
start_offsets
,
group
,
fsdp_shard_model_weights
=
None
):
r
"""Helper function to cast master weights to FP8 primary weights.
r
"""Helper function to cast master weights to FP8 primary weights.
This is intended for use with ZeRO/FSDP. Each rank has a shard of
This is intended for use with ZeRO/FSDP. Each rank has a shard of
...
@@ -56,14 +58,23 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro
...
@@ -56,14 +58,23 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro
should be updated.
should be updated.
group : The distributed group to do amax reduction. Typically it's the data parallel
group : The distributed group to do amax reduction. Typically it's the data parallel
group.
group.
fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are
not sharded. Otherwise, it means that the model weights are sharded and we get
target model weights data storage using the FSDP shard model weights.
"""
"""
delayed_scaling_params
=
[]
delayed_scaling_params
=
[]
current_scaling_params
=
[]
current_scaling_params
=
[]
for
model_weight
,
master_weight
,
start_offset
in
zip
(
if
fsdp_shard_model_weights
is
None
:
model_weights
,
master_weights
,
start_offsets
use_fsdp_shard_model_weights
=
False
fsdp_shard_model_weights
=
[
None
]
*
len
(
model_weights
)
else
:
use_fsdp_shard_model_weights
=
True
for
model_weight
,
master_weight
,
start_offset
,
fsdp_shard_model_weight
in
zip
(
model_weights
,
master_weights
,
start_offsets
,
fsdp_shard_model_weights
):
):
# Clear `_high_precision_init_val` of model_weight automatically.
# Clear `_high_precision_init_val` of model_weight automatically.
# - Master weights are initialized from model weights, if we use fp8 primary weights to
# - Master weights are initialized from model weights, if we use fp8 primary weights to
...
@@ -89,9 +100,13 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro
...
@@ -89,9 +100,13 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro
quantizer
=
model_weight
.
_get_quantizer
()
quantizer
=
model_weight
.
_get_quantizer
()
if
isinstance
(
quantizer
,
Float8Quantizer
):
if
isinstance
(
quantizer
,
Float8Quantizer
):
delayed_scaling_params
.
append
((
model_weight
,
master_weight
,
start_offset
))
delayed_scaling_params
.
append
(
(
model_weight
,
master_weight
,
start_offset
,
fsdp_shard_model_weight
)
)
elif
isinstance
(
quantizer
,
Float8CurrentScalingQuantizer
):
elif
isinstance
(
quantizer
,
Float8CurrentScalingQuantizer
):
current_scaling_params
.
append
((
model_weight
,
master_weight
,
start_offset
))
current_scaling_params
.
append
(
(
model_weight
,
master_weight
,
start_offset
,
fsdp_shard_model_weight
)
)
elif
isinstance
(
quantizer
,
MXFP8Quantizer
):
elif
isinstance
(
quantizer
,
MXFP8Quantizer
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet"
"cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet"
...
@@ -102,12 +117,16 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro
...
@@ -102,12 +117,16 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro
)
)
if
len
(
delayed_scaling_params
)
>
0
:
if
len
(
delayed_scaling_params
)
>
0
:
_cast_master_weights_to_fp8_delayed_scaling
(
delayed_scaling_params
,
group
)
_cast_master_weights_to_fp8_delayed_scaling
(
delayed_scaling_params
,
group
,
use_fsdp_shard_model_weights
)
if
len
(
current_scaling_params
)
>
0
:
if
len
(
current_scaling_params
)
>
0
:
_cast_master_weights_to_fp8_current_scaling
(
current_scaling_params
,
group
)
_cast_master_weights_to_fp8_current_scaling
(
current_scaling_params
,
group
,
use_fsdp_shard_model_weights
)
def
_cast_master_weights_to_fp8_delayed_scaling
(
params
,
group
):
def
_cast_master_weights_to_fp8_delayed_scaling
(
params
,
group
,
use_fsdp_shard_model_weights
=
False
):
r
"""Helper function to cast master weights to FP8 primary weights for delayed scaling.
r
"""Helper function to cast master weights to FP8 primary weights for delayed scaling.
Parameters
Parameters
...
@@ -116,13 +135,14 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group):
...
@@ -116,13 +135,14 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group):
indicating the starting index of the master weight in the model weight.
indicating the starting index of the master weight in the model weight.
group : The distributed group to do amax reduction. Typically it's the data parallel
group : The distributed group to do amax reduction. Typically it's the data parallel
group.
group.
use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded.
"""
"""
# Collect amaxes to do reduce-max among dp group.
# Collect amaxes to do reduce-max among dp group.
# Collect scales and scale_invs to update scale_invs of the fp8 weights.
# Collect scales and scale_invs to update scale_invs of the fp8 weights.
amaxes
,
scales
,
scale_invs
=
[],
[],
[]
amaxes
,
scales
,
scale_invs
=
[],
[],
[]
for
model_weight
,
master_weight
,
start_offset
in
params
:
for
model_weight
,
master_weight
,
start_offset
,
shard_model_weight_raw
in
params
:
# Reset transpose cache for all model weights.
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap
# We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# the all-gather of model weights and forward process, so the model weight is not updated
...
@@ -148,6 +168,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group):
...
@@ -148,6 +168,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group):
# master_weight may be smaller than model_weight because it could be distributed across
# master_weight may be smaller than model_weight because it could be distributed across
# multiple ranks. So we need to create a dummy weight using the raw data from model_weight.
# multiple ranks. So we need to create a dummy weight using the raw data from model_weight.
if
not
use_fsdp_shard_model_weights
:
shard_model_weight_raw
=
model_weight
.
_data
.
view
(
-
1
)[
start_offset
:
end_offset
]
shard_model_weight_raw
=
model_weight
.
_data
.
view
(
-
1
)[
start_offset
:
end_offset
]
shard_model_weight_fp8
=
quantizer
.
create_tensor_from_data
(
shard_model_weight_fp8
=
quantizer
.
create_tensor_from_data
(
shard_model_weight_raw
.
view
(
1
,
-
1
),
shard_model_weight_raw
.
view
(
1
,
-
1
),
...
@@ -187,7 +208,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group):
...
@@ -187,7 +208,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group):
)
)
def
_cast_master_weights_to_fp8_current_scaling
(
params
,
group
):
def
_cast_master_weights_to_fp8_current_scaling
(
params
,
group
,
use_fsdp_shard_model_weights
=
False
):
r
"""Helper function to cast master weights to FP8 primary weights for current scaling.
r
"""Helper function to cast master weights to FP8 primary weights for current scaling.
Parameters
Parameters
...
@@ -196,6 +217,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
...
@@ -196,6 +217,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
indicating the starting index of the master weight in the model weight.
indicating the starting index of the master weight in the model weight.
group : The distributed group to do amax reduction. Typically it's the data parallel
group : The distributed group to do amax reduction. Typically it's the data parallel
group.
group.
use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded.
"""
"""
# Parameter attributes
# Parameter attributes
...
@@ -220,7 +242,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
...
@@ -220,7 +242,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
# amaxes in a contiguous buffer. If the master weight is None, the corresponding amax
# amaxes in a contiguous buffer. If the master weight is None, the corresponding amax
# will be set to 0.
# will be set to 0.
# ---------------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------------
for
(
model_weight
,
master_weight
,
_
),
amax
in
zip
(
params
,
amaxes
):
for
(
model_weight
,
master_weight
,
_
,
_
),
amax
in
zip
(
params
,
amaxes
):
# Make sure all the model weights have the same numerical options.
# Make sure all the model weights have the same numerical options.
quantizer
=
model_weight
.
_get_quantizer
()
quantizer
=
model_weight
.
_get_quantizer
()
...
@@ -261,7 +283,9 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
...
@@ -261,7 +283,9 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
# ---------------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------------
# Step 4: Cast master weights to FP8.
# Step 4: Cast master weights to FP8.
# ---------------------------------------------------------------------------------------------
# ---------------------------------------------------------------------------------------------
for
(
model_weight
,
master_weight
,
start_offset
),
scale
in
zip
(
params
,
scales
):
for
(
model_weight
,
master_weight
,
start_offset
,
model_weight_fragment
),
scale
in
zip
(
params
,
scales
):
# Reset transpose cache for all model weights.
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap
# We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# the all-gather of model weights and forward process, so the model weight is not updated
...
@@ -275,10 +299,18 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
...
@@ -275,10 +299,18 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
# Cast master weight to FP8
# Cast master weight to FP8
end_offset
=
start_offset
+
master_weight
.
numel
()
end_offset
=
start_offset
+
master_weight
.
numel
()
if
not
use_fsdp_shard_model_weights
:
model_weight_fragment
=
model_weight
.
reshape
(
-
1
)[
start_offset
:
end_offset
]
model_weight_fragment
=
model_weight
.
reshape
(
-
1
)[
start_offset
:
end_offset
]
quantizer
=
Float8Quantizer
(
quantizer
=
Float8Quantizer
(
scale
=
scale
,
scale
=
scale
,
amax
=
torch
.
Tensor
(),
amax
=
torch
.
Tensor
(),
fp8_dtype
=
model_weight
.
_fp8_dtype
,
fp8_dtype
=
model_weight
.
_fp8_dtype
,
)
)
if
use_fsdp_shard_model_weights
and
not
isinstance
(
model_weight_fragment
,
Float8Tensor
):
# NOTE: The fsdp shard model weight may be a unit8 tensor instead of
# a float8 tensor. We should handle this situation properly.
model_weight_fragment
=
quantizer
.
create_tensor_from_data
(
model_weight_fragment
.
view
(
-
1
),
model_weight
.
dtype
,
)
quantizer
.
update_quantized
(
master_weight
,
model_weight_fragment
)
quantizer
.
update_quantized
(
master_weight
,
model_weight_fragment
)
transformer_engine/pytorch/transformer.py
View file @
ab3e5a92
...
@@ -11,6 +11,7 @@ from typing import Callable, List, Optional, Tuple, Union
...
@@ -11,6 +11,7 @@ from typing import Callable, List, Optional, Tuple, Union
import
torch
import
torch
from
transformer_engine.pytorch.module
import
LayerNormMLP
,
LayerNorm
,
RMSNorm
from
transformer_engine.pytorch.module
import
LayerNormMLP
,
LayerNorm
,
RMSNorm
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
from
transformer_engine.pytorch.attention
import
(
from
transformer_engine.pytorch.attention
import
(
MultiheadAttention
,
MultiheadAttention
,
)
)
...
@@ -33,6 +34,7 @@ from transformer_engine.pytorch.constants import (
...
@@ -33,6 +34,7 @@ from transformer_engine.pytorch.constants import (
dist_group_type
,
dist_group_type
,
)
)
from
transformer_engine.pytorch.distributed
import
get_distributed_world_size
from
transformer_engine.pytorch.distributed
import
get_distributed_world_size
from
transformer_engine.pytorch.module.base
import
TransformerEngineBaseModule
warnings
.
filterwarnings
(
"module"
,
category
=
DeprecationWarning
,
module
=
"transformer"
)
warnings
.
filterwarnings
(
"module"
,
category
=
DeprecationWarning
,
module
=
"transformer"
)
...
@@ -184,6 +186,8 @@ class TransformerLayer(torch.nn.Module):
...
@@ -184,6 +186,8 @@ class TransformerLayer(torch.nn.Module):
head size. Note that these formats are very closely
head size. Note that these formats are very closely
related to the `qkv_format` in the `MultiHeadAttention`
related to the `qkv_format` in the `MultiHeadAttention`
and `DotProductAttention` modules.
and `DotProductAttention` modules.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters
Parallelism parameters
----------------------
----------------------
...
@@ -277,6 +281,7 @@ class TransformerLayer(torch.nn.Module):
...
@@ -277,6 +281,7 @@ class TransformerLayer(torch.nn.Module):
normalization
:
str
=
"LayerNorm"
,
normalization
:
str
=
"LayerNorm"
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
attn_input_format
:
str
=
"sbhd"
,
attn_input_format
:
str
=
"sbhd"
,
name
:
str
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -336,6 +341,8 @@ class TransformerLayer(torch.nn.Module):
...
@@ -336,6 +341,8 @@ class TransformerLayer(torch.nn.Module):
self
.
attn_input_format
=
attn_input_format
self
.
attn_input_format
=
attn_input_format
self
.
name
=
name
attention_args
=
(
attention_args
=
(
hidden_size
,
hidden_size
,
num_attention_heads
,
num_attention_heads
,
...
@@ -376,6 +383,7 @@ class TransformerLayer(torch.nn.Module):
...
@@ -376,6 +383,7 @@ class TransformerLayer(torch.nn.Module):
return_bias
=
not
self
.
parallel_attention_mlp
,
return_bias
=
not
self
.
parallel_attention_mlp
,
normalization
=
normalization
,
normalization
=
normalization
,
device
=
device
,
device
=
device
,
name
=
name
+
".self_attention"
if
name
is
not
None
else
None
,
)
)
if
layer_type
==
"decoder"
:
if
layer_type
==
"decoder"
:
...
@@ -389,6 +397,7 @@ class TransformerLayer(torch.nn.Module):
...
@@ -389,6 +397,7 @@ class TransformerLayer(torch.nn.Module):
return_bias
=
True
,
return_bias
=
True
,
normalization
=
normalization
,
normalization
=
normalization
,
device
=
device
,
device
=
device
,
name
=
name
+
".inter_attention"
if
name
is
not
None
else
None
,
)
)
# LayerNorm -> activation(Linear + Bias) -> Linear
# LayerNorm -> activation(Linear + Bias) -> Linear
...
@@ -423,6 +432,7 @@ class TransformerLayer(torch.nn.Module):
...
@@ -423,6 +432,7 @@ class TransformerLayer(torch.nn.Module):
activation
=
activation
,
activation
=
activation
,
normalization
=
normalization
,
normalization
=
normalization
,
device
=
device
,
device
=
device
,
name
=
name
+
".layernorm_mlp"
if
name
is
not
None
else
None
,
)
)
self
.
hidden_dropout
=
hidden_dropout
self
.
hidden_dropout
=
hidden_dropout
...
@@ -679,6 +689,9 @@ class TransformerLayer(torch.nn.Module):
...
@@ -679,6 +689,9 @@ class TransformerLayer(torch.nn.Module):
enc_dec_attn_mask
[
i
].
dtype
==
torch
.
bool
for
i
in
range
(
len
(
enc_dec_attn_mask
))
enc_dec_attn_mask
[
i
].
dtype
==
torch
.
bool
for
i
in
range
(
len
(
enc_dec_attn_mask
))
),
"Encoder-decoder attention mask must be boolean tensor(s)"
),
"Encoder-decoder attention mask must be boolean tensor(s)"
if
TEDebugState
.
debug_enabled
:
TransformerEngineBaseModule
.
_validate_name
(
self
)
# For AMP
# For AMP
if
torch
.
is_autocast_enabled
():
if
torch
.
is_autocast_enabled
():
hidden_states
=
cast_if_needed
(
hidden_states
,
torch
.
get_autocast_gpu_dtype
())
hidden_states
=
cast_if_needed
(
hidden_states
,
torch
.
get_autocast_gpu_dtype
())
...
...
transformer_engine/pytorch/triton/permutation.py
View file @
ab3e5a92
...
@@ -10,15 +10,6 @@ import torch
...
@@ -10,15 +10,6 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
transformer_engine_torch
import
DType
as
TE_DType
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
if
IS_HIP_EXTENSION
:
e5m2_data_type
=
tl
.
float8e5b16
e4m3_data_type
=
tl
.
float8e4b8
else
:
e5m2_data_type
=
tl
.
float8e5
e4m3_data_type
=
tl
.
float8e4nv
@
triton
.
jit
@
triton
.
jit
def
_row_id_map_pass_1_kernel
(
def
_row_id_map_pass_1_kernel
(
...
@@ -123,11 +114,14 @@ def _permute_kernel(
...
@@ -123,11 +114,14 @@ def _permute_kernel(
output_ptr
,
output_ptr
,
row_id_map_ptr
,
row_id_map_ptr
,
probs_ptr
,
probs_ptr
,
scale_ptr
,
permuted_probs_ptr
,
permuted_probs_ptr
,
permuted_scale_ptr
,
# sizes
# sizes
num_tokens
,
num_tokens
,
num_experts
,
num_experts
,
hidden_size
,
hidden_size
,
scale_hidden_dim
,
# strides
# strides
stride_input_token
,
stride_input_token
,
stride_input_hidden
,
stride_input_hidden
,
...
@@ -135,9 +129,14 @@ def _permute_kernel(
...
@@ -135,9 +129,14 @@ def _permute_kernel(
stride_output_hidden
,
stride_output_hidden
,
stride_probs_token
,
stride_probs_token
,
stride_probs_expert
,
stride_probs_expert
,
stride_scale_token
,
stride_scale_hidden
,
stride_permuted_probs_token
,
stride_permuted_probs_token
,
stride_permuted_scale_token
,
stride_permuted_scale_hidden
,
# metas
# metas
PERMUTE_PROBS
:
tl
.
constexpr
,
PERMUTE_PROBS
:
tl
.
constexpr
,
PERMUTE_SCALE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
pid
=
tl
.
program_id
(
0
)
pid
=
tl
.
program_id
(
0
)
...
@@ -147,11 +146,21 @@ def _permute_kernel(
...
@@ -147,11 +146,21 @@ def _permute_kernel(
mask
=
cur_off
<
hidden_size
mask
=
cur_off
<
hidden_size
input_off
=
pid
*
stride_input_token
+
cur_off
*
stride_input_hidden
input_off
=
pid
*
stride_input_token
+
cur_off
*
stride_input_hidden
inp
=
tl
.
load
(
input_ptr
+
input_off
,
mask
=
mask
)
inp
=
tl
.
load
(
input_ptr
+
input_off
,
mask
=
mask
)
if
PERMUTE_SCALE
:
mask_scale
=
cur_off
<
scale_hidden_dim
scale_off
=
pid
*
stride_scale_token
+
cur_off
*
stride_scale_hidden
scale
=
tl
.
load
(
scale_ptr
+
scale_off
,
mask
=
mask_scale
)
for
expert_idx
in
range
(
num_experts
):
for
expert_idx
in
range
(
num_experts
):
dst_row
=
tl
.
load
(
row_id_map_ptr
+
expert_idx
*
num_tokens
+
pid
)
dst_row
=
tl
.
load
(
row_id_map_ptr
+
expert_idx
*
num_tokens
+
pid
)
if
dst_row
!=
-
1
:
if
dst_row
!=
-
1
:
output_off
=
dst_row
*
stride_output_token
+
cur_off
*
stride_output_hidden
output_off
=
dst_row
*
stride_output_token
+
cur_off
*
stride_output_hidden
tl
.
store
(
output_ptr
+
output_off
,
inp
,
mask
=
mask
)
tl
.
store
(
output_ptr
+
output_off
,
inp
,
mask
=
mask
)
if
PERMUTE_SCALE
:
permuted_scale_off
=
(
dst_row
*
stride_permuted_scale_token
+
cur_off
*
stride_permuted_scale_hidden
)
tl
.
store
(
permuted_scale_ptr
+
permuted_scale_off
,
scale
,
mask
=
mask_scale
)
if
PERMUTE_PROBS
:
if
PERMUTE_PROBS
:
if
cur_pos
==
0
:
if
cur_pos
==
0
:
prob_off
=
pid
*
stride_probs_token
+
expert_idx
*
stride_probs_expert
prob_off
=
pid
*
stride_probs_token
+
expert_idx
*
stride_probs_expert
...
@@ -180,10 +189,12 @@ def permute_with_mask_map(
...
@@ -180,10 +189,12 @@ def permute_with_mask_map(
inp
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
row_id_map
:
torch
.
Tensor
,
row_id_map
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
num_tokens
:
int
,
num_tokens
:
int
,
num_experts
:
int
,
num_experts
:
int
,
num_out_tokens
:
int
,
num_out_tokens
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
scale_hidden_dim
:
int
,
):
):
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
output
=
torch
.
empty
((
num_out_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
,
device
=
"cuda"
)
output
=
torch
.
empty
((
num_out_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
,
device
=
"cuda"
)
...
@@ -191,26 +202,42 @@ def permute_with_mask_map(
...
@@ -191,26 +202,42 @@ def permute_with_mask_map(
permuted_probs
=
torch
.
empty
((
num_out_tokens
,),
dtype
=
probs
.
dtype
,
device
=
"cuda"
)
permuted_probs
=
torch
.
empty
((
num_out_tokens
,),
dtype
=
probs
.
dtype
,
device
=
"cuda"
)
else
:
else
:
permuted_probs
=
None
permuted_probs
=
None
if
scale
is
not
None
:
permuted_scale
=
torch
.
empty
(
(
num_out_tokens
,
scale_hidden_dim
),
dtype
=
scale
.
dtype
,
device
=
"cuda"
)
else
:
permuted_scale
=
None
grid
=
(
num_tokens
,)
grid
=
(
num_tokens
,)
_permute_kernel
[
grid
](
_permute_kernel
[
grid
](
inp
,
inp
,
output
,
output
,
row_id_map
,
row_id_map
,
probs
,
probs
,
scale
,
permuted_probs
,
permuted_probs
,
permuted_scale
,
num_tokens
,
num_tokens
,
num_experts
,
num_experts
,
hidden_size
,
hidden_size
,
scale_hidden_dim
,
inp
.
stride
(
0
),
inp
.
stride
(
0
),
inp
.
stride
(
1
),
inp
.
stride
(
1
),
output
.
stride
(
0
),
output
.
stride
(
0
),
output
.
stride
(
1
),
output
.
stride
(
1
),
probs
.
stride
(
0
)
if
probs
is
not
None
else
None
,
probs
.
stride
(
0
)
if
probs
is
not
None
else
None
,
probs
.
stride
(
1
)
if
probs
is
not
None
else
None
,
probs
.
stride
(
1
)
if
probs
is
not
None
else
None
,
scale
.
stride
(
0
)
if
scale
is
not
None
else
None
,
scale
.
stride
(
1
)
if
scale
is
not
None
else
None
,
permuted_probs
.
stride
(
0
)
if
permuted_probs
is
not
None
else
None
,
permuted_probs
.
stride
(
0
)
if
permuted_probs
is
not
None
else
None
,
permuted_scale
.
stride
(
0
)
if
permuted_scale
is
not
None
else
None
,
permuted_scale
.
stride
(
1
)
if
permuted_scale
is
not
None
else
None
,
PERMUTE_PROBS
=
probs
is
not
None
,
PERMUTE_PROBS
=
probs
is
not
None
,
PERMUTE_SCALE
=
scale
is
not
None
,
)
)
return
output
,
permuted_probs
return
output
,
permuted_scale
,
permuted_probs
@
triton
.
jit
@
triton
.
jit
...
@@ -239,18 +266,9 @@ def _unpermute_kernel(
...
@@ -239,18 +266,9 @@ def _unpermute_kernel(
# metas
# metas
WITH_MERGING_PROBS
:
tl
.
constexpr
,
WITH_MERGING_PROBS
:
tl
.
constexpr
,
PERMUTE_PROBS
:
tl
.
constexpr
,
PERMUTE_PROBS
:
tl
.
constexpr
,
FP8_DTYPE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
if
FP8_DTYPE
==
"e5m2"
:
data_type
=
tl
.
float8e5
pytorch_tensor_dtype
=
tl
.
uint8
elif
FP8_DTYPE
==
"e4m3"
:
data_type
=
tl
.
float8e4nv
pytorch_tensor_dtype
=
tl
.
uint8
else
:
data_type
=
input_ptr
.
dtype
.
element_ty
data_type
=
input_ptr
.
dtype
.
element_ty
assert
FP8_DTYPE
is
None
compute_type
=
tl
.
float32
compute_type
=
tl
.
float32
pid
=
tl
.
program_id
(
0
)
pid
=
tl
.
program_id
(
0
)
...
@@ -264,8 +282,6 @@ def _unpermute_kernel(
...
@@ -264,8 +282,6 @@ def _unpermute_kernel(
if
src_row
!=
-
1
:
if
src_row
!=
-
1
:
input_off
=
src_row
*
stride_input_token
+
current_offset
*
stride_input_hidden
input_off
=
src_row
*
stride_input_token
+
current_offset
*
stride_input_hidden
inp
=
tl
.
load
(
input_ptr
+
input_off
,
mask
=
mask
)
inp
=
tl
.
load
(
input_ptr
+
input_off
,
mask
=
mask
)
if
FP8_DTYPE
is
not
None
:
inp
=
inp
.
to
(
data_type
,
bitcast
=
True
)
inp
=
inp
.
to
(
compute_type
)
inp
=
inp
.
to
(
compute_type
)
if
WITH_MERGING_PROBS
:
if
WITH_MERGING_PROBS
:
merging_prob_off
=
(
merging_prob_off
=
(
...
@@ -286,13 +302,6 @@ def _unpermute_kernel(
...
@@ -286,13 +302,6 @@ def _unpermute_kernel(
tl
.
store
(
unpermuted_probs_ptr
+
unpermuted_prob_off
,
prob
)
tl
.
store
(
unpermuted_probs_ptr
+
unpermuted_prob_off
,
prob
)
else
:
else
:
tl
.
store
(
unpermuted_probs_ptr
+
unpermuted_prob_off
,
0.0
)
tl
.
store
(
unpermuted_probs_ptr
+
unpermuted_prob_off
,
0.0
)
if
FP8_DTYPE
is
not
None
:
if
not
WITH_MERGING_PROBS
:
# Directly adding these value may cause overflow for fp8, we scale it here.
# The outside fp8_scale_inv is also scaled in the meantime.
accumulator
/=
num_experts
accumulator
=
accumulator
.
to
(
data_type
).
to
(
pytorch_tensor_dtype
,
bitcast
=
True
)
else
:
accumulator
=
accumulator
.
to
(
data_type
)
accumulator
=
accumulator
.
to
(
data_type
)
output_off
=
pid
*
stride_output_token
+
current_offset
*
stride_output_hidden
output_off
=
pid
*
stride_output_token
+
current_offset
*
stride_output_hidden
tl
.
store
(
output_ptr
+
output_off
,
accumulator
,
mask
=
mask
)
tl
.
store
(
output_ptr
+
output_off
,
accumulator
,
mask
=
mask
)
...
@@ -322,15 +331,8 @@ def unpermute_with_mask_map(
...
@@ -322,15 +331,8 @@ def unpermute_with_mask_map(
num_tokens
:
int
,
num_tokens
:
int
,
num_experts
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
fp8_dtype
:
TE_DType
,
):
):
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
if
fp8_dtype
==
TE_DType
.
kFloat8E5M2
:
fp8_dtype
=
"e5m2"
elif
fp8_dtype
==
TE_DType
.
kFloat8E4M3
:
fp8_dtype
=
"e4m3"
else
:
fp8_dtype
=
None
output
=
torch
.
empty
((
num_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
,
device
=
"cuda"
)
output
=
torch
.
empty
((
num_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
,
device
=
"cuda"
)
if
permuted_probs
is
not
None
:
if
permuted_probs
is
not
None
:
unpermuted_probs
=
torch
.
empty
(
unpermuted_probs
=
torch
.
empty
(
...
@@ -360,7 +362,6 @@ def unpermute_with_mask_map(
...
@@ -360,7 +362,6 @@ def unpermute_with_mask_map(
unpermuted_probs
.
stride
(
1
)
if
unpermuted_probs
is
not
None
else
None
,
unpermuted_probs
.
stride
(
1
)
if
unpermuted_probs
is
not
None
else
None
,
WITH_MERGING_PROBS
=
merging_probs
is
not
None
,
WITH_MERGING_PROBS
=
merging_probs
is
not
None
,
PERMUTE_PROBS
=
permuted_probs
is
not
None
,
PERMUTE_PROBS
=
permuted_probs
is
not
None
,
FP8_DTYPE
=
fp8_dtype
,
)
)
return
output
,
unpermuted_probs
return
output
,
unpermuted_probs
...
@@ -390,18 +391,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
...
@@ -390,18 +391,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
stride_merging_probs_grad_token
,
stride_merging_probs_grad_token
,
stride_merging_probs_grad_expert
,
stride_merging_probs_grad_expert
,
# metas
# metas
FP8_DTYPE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
if
FP8_DTYPE
==
"e5m2"
:
data_type
=
tl
.
float8e5
pytorch_tensor_dtype
=
tl
.
uint8
elif
FP8_DTYPE
==
"e4m3"
:
data_type
=
tl
.
float8e4nv
pytorch_tensor_dtype
=
tl
.
uint8
else
:
data_type
=
fwd_output_grad_ptr
.
dtype
.
element_ty
data_type
=
fwd_output_grad_ptr
.
dtype
.
element_ty
assert
FP8_DTYPE
is
None
compute_type
=
tl
.
float32
compute_type
=
tl
.
float32
pid
=
tl
.
program_id
(
0
)
pid
=
tl
.
program_id
(
0
)
...
@@ -418,8 +410,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
...
@@ -418,8 +410,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
+
current_offset
*
stride_fwd_output_grad_hidden
+
current_offset
*
stride_fwd_output_grad_hidden
)
)
inp
=
tl
.
load
(
fwd_output_grad_ptr
+
input_off
,
mask
=
mask
)
inp
=
tl
.
load
(
fwd_output_grad_ptr
+
input_off
,
mask
=
mask
)
if
FP8_DTYPE
is
not
None
:
inp
=
inp
.
to
(
data_type
,
bitcast
=
True
)
inp
=
inp
.
to
(
compute_type
)
inp
=
inp
.
to
(
compute_type
)
merging_prob_off
=
(
merging_prob_off
=
(
pid
*
stride_merging_probs_token
+
expert_idx
*
stride_merging_probs_expert
pid
*
stride_merging_probs_token
+
expert_idx
*
stride_merging_probs_expert
...
@@ -427,8 +417,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
...
@@ -427,8 +417,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
merging_prob
=
tl
.
load
(
merging_probs_ptr
+
merging_prob_off
).
to
(
compute_type
)
merging_prob
=
tl
.
load
(
merging_probs_ptr
+
merging_prob_off
).
to
(
compute_type
)
output
=
inp
*
merging_prob
output
=
inp
*
merging_prob
output
=
output
.
to
(
data_type
)
output
=
output
.
to
(
data_type
)
if
FP8_DTYPE
is
not
None
:
output
=
output
.
to
(
pytorch_tensor_dtype
,
bitcast
=
True
)
output_off
=
(
output_off
=
(
dst_row
*
stride_fwd_input_grad_token
dst_row
*
stride_fwd_input_grad_token
+
current_offset
*
stride_fwd_input_grad_hidden
+
current_offset
*
stride_fwd_input_grad_hidden
...
@@ -439,8 +427,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
...
@@ -439,8 +427,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
dst_row
*
stride_fwd_input_token
+
current_offset
*
stride_fwd_input_hidden
dst_row
*
stride_fwd_input_token
+
current_offset
*
stride_fwd_input_hidden
)
)
fwd_input
=
tl
.
load
(
fwd_input_ptr
+
fwd_input_off
,
mask
=
mask
)
fwd_input
=
tl
.
load
(
fwd_input_ptr
+
fwd_input_off
,
mask
=
mask
)
if
FP8_DTYPE
is
not
None
:
fwd_input
=
fwd_input
.
to
(
data_type
,
bitcast
=
True
)
prob_grad_accum
+=
fwd_input
.
to
(
compute_type
)
*
inp
prob_grad_accum
+=
fwd_input
.
to
(
compute_type
)
*
inp
current_start
+=
BLOCK_SIZE
current_start
+=
BLOCK_SIZE
probs_grad
=
tl
.
sum
(
prob_grad_accum
).
to
(
merging_probs_grad_ptr
.
dtype
.
element_ty
)
probs_grad
=
tl
.
sum
(
prob_grad_accum
).
to
(
merging_probs_grad_ptr
.
dtype
.
element_ty
)
...
@@ -481,15 +467,8 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
...
@@ -481,15 +467,8 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
num_experts
:
int
,
num_experts
:
int
,
num_out_tokens
:
int
,
num_out_tokens
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
fp8_dtype
:
TE_DType
,
):
):
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
if
fp8_dtype
==
TE_DType
.
kFloat8E5M2
:
fp8_dtype
=
"e5m2"
elif
fp8_dtype
==
TE_DType
.
kFloat8E4M3
:
fp8_dtype
=
"e4m3"
else
:
fp8_dtype
=
None
act_grad
=
torch
.
empty
(
act_grad
=
torch
.
empty
(
(
num_out_tokens
,
hidden_size
),
dtype
=
fwd_output_grad
.
dtype
,
device
=
"cuda"
(
num_out_tokens
,
hidden_size
),
dtype
=
fwd_output_grad
.
dtype
,
device
=
"cuda"
)
)
...
@@ -517,7 +496,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
...
@@ -517,7 +496,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
merging_probs
.
stride
(
1
),
merging_probs
.
stride
(
1
),
merging_probs_grad
.
stride
(
0
),
merging_probs_grad
.
stride
(
0
),
merging_probs_grad
.
stride
(
1
),
merging_probs_grad
.
stride
(
1
),
fp8_dtype
,
)
)
return
act_grad
,
merging_probs_grad
return
act_grad
,
merging_probs_grad
...
...
transformer_engine/pytorch/utils.py
View file @
ab3e5a92
...
@@ -11,6 +11,7 @@ from typing import Any, Callable, List, Optional, Tuple
...
@@ -11,6 +11,7 @@ from typing import Any, Callable, List, Optional, Tuple
import
torch
import
torch
import
transformer_engine.pytorch.cpp_extensions
as
ext
import
transformer_engine.pytorch.cpp_extensions
as
ext
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
from
.tensor.quantized_tensor
import
QuantizedTensor
from
.tensor.quantized_tensor
import
QuantizedTensor
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
@@ -354,6 +355,19 @@ def round_up_to_nearest_multiple(value, multiple):
...
@@ -354,6 +355,19 @@ def round_up_to_nearest_multiple(value, multiple):
return
((
value
+
multiple
-
1
)
//
multiple
)
*
multiple
return
((
value
+
multiple
-
1
)
//
multiple
)
*
multiple
def
needs_quantized_gemm
(
obj
,
rowwise
=
True
):
"""Used to check if obj will need quantized gemm or normal gemm."""
if
isinstance
(
obj
,
DebugQuantizedTensor
):
return
type
(
obj
.
get_tensor
(
not
rowwise
))
not
in
[
# pylint: disable=unidiomatic-typecheck
torch
.
Tensor
,
torch
.
nn
.
Parameter
,
]
return
type
(
obj
)
not
in
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
,
]
# pylint: disable=unidiomatic-typecheck
@
functools
.
lru_cache
(
maxsize
=
None
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_nvtx_enabled
()
->
bool
:
def
_nvtx_enabled
()
->
bool
:
"""Check if NVTX range profiling is enabled"""
"""Check if NVTX range profiling is enabled"""
...
...
Prev
1
…
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment