Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
c1a1c04e
Commit
c1a1c04e
authored
Dec 27, 2025
by
wenjh
Browse files
Merge nv_main(2.10) to main
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
e698a0a7
66aed3ae
Changes
208
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
102 additions
and
967 deletions
+102
-967
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
...mer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
+4
-32
transformer_engine/pytorch/tensor/utils.py
transformer_engine/pytorch/tensor/utils.py
+69
-40
transformer_engine/pytorch/transformer.py
transformer_engine/pytorch/transformer.py
+8
-1
transformer_engine/pytorch/triton/__init__.py
transformer_engine/pytorch/triton/__init__.py
+1
-1
transformer_engine/pytorch/triton/cross_entropy.py
transformer_engine/pytorch/triton/cross_entropy.py
+6
-246
transformer_engine/pytorch/triton/pad.py
transformer_engine/pytorch/triton/pad.py
+2
-53
transformer_engine/pytorch/triton/permutation.py
transformer_engine/pytorch/triton/permutation.py
+11
-593
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+1
-1
No files found.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
View file @
c1a1c04e
...
...
@@ -13,13 +13,12 @@ import warnings
import
torch
#
import transformer_engine_torch as tex
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
..quantized_tensor
import
QuantizedTensorStorage
from
..
.
quantized_tensor
import
QuantizedTensorStorage
,
Quantizer
# from ...constants import TE_DType as torch_to_transformer_engine_dtype
from
..quantized_tensor
import
Quantizer
from
...constants
import
TE_DType
as
torch_to_transformer_engine_dtype
from
...utils
import
_empty_tensor
...
...
@@ -46,34 +45,7 @@ class _FromNVFP4Func(torch.autograd.Function):
# Dequantize row-wise data
if
tensor
.
_rowwise_data
is
not
None
:
### TODO(tmoon): Debug dequantize kernel and remove unfused impl
# return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype])
# Tensor properties
shape
=
list
(
tensor
.
_rowwise_data
.
size
())
shape
[
-
1
]
*=
2
device
=
tensor
.
_rowwise_data
.
device
# Convert FP4E2M1 values to FP32
data
=
tensor
.
_rowwise_data
.
view
(
torch
.
uint8
).
to
(
torch
.
int32
)
data
=
torch
.
stack
((
data
&
0x0F
,
data
>>
4
),
dim
=-
1
).
reshape
(
shape
)
data
=
_fp4_e2m1_vals
(
device
,
dtype
=
torch
.
float32
)[
data
]
data
=
data
.
to
(
torch
.
float32
).
contiguous
()
# Convert FP8E4M3 block scales to FP32
block_scales
=
tensor
.
_rowwise_scale_inv
block_scales
=
block_scales
.
reshape
(
-
1
,
block_scales
.
size
(
-
1
))
block_scales
=
block_scales
[:
math
.
prod
(
shape
[:
-
1
]),
:
shape
[
-
1
]
//
16
]
block_scales
=
block_scales
.
view
(
torch
.
float8_e4m3fn
).
to
(
torch
.
float32
)
# Convert amax to FP32 tensor scale
tensor_scale
=
tensor
.
_amax_rowwise
/
(
6.0
*
448.0
)
# Scale by FP4E2M1 and FP8E4M3 max
# Apply scales
block_data
=
data
.
view
(
-
1
,
16
)
block_data
*=
tensor_scale
.
view
(())
*
block_scales
.
reshape
(
-
1
,
1
)
return
data
.
to
(
dtype
)
return
tex
.
dequantize
(
tensor
,
torch_to_transformer_engine_dtype
[
dtype
])
if
tensor
.
_columnwise_data
is
not
None
:
raise
NotImplementedError
(
"Dequantizing column-wise NVFP4 data is not implemented yet!"
)
...
...
transformer_engine/pytorch/tensor/utils.py
View file @
c1a1c04e
...
...
@@ -4,18 +4,18 @@
"""Helper functions for using fp8 tensors as weights"""
import
os
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
,
List
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
multi_tensor_scale
,
multi_tensor_compute_scale_and_scale_inv
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
QuantizedTensorStorage
from
.
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
QuantizedTensorStorage
from
.float8_tensor
import
Float8Tensor
,
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
.mxfp8_tensor
import
MXFP8Tensor
,
MXFP8Quantizer
from
.float8_blockwise_tensor
import
Float8BlockwiseQTensor
,
Float8BlockQuantizer
from
..optimizers.multi_tensor_apply
import
multi_tensor_applier
from
..utils
import
is_non_tn_fp8_gemm_supported
def
replace_raw_data
(
tensor
:
QuantizedTensor
,
new_raw_data
:
torch
.
Tensor
):
...
...
@@ -48,7 +48,12 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor):
def
cast_master_weights_to_fp8
(
model_weights
,
master_weights
,
start_offsets
,
group
,
fsdp_shard_model_weights
=
None
model_weights
,
master_weights
,
start_offsets
,
group
,
fsdp_shard_model_weights
=
None
,
manual_post_all_gather_processing
=
False
,
):
r
"""Helper function to cast master weights to FP8 primary weights.
...
...
@@ -69,6 +74,11 @@ def cast_master_weights_to_fp8(
fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are
not sharded. Otherwise, it means that the model weights are sharded and we get
target model weights data storage using the FSDP shard model weights.
manual_post_all_gather_processing: bool, default = `False`.
If False, post processing will be automatically triggered during next forward.
If True, the timing of calling post_all_gather_processing is left to the user.
Note that users must call `post_all_gather_processing` if it's set to True,
otherwise the weights won't be updated correctly.
"""
...
...
@@ -129,21 +139,18 @@ def cast_master_weights_to_fp8(
f
"cast_master_weights_to_fp8 for
{
type
(
quantizer
)
}
is not supported yet"
)
extra_args
=
[
group
,
use_fsdp_shard_model_weights
,
manual_post_all_gather_processing
]
if
len
(
delayed_scaling_params
)
>
0
:
_cast_master_weights_to_fp8_delayed_scaling
(
delayed_scaling_params
,
group
,
use_fsdp_shard_model_weights
)
_cast_master_weights_to_fp8_delayed_scaling
(
delayed_scaling_params
,
*
extra_args
)
if
len
(
current_scaling_params
)
>
0
:
_cast_master_weights_to_fp8_current_scaling
(
current_scaling_params
,
group
,
use_fsdp_shard_model_weights
)
_cast_master_weights_to_fp8_current_scaling
(
current_scaling_params
,
*
extra_args
)
if
len
(
blockwise_scaling_params
)
>
0
:
_cast_master_weights_to_fp8_blockwise_scaling
(
blockwise_scaling_params
,
group
,
use_fsdp_shard_model_weights
)
_cast_master_weights_to_fp8_blockwise_scaling
(
blockwise_scaling_params
,
*
extra_args
)
def
_cast_master_weights_to_fp8_delayed_scaling
(
params
,
group
,
use_fsdp_shard_model_weights
=
False
):
def
_cast_master_weights_to_fp8_delayed_scaling
(
params
,
group
,
use_fsdp_shard_model_weights
=
False
,
manual_post_all_gather_processing
=
False
):
r
"""Helper function to cast master weights to FP8 primary weights for delayed scaling.
Parameters
...
...
@@ -160,11 +167,12 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo
amaxes
,
scales
,
scale_invs
=
[],
[],
[]
for
model_weight
,
master_weight
,
start_offset
,
shard_model_weight_raw
in
params
:
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# currently.
model_weight
.
_reset_caches
()
if
not
manual_post_all_gather_processing
:
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to
# overlap the all-gather of model weights and forward process, so the model weight is
# not updated currently.
model_weight
.
_reset_caches
()
quantizer
=
model_weight
.
_get_quantizer
()
...
...
@@ -225,7 +233,9 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo
)
def
_cast_master_weights_to_fp8_current_scaling
(
params
,
group
,
use_fsdp_shard_model_weights
=
False
):
def
_cast_master_weights_to_fp8_current_scaling
(
params
,
group
,
use_fsdp_shard_model_weights
=
False
,
manual_post_all_gather_processing
=
False
):
r
"""Helper function to cast master weights to FP8 primary weights for current scaling.
Parameters
...
...
@@ -305,11 +315,12 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
for
(
model_weight
,
master_weight
,
start_offset
,
model_weight_fragment
),
scale
in
zip
(
params
,
scales
):
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# currently.
model_weight
.
_reset_caches
()
if
not
manual_post_all_gather_processing
:
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to
# overlap the all-gather of model weights and forward process, so the model weight is
# not updated currently.
model_weight
.
_reset_caches
()
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
...
...
@@ -336,7 +347,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
def
_cast_master_weights_to_fp8_blockwise_scaling
(
params
,
group
,
use_fsdp_shard_model_weights
=
False
params
,
group
,
use_fsdp_shard_model_weights
=
False
,
manual_post_all_gather_processing
=
False
):
r
"""Helper function to cast master weights to FP8 primary weights for blockwise scaling.
...
...
@@ -437,11 +448,12 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
for
(
model_weight
,
master_weight
,
start_offset
,
model_weight_fragment
),
scale
in
zip
(
params
,
scales
):
# Clear columnwise data for all model weights.
# We cannot create columnwise data here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# at this moment.
model_weight
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
False
)
if
not
manual_post_all_gather_processing
:
# Clear columnwise data for all model weights.
# We cannot create columnwise data here because users (like megatron) may want to
# overlap the all-gather of model weights and forward process, so the model weight is
# not updated at this moment.
model_weight
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
False
)
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
...
...
@@ -459,18 +471,35 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
)
def
is_experimental
(
x
:
Optional
[
Union
[
Quantizer
,
QuantizedTensorStorage
]]
=
None
)
->
bool
:
"""Check if an environment or object is using experimental Kitchen middleware.
def
post_all_gather_processing
(
model_weights
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]):
"""
Post-processing after all-gather for weights in distributed optimizer.
- Float8Tensor: may need to create a transposed view to match backend GEMM.
- Float8BlockwiseQTensor: create column-wise storage.
- Plain pytorch tensor: noop.
"""
if
not
isinstance
(
model_weights
,
list
):
model_weights
=
[
model_weights
]
for
model_weight
in
model_weights
:
if
isinstance
(
model_weight
,
Float8Tensor
):
# Delayed scaling and per-tensor current scaling: if backend does not support
# non-transposed FP8 GEMM, pre-create the transpose.
if
not
is_non_tn_fp8_gemm_supported
():
model_weight
.
_create_transpose
()
elif
isinstance
(
model_weight
,
Float8BlockwiseQTensor
):
# Blockwise scaling: create column-wise storage.
model_weight
.
_create_columnwise
()
elif
isinstance
(
model_weight
,
QuantizedTensor
):
raise
ValueError
(
f
"post_processing for
{
type
(
model_weight
)
}
is not supported"
)
def
is_custom
(
x
:
Optional
[
Union
[
Quantizer
,
QuantizedTensorStorage
]]
=
None
)
->
bool
:
"""Check if an object is custom.
Returns False if x is a torch.Tensor.
"""
# Detect if the environment is experimental
if
x
is
None
:
return
int
(
os
.
getenv
(
"QAT_PARAMS"
,
"0"
))
>
0
# Detect if the object is experimental
if
isinstance
(
x
,
torch
.
Tensor
):
if
x
is
None
or
isinstance
(
x
,
torch
.
Tensor
):
return
False
if
not
isinstance
(
x
,
(
Quantizer
,
QuantizedTensorStorage
)):
raise
AssertionError
(
"Object must be a Quantizer or QuantizedTensorStorage instance"
)
return
hasattr
(
x
,
"
experimental"
)
and
x
.
experimental
return
hasattr
(
x
,
"
custom"
)
and
x
.
custom
transformer_engine/pytorch/transformer.py
View file @
c1a1c04e
...
...
@@ -176,7 +176,12 @@ class TransformerLayer(torch.nn.Module):
activation : str, default = 'gelu'
Type of activation used in MLP block.
Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'.
'silu', 'swiglu', and 'clamped_swiglu'.
activation_params : Optional[dict], default = `None`
Additional parameters for the activation function.
At the moment, only used for 'clamped_swiglu' activation which
supports 'limit' and 'alpha' parameters. You can set these as
`activation_params={'limit': 7.0, 'alpha': 1.702}`.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
...
...
@@ -310,6 +315,7 @@ class TransformerLayer(torch.nn.Module):
ub_bulk_wgrad
:
bool
=
True
,
bias
:
bool
=
True
,
activation
:
str
=
"gelu"
,
activation_params
:
Optional
[
dict
]
=
None
,
normalization
:
str
=
"LayerNorm"
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
attn_input_format
:
str
=
"sbhd"
,
...
...
@@ -475,6 +481,7 @@ class TransformerLayer(torch.nn.Module):
ub_overlap_rs
=
ub_overlap_rs
,
ub_overlap_ag
=
ub_overlap_ag
,
activation
=
activation
,
activation_params
=
activation_params
,
normalization
=
normalization
,
device
=
device
,
name
=
name
+
".layernorm_mlp"
if
name
is
not
None
else
None
,
...
...
transformer_engine/pytorch/triton/__init__.py
View file @
c1a1c04e
...
...
@@ -2,4 +2,4 @@
#
# See LICENSE for license information.
"""
Kernels written with OpenAI Triton
."""
"""
PyTorch wrappers for Triton kernels
."""
transformer_engine/pytorch/triton/cross_entropy.py
View file @
c1a1c04e
...
...
@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""
Efficient Cross Entropy kernels written with OpenAI Triton
."""
"""
PyTorch wrapper functions for Cross Entropy Triton kernels
."""
from
typing
import
Union
from
functools
import
reduce
...
...
@@ -13,257 +13,17 @@ import torch
import
torch.distributed
as
dist
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
online_softmax_kernel
(
X_ptr
,
X_stride
,
Y_ptr
,
Y_stride
,
m_d_X_y_ptr
,
m_d_X_y_stride
,
rank
,
n_cols
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""
This kernel computes the m/d components on this TP rank for the online softmax.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
m_d_X_y_ptr: Pointer to m/d/X_y tensor.
m_d_X_y_stride (int): The stride of the m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""
program_id
=
tl
.
program_id
(
0
).
to
(
tl
.
int64
)
# locate the start index
X_ptr
+=
program_id
*
X_stride
# Load Y_ptr
Y_ptr
+=
program_id
*
Y_stride
y
=
tl
.
load
(
Y_ptr
)
vocab_start_idx
=
rank
*
n_cols
vocab_end_idx
=
(
rank
+
1
)
*
n_cols
if
y
>=
vocab_start_idx
:
if
y
<
vocab_end_idx
:
X_y
=
tl
.
load
(
X_ptr
+
y
-
vocab_start_idx
).
to
(
tl
.
float32
)
else
:
X_y
=
float
(
"-inf"
)
else
:
X_y
=
float
(
"-inf"
)
m_d_X_y_ptr
+=
program_id
*
m_d_X_y_stride
*
3
# 3. [Online softmax] first pass: find max + sum
m
=
float
(
"-inf"
)
# m is the max value. use the notation from the paper
d
=
0.0
# d is the sum. use the notation from the paper
for
i
in
range
(
0
,
n_cols
,
BLOCK_SIZE
):
X_offsets
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
X_block
=
tl
.
load
(
X_ptr
+
X_offsets
,
mask
=
X_offsets
<
n_cols
,
other
=
float
(
"-inf"
)).
to
(
tl
.
float32
)
block_max
=
tl
.
max
(
X_block
)
m_new
=
tl
.
maximum
(
m
,
block_max
)
d
=
d
*
tl
.
exp
(
m
-
m_new
)
+
tl
.
sum
(
tl
.
exp
(
X_block
-
m_new
))
m
=
m_new
tl
.
store
(
m_d_X_y_ptr
,
m
)
tl
.
store
(
m_d_X_y_ptr
+
m_d_X_y_stride
,
d
)
tl
.
store
(
m_d_X_y_ptr
+
(
2
*
m_d_X_y_stride
),
X_y
)
@
triton
.
jit
def
cross_entropy_kernel
(
X_ptr
,
X_stride
,
Y_ptr
,
Y_stride
,
loss_ptr
,
loss_stride
,
m_d_X_y_ptr
,
m_d_X_y_stride
,
rank
,
world_size
,
ignore_idx
,
n_cols
,
n_non_ignore
,
reduce_loss
:
tl
.
constexpr
,
label_smoothing
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""
This kernel computes both cross entropy loss and the gradient of the input.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
loss_ptr: Pointer to tensor to store the loss.
loss_stride (int): The stride of the loss tensor.
m_d_X_y_ptr: Pointer to m/d/X_y tensor.
m_d_X_y_stride: The stride of m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
world_size (int): The size of world involved in this distributed loss calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
BLOCK_SIZE (int): The block size for Triton operations.
"""
program_id
=
tl
.
program_id
(
0
).
to
(
tl
.
int64
)
# locate the start index
X_ptr
+=
program_id
*
X_stride
# Load Y_ptr
Y_ptr
+=
program_id
*
Y_stride
y
=
tl
.
load
(
Y_ptr
)
if
y
==
ignore_idx
:
# set all X_ptr as 0
for
i
in
range
(
0
,
n_cols
,
BLOCK_SIZE
):
X_offsets
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
tl
.
store
(
X_ptr
+
X_offsets
,
0.0
,
mask
=
X_offsets
<
n_cols
)
return
loss_ptr
+=
program_id
*
loss_stride
m_d_X_y_ptr
+=
program_id
*
3
*
m_d_X_y_stride
# Need to reduce the m/d/X_y values from other TP ranks
m
=
tl
.
load
(
m_d_X_y_ptr
)
d
=
tl
.
load
(
m_d_X_y_ptr
+
m_d_X_y_stride
)
ori_X_y
=
tl
.
load
(
m_d_X_y_ptr
+
(
2
*
m_d_X_y_stride
))
for
i
in
range
(
1
,
world_size
):
offset
=
i
*
3
*
n_non_ignore
*
m_d_X_y_stride
access_ptr
=
m_d_X_y_ptr
+
offset
m_new
=
tl
.
load
(
access_ptr
)
d_new
=
tl
.
load
(
access_ptr
+
m_d_X_y_stride
)
X_y_new
=
tl
.
load
(
access_ptr
+
(
2
*
m_d_X_y_stride
))
d
=
d
*
tl
.
exp
(
m
-
tl
.
maximum
(
m
,
m_new
))
+
d_new
*
tl
.
exp
(
m_new
-
tl
.
maximum
(
m
,
m_new
))
m
=
tl
.
maximum
(
m
,
m_new
)
ori_X_y
=
tl
.
maximum
(
ori_X_y
,
X_y_new
)
# Label smoothing is a general case of normal cross entropy
scaled_x_sum
=
0.0
eps
=
label_smoothing
/
(
n_cols
*
world_size
)
# 4. [Online softmax] second pass: calculate the gradients
# dx_y = (softmax(x_y) - 1) / N
# dx_i = softmax(x_i) / N, i != y
# N is the number of non ignored elements in the batch
# For label smoothing:
# dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
# = dx_i - (1 - label_smoothing) / N
for
i
in
range
(
0
,
n_cols
,
BLOCK_SIZE
):
X_offsets
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
X_block
=
tl
.
load
(
X_ptr
+
X_offsets
,
mask
=
X_offsets
<
n_cols
,
other
=
float
(
"-inf"
))
grad_dtype
=
X_block
.
dtype
X_block
=
X_block
.
to
(
tl
.
float32
)
if
label_smoothing
>
0
:
# scale X beforehand to avoid overflow
scaled_x_sum
+=
tl
.
sum
(
tl
.
where
(
X_offsets
<
n_cols
,
-
eps
*
X_block
,
0.0
))
# Scale gradients based on reduction mode
# For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore
# For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here
if
reduce_loss
:
X_block
=
(
tl
.
exp
(
X_block
-
m
)
/
d
-
eps
)
/
(
n_non_ignore
)
else
:
X_block
=
tl
.
exp
(
X_block
-
m
)
/
d
-
eps
tl
.
store
(
X_ptr
+
X_offsets
,
X_block
.
to
(
grad_dtype
),
mask
=
X_offsets
<
n_cols
)
# We need tl.debug_barrier() to ensure the new result of X_ptr is written
tl
.
debug_barrier
()
# 5. Calculate the loss
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
loss
=
-
(
ori_X_y
-
m
-
tl
.
log
(
d
))
# Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
# = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
if
label_smoothing
>
0
:
smooth_loss
=
scaled_x_sum
+
label_smoothing
*
(
m
+
tl
.
log
(
d
))
loss
=
loss
*
(
1
-
label_smoothing
)
+
smooth_loss
# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
vocab_start_idx
=
rank
*
n_cols
vocab_end_idx
=
(
rank
+
1
)
*
n_cols
if
y
>=
vocab_start_idx
:
if
y
<
vocab_end_idx
:
X_y
=
tl
.
load
(
X_ptr
+
y
-
vocab_start_idx
)
# Apply the same conditional scaling logic for the target token
if
reduce_loss
:
X_y
+=
-
(
1
-
label_smoothing
)
/
(
n_non_ignore
)
else
:
X_y
+=
-
(
1
-
label_smoothing
)
tl
.
store
(
X_ptr
+
y
-
vocab_start_idx
,
X_y
)
tl
.
store
(
loss_ptr
,
loss
)
from
transformer_engine.common.triton.cross_entropy
import
(
online_softmax_kernel
,
cross_entropy_kernel
,
element_mul_kernel
,
)
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
MAX_FUSED_SIZE
=
65536
//
2
@
triton
.
jit
def
element_mul_kernel
(
X_ptr
,
X_stride
,
grad_output_ptr
,
grad_output_stride
,
n_cols
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
The multiplication is performed in-place on the tensor pointed by X_ptr.
Parameters:
X_ptr: Pointer to the input tensor.
X_stride (int): The stride of the input tensor.
grad_output_ptr: Pointer to the gradient output value.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""
# Get the program ID and convert it to int64 to avoid overflow
program_id
=
tl
.
program_id
(
0
).
to
(
tl
.
int64
)
# Locate the start index
X_ptr
+=
program_id
*
X_stride
# Load the gradient output value
grad_output_ptr
+=
program_id
*
grad_output_stride
grad_output
=
tl
.
load
(
grad_output_ptr
)
# Perform the element-wise multiplication
for
i
in
range
(
0
,
n_cols
,
BLOCK_SIZE
):
X_offsets
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
X_block
=
tl
.
load
(
X_ptr
+
X_offsets
,
mask
=
X_offsets
<
n_cols
)
tl
.
store
(
X_ptr
+
X_offsets
,
X_block
*
grad_output
,
mask
=
X_offsets
<
n_cols
)
def
cross_entropy_forward
(
_input
:
torch
.
Tensor
,
target
:
torch
.
Tensor
,
...
...
transformer_engine/pytorch/triton/pad.py
View file @
c1a1c04e
...
...
@@ -2,63 +2,12 @@
#
# See LICENSE for license information.
"""NVFP4 padding kernels
TODO(ksivamani): Documentation
"""
"""PyTorch wrapper functions for padding Triton kernels."""
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_stages
=
2
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_stages
=
2
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
2
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_stages
=
1
),
],
key
=
[
"out_dim0"
,
"out_dim1"
],
)
@
triton
.
jit
def
zero_pad_kernel
(
inp_ptr
,
out_ptr
,
in_dim0
:
tl
.
constexpr
,
in_dim1
:
tl
.
constexpr
,
out_dim0
:
tl
.
constexpr
,
out_dim1
:
tl
.
constexpr
,
in_s0
,
in_s1
,
out_s0
,
out_s1
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
"""Pads a tensor assuming it's a columnwise scaling inverse."""
# tile over OUTPUT coordinates
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
offs_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
# output rows
offs_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# output cols
om
=
offs_m
[:,
None
]
on
=
offs_n
[
None
,
:]
# edge masking for output
out_mask
=
(
om
<
out_dim0
)
&
(
on
<
out_dim1
)
# valid input region is simply top-left (no offsets)
in_mask
=
(
om
<
in_dim0
)
&
(
on
<
in_dim1
)
# load valid input, else zero (masked load touches memory only where True)
x
=
tl
.
load
(
inp_ptr
+
om
*
in_s0
+
on
*
in_s1
,
mask
=
in_mask
,
other
=
0
)
# store to output (only within bounds of the output tile)
tl
.
store
(
out_ptr
+
om
*
out_s0
+
on
*
out_s1
,
x
,
mask
=
out_mask
)
from
transformer_engine.common.triton.pad
import
zero_pad_kernel
def
pad_columnwise_scale_inv
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
transformer_engine/pytorch/triton/permutation.py
View file @
c1a1c04e
...
...
@@ -2,192 +2,23 @@
#
# See LICENSE for license information.
"""P
ermutation kernels written with OpenAI Triton
."""
"""P
yTorch wrapper functions for Permutation Triton kernels
."""
from
typing
import
Union
import
torch
import
triton
import
triton.language
as
tl
from
triton.language
import
core
from
triton.language.standard
import
_log2
# The following three argsort related kernels are adapted from
# the issue https://github.com/triton-lang/triton/issues/3698
@
triton
.
jit
def
_compare_and_swap
(
x
,
indices
,
flip
,
i
:
tl
.
constexpr
,
n_dims
:
tl
.
constexpr
):
n_outer
:
tl
.
constexpr
=
x
.
numel
>>
n_dims
shape
:
tl
.
constexpr
=
[
n_outer
*
(
2
**
i
),
2
,
2
**
(
n_dims
-
i
-
1
)]
y
=
tl
.
reshape
(
x
,
shape
)
z
=
tl
.
reshape
(
indices
,
shape
)
mask
=
tl
.
arange
(
0
,
2
)[
None
,
:,
None
]
l_value
=
tl
.
reshape
(
tl
.
broadcast_to
(
tl
.
sum
(
y
*
(
1
-
mask
),
1
)[:,
None
,
:],
shape
),
x
.
shape
).
to
(
x
.
dtype
)
r_value
=
tl
.
reshape
(
tl
.
broadcast_to
(
tl
.
sum
(
y
*
mask
,
1
)[:,
None
,
:],
shape
),
x
.
shape
).
to
(
x
.
dtype
)
l_indice
=
tl
.
reshape
(
tl
.
broadcast_to
(
tl
.
sum
(
z
*
(
1
-
mask
),
1
)[:,
None
,
:],
shape
),
x
.
shape
)
r_indice
=
tl
.
reshape
(
tl
.
broadcast_to
(
tl
.
sum
(
z
*
mask
,
1
)[:,
None
,
:],
shape
),
x
.
shape
)
idtype
=
core
.
get_int_dtype
(
bitwidth
=
x
.
dtype
.
primitive_bitwidth
,
signed
=
True
)
il_value
=
l_value
.
to
(
idtype
,
bitcast
=
True
)
ir_value
=
r_value
.
to
(
idtype
,
bitcast
=
True
)
ix
=
x
.
to
(
idtype
,
bitcast
=
True
)
flag1
=
tl
.
where
(((
l_value
>
r_value
)
^
flip
)
!=
0
,
il_value
^
ir_value
,
tl
.
zeros_like
(
ix
))
ret
=
ix
^
flag1
flag2
=
tl
.
where
(((
l_value
>
r_value
)
^
flip
)
!=
0
,
l_indice
^
r_indice
,
tl
.
zeros_like
(
ix
))
ind
=
indices
^
flag2
return
ret
.
to
(
x
.
dtype
,
bitcast
=
True
),
ind
@
triton
.
jit
def
_bitonic_merge
(
x
,
indices
,
stage
:
tl
.
constexpr
,
order
:
tl
.
constexpr
,
n_dims
:
tl
.
constexpr
):
n_outer
:
tl
.
constexpr
=
x
.
numel
>>
n_dims
tl
.
static_assert
(
stage
<=
n_dims
)
"""
order_type 0 == ascending
order_type 1 == descending
order_type 2 == alternating
"""
if
order
==
2
:
shape
:
tl
.
constexpr
=
[
n_outer
*
(
2
**
(
n_dims
-
1
-
stage
)),
2
,
2
**
stage
]
flip
=
tl
.
reshape
(
tl
.
broadcast_to
(
tl
.
arange
(
0
,
2
)[
None
,
:,
None
],
shape
),
x
.
shape
)
else
:
flip
=
tl
.
full
(
x
.
shape
,
value
=
order
,
dtype
=
tl
.
int32
)
for
i
in
tl
.
static_range
(
stage
):
x
,
indices
=
_compare_and_swap
(
x
,
indices
,
flip
,
i
+
(
n_dims
-
stage
),
n_dims
)
return
x
,
indices
@
triton
.
jit
def
_argsort
(
x
,
indices
,
n_dims
:
tl
.
constexpr
):
for
i
in
tl
.
static_range
(
1
,
n_dims
+
1
):
x
,
indices
=
_bitonic_merge
(
x
,
indices
,
i
,
2
if
i
<
n_dims
else
1
,
n_dims
)
return
x
,
indices
@
triton
.
jit
def
_row_id_map_pass_1_kernel
(
# pointers
routing_map_ptr
,
row_id_map_ptr
,
workspace_ptr
,
# sizes
num_tokens
,
# strides
stride_routing_map_token
,
stride_routing_map_expert
,
stride_row_id_map_token
,
stride_row_id_map_expert
,
# metas
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
offset
=
pid_n
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
expert_token_mask
=
tl
.
load
(
routing_map_ptr
+
pid_m
*
stride_routing_map_expert
+
offset
*
stride_routing_map_token
,
mask
=
(
offset
<
num_tokens
),
other
=
0
,
).
to
(
tl
.
int32
)
row_id_within_token_block
=
tl
.
cumsum
(
expert_token_mask
)
*
expert_token_mask
tl
.
store
(
row_id_map_ptr
+
pid_m
*
stride_row_id_map_expert
+
offset
*
stride_row_id_map_token
,
row_id_within_token_block
,
mask
=
offset
<
num_tokens
,
)
n_tokens_per_block
=
tl
.
sum
(
expert_token_mask
)
tl
.
store
(
workspace_ptr
+
pid_m
*
tl
.
cdiv
(
num_tokens
,
BLOCK_SIZE
)
+
pid_n
,
n_tokens_per_block
)
@
triton
.
jit
def
_row_id_map_pass_2_kernel
(
# pointers
row_id_map_ptr
,
workspace_ptr
,
# sizes
num_tokens
,
# strides
stride_row_id_map_token
,
stride_row_id_map_expert
,
# metas
WORKSPACE_LOAD_WIDTH
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
chunk_idx
=
pid_m
*
tl
.
cdiv
(
num_tokens
,
BLOCK_SIZE
)
+
pid_n
offset
=
pid_n
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
row_id_within_token_block
=
tl
.
load
(
row_id_map_ptr
+
pid_m
*
stride_row_id_map_expert
+
offset
*
stride_row_id_map_token
,
mask
=
(
offset
<
num_tokens
),
other
=
0
,
)
workspace_off
=
tl
.
arange
(
0
,
WORKSPACE_LOAD_WIDTH
)
n_tokens_per_chunk
=
tl
.
load
(
workspace_ptr
+
workspace_off
,
mask
=
workspace_off
<
chunk_idx
)
row_id
=
tl
.
where
(
row_id_within_token_block
==
0
,
-
1
,
row_id_within_token_block
+
tl
.
sum
(
n_tokens_per_chunk
)
-
1
,
)
tl
.
store
(
row_id_map_ptr
+
pid_m
*
stride_row_id_map_expert
+
offset
*
stride_row_id_map_token
,
row_id
,
mask
=
(
offset
<
num_tokens
),
)
@
triton
.
jit
def
_row_id_map_pass_3_kernel
(
# pointers
row_id_map_ptr
,
# sizes
num_experts
:
tl
.
constexpr
,
# strides
stride_row_id_map_token
,
stride_row_id_map_expert
,
# metas
LOAD_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
n_dims
:
tl
.
constexpr
=
_log2
(
LOAD_SIZE
)
off
=
tl
.
arange
(
0
,
LOAD_SIZE
)
row_id_map
=
tl
.
load
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
+
stride_row_id_map_expert
*
off
,
mask
=
off
<
num_experts
,
other
=-
1
,
)
n_routed
=
tl
.
sum
(
tl
.
where
(
row_id_map
!=
-
1
,
1
,
0
))
indices
=
off
sorted_map
,
indices
=
_argsort
(
row_id_map
,
indices
,
n_dims
=
n_dims
)
tl
.
store
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
+
off
*
stride_row_id_map_expert
,
sorted_map
,
mask
=
off
<
n_routed
,
)
tl
.
store
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
+
(
num_experts
+
off
)
*
stride_row_id_map_expert
,
indices
,
mask
=
off
<
n_routed
,
)
tl
.
store
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
+
num_experts
*
2
*
stride_row_id_map_expert
,
n_routed
,
)
from
transformer_engine.common.triton.permutation
import
(
_row_id_map_pass_1_kernel
,
_row_id_map_pass_2_kernel
,
_row_id_map_pass_3_kernel
,
_permute_kernel
,
_unpermute_kernel
,
_unpermute_bwd_with_merging_probs_kernel
,
_make_chunk_sort_map_kernel
,
_sort_chunks_by_map_kernel
,
)
def
make_row_id_map
(
...
...
@@ -287,103 +118,6 @@ def make_row_id_map(
return
row_id_map
@
triton
.
jit
def
_permute_kernel
(
# pointers
input_ptr
,
output_ptr
,
row_id_map_ptr
,
probs_ptr
,
scale_ptr
,
permuted_probs_ptr
,
permuted_scale_ptr
,
# sizes
num_experts
:
tl
.
constexpr
,
hidden_size
:
tl
.
constexpr
,
scale_hidden_dim
,
# strides
stride_row_id_map_token
,
stride_row_id_map_expert
,
stride_input_token
,
stride_input_hidden
,
stride_output_token
,
stride_output_hidden
,
stride_probs_token
,
stride_probs_expert
,
stride_scale_token
,
stride_scale_hidden
,
stride_permuted_probs_token
,
stride_permuted_scale_token
,
stride_permuted_scale_hidden
,
# metas
PERMUTE_PROBS
:
tl
.
constexpr
,
PERMUTE_SCALE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid_t
=
tl
.
program_id
(
0
)
pid_h
=
tl
.
program_id
(
1
)
cur_off
=
pid_h
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
cur_off
<
hidden_size
src_row
=
pid_t
.
to
(
tl
.
int64
)
input_off
=
src_row
*
stride_input_token
+
cur_off
*
stride_input_hidden
inp
=
tl
.
load
(
input_ptr
+
input_off
,
mask
=
mask
)
if
PERMUTE_SCALE
:
mask_scale
=
cur_off
<
scale_hidden_dim
scale_off
=
pid_t
*
stride_scale_token
+
cur_off
*
stride_scale_hidden
scale
=
tl
.
load
(
scale_ptr
+
scale_off
,
mask
=
mask_scale
)
n_routed
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
num_experts
*
2
*
stride_row_id_map_expert
)
for
idx
in
tl
.
range
(
n_routed
):
dst_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
idx
*
stride_row_id_map_expert
).
to
(
tl
.
int64
)
output_off
=
dst_row
*
stride_output_token
+
cur_off
*
stride_output_hidden
if
PERMUTE_SCALE
:
permuted_scale_off
=
(
dst_row
*
stride_permuted_scale_token
+
cur_off
*
stride_permuted_scale_hidden
)
tl
.
store
(
permuted_scale_ptr
+
permuted_scale_off
,
scale
,
mask
=
mask_scale
)
if
PERMUTE_PROBS
:
expert_idx
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
(
num_experts
+
idx
)
*
stride_row_id_map_expert
)
prob_off
=
pid_t
*
stride_probs_token
+
expert_idx
*
stride_probs_expert
prob
=
tl
.
load
(
probs_ptr
+
prob_off
)
if
pid_h
==
0
:
permuted_prob_off
=
dst_row
*
stride_permuted_probs_token
tl
.
store
(
permuted_probs_ptr
+
permuted_prob_off
,
prob
)
if
prob
==
0.0
:
# for routing_map padding
# dst_row != -1 and prob == 0.0 means that this slot is padded
tl
.
store
(
output_ptr
+
output_off
,
0.0
,
mask
=
mask
)
else
:
tl
.
store
(
output_ptr
+
output_off
,
inp
,
mask
=
mask
)
else
:
tl
.
store
(
output_ptr
+
output_off
,
inp
,
mask
=
mask
)
try
:
_permute_kernel
=
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_SIZE"
:
64
}),
triton
.
Config
({
"BLOCK_SIZE"
:
128
}),
triton
.
Config
({
"BLOCK_SIZE"
:
256
}),
triton
.
Config
({
"BLOCK_SIZE"
:
512
}),
triton
.
Config
({
"BLOCK_SIZE"
:
1024
}),
triton
.
Config
({
"BLOCK_SIZE"
:
2048
}),
triton
.
Config
({
"BLOCK_SIZE"
:
4096
}),
],
key
=
[
"hidden_size"
],
)(
_permute_kernel
)
except
RuntimeError
:
pass
def
permute_with_mask_map
(
inp
:
torch
.
Tensor
,
row_id_map
:
torch
.
Tensor
,
...
...
@@ -463,116 +197,6 @@ def permute_with_mask_map(
return
output
,
permuted_scale
,
permuted_probs
@
triton
.
jit
def
_unpermute_kernel
(
# pointers
input_ptr
,
output_ptr
,
row_id_map_ptr
,
merging_probs_ptr
,
permuted_probs_ptr
,
unpermuted_probs_ptr
,
# sizes
num_experts
:
tl
.
constexpr
,
hidden_size
:
tl
.
constexpr
,
# strides
stride_row_id_map_token
,
stride_row_id_map_expert
,
stride_input_token
,
stride_input_hidden
,
stride_output_token
,
stride_output_hidden
,
stride_merging_probs_token
,
stride_merging_probs_expert
,
stride_permuted_probs_token
,
stride_unpermuted_probs_token
,
stride_unpermuted_probs_expert
,
# metas
PROBS_LOAD_WIDTH
:
tl
.
constexpr
,
WITH_MERGING_PROBS
:
tl
.
constexpr
,
PERMUTE_PROBS
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
data_type
=
input_ptr
.
dtype
.
element_ty
compute_type
=
tl
.
float32
pid_t
=
tl
.
program_id
(
0
)
pid_h
=
tl
.
program_id
(
1
)
current_offset
=
pid_h
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
current_offset
<
hidden_size
if
PERMUTE_PROBS
:
# write 0.0 to probs_grad that are not routed
if
pid_h
==
0
:
map_load_off
=
tl
.
arange
(
0
,
PROBS_LOAD_WIDTH
)
unpermuted_prob_off
=
(
pid_t
*
stride_unpermuted_probs_token
+
stride_unpermuted_probs_expert
*
map_load_off
)
tl
.
store
(
unpermuted_probs_ptr
+
unpermuted_prob_off
,
0.0
,
mask
=
map_load_off
<
num_experts
)
accumulator
=
tl
.
zeros
((
BLOCK_SIZE
,),
dtype
=
compute_type
)
n_routed
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
num_experts
*
2
*
stride_row_id_map_expert
)
for
idx
in
tl
.
range
(
n_routed
):
src_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
idx
*
stride_row_id_map_expert
).
to
(
tl
.
int64
)
input_off
=
src_row
*
stride_input_token
+
current_offset
*
stride_input_hidden
inp
=
tl
.
load
(
input_ptr
+
input_off
,
mask
=
mask
)
inp
=
inp
.
to
(
compute_type
)
if
WITH_MERGING_PROBS
:
expert_idx
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
(
num_experts
+
idx
)
*
stride_row_id_map_expert
)
merging_prob_off
=
(
pid_t
*
stride_merging_probs_token
+
expert_idx
*
stride_merging_probs_expert
)
merging_prob
=
tl
.
load
(
merging_probs_ptr
+
merging_prob_off
).
to
(
compute_type
)
inp
*=
merging_prob
accumulator
+=
inp
if
PERMUTE_PROBS
:
if
pid_h
==
0
:
expert_idx
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
(
num_experts
+
idx
)
*
stride_row_id_map_expert
)
unpermuted_prob_off
=
(
pid_t
*
stride_unpermuted_probs_token
+
expert_idx
*
stride_unpermuted_probs_expert
)
permuted_prob_off
=
src_row
*
stride_permuted_probs_token
prob
=
tl
.
load
(
permuted_probs_ptr
+
permuted_prob_off
)
tl
.
store
(
unpermuted_probs_ptr
+
unpermuted_prob_off
,
prob
)
accumulator
=
accumulator
.
to
(
data_type
)
dst_row
=
pid_t
.
to
(
tl
.
int64
)
output_off
=
dst_row
*
stride_output_token
+
current_offset
*
stride_output_hidden
tl
.
store
(
output_ptr
+
output_off
,
accumulator
,
mask
=
mask
)
try
:
_unpermute_kernel
=
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_SIZE"
:
64
}),
triton
.
Config
({
"BLOCK_SIZE"
:
128
}),
triton
.
Config
({
"BLOCK_SIZE"
:
256
}),
triton
.
Config
({
"BLOCK_SIZE"
:
512
}),
triton
.
Config
({
"BLOCK_SIZE"
:
1024
}),
triton
.
Config
({
"BLOCK_SIZE"
:
2048
}),
triton
.
Config
({
"BLOCK_SIZE"
:
4096
}),
],
key
=
[
"hidden_size"
],
)(
_unpermute_kernel
)
except
RuntimeError
:
pass
def
unpermute_with_mask_map
(
inp
:
torch
.
Tensor
,
row_id_map
:
torch
.
Tensor
,
...
...
@@ -639,110 +263,6 @@ def unpermute_with_mask_map(
return
output
,
unpermuted_probs
@
triton
.
jit
def
_unpermute_bwd_with_merging_probs_kernel
(
# pointers
fwd_output_grad_ptr
,
fwd_input_grad_ptr
,
fwd_input_ptr
,
merging_probs_ptr
,
merging_probs_grad_ptr
,
row_id_map_ptr
,
# sizes
num_experts
:
tl
.
constexpr
,
hidden_size
:
tl
.
constexpr
,
# strides
stride_row_id_map_token
,
stride_row_id_map_expert
,
stride_fwd_output_grad_token
,
stride_fwd_output_grad_hidden
,
stride_fwd_input_grad_token
,
stride_fwd_input_grad_hidden
,
stride_fwd_input_token
,
stride_fwd_input_hidden
,
stride_merging_probs_token
,
stride_merging_probs_expert
,
stride_merging_probs_grad_token
,
stride_merging_probs_grad_expert
,
# metas
PROBS_LOAD_WIDTH
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
data_type
=
fwd_output_grad_ptr
.
dtype
.
element_ty
compute_type
=
tl
.
float32
pid
=
tl
.
program_id
(
0
)
map_load_off
=
tl
.
arange
(
0
,
PROBS_LOAD_WIDTH
)
token_probs_grad_off
=
(
pid
*
stride_merging_probs_grad_token
+
stride_merging_probs_grad_expert
*
map_load_off
)
tl
.
store
(
merging_probs_grad_ptr
+
token_probs_grad_off
,
0.0
,
mask
=
map_load_off
<
num_experts
)
n_routed
=
tl
.
load
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
+
num_experts
*
2
*
stride_row_id_map_expert
)
for
idx
in
tl
.
range
(
n_routed
):
dst_row
=
tl
.
load
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
+
idx
*
stride_row_id_map_expert
).
to
(
tl
.
int64
)
expert_idx
=
tl
.
load
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
+
(
num_experts
+
idx
)
*
stride_row_id_map_expert
)
prob_grad_accum
=
tl
.
zeros
((
BLOCK_SIZE
,),
dtype
=
compute_type
)
current_start
=
0
while
current_start
<
hidden_size
:
current_offset
=
current_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
current_offset
<
hidden_size
src_row
=
pid
.
to
(
tl
.
int64
)
input_off
=
(
src_row
*
stride_fwd_output_grad_token
+
current_offset
*
stride_fwd_output_grad_hidden
)
inp
=
tl
.
load
(
fwd_output_grad_ptr
+
input_off
,
mask
=
mask
)
inp
=
inp
.
to
(
compute_type
)
merging_prob_off
=
(
pid
*
stride_merging_probs_token
+
expert_idx
*
stride_merging_probs_expert
)
merging_prob
=
tl
.
load
(
merging_probs_ptr
+
merging_prob_off
).
to
(
compute_type
)
output
=
inp
*
merging_prob
output
=
output
.
to
(
data_type
)
output_off
=
(
dst_row
*
stride_fwd_input_grad_token
+
current_offset
*
stride_fwd_input_grad_hidden
)
tl
.
store
(
fwd_input_grad_ptr
+
output_off
,
output
,
mask
=
mask
)
fwd_input_off
=
(
dst_row
*
stride_fwd_input_token
+
current_offset
*
stride_fwd_input_hidden
)
fwd_input
=
tl
.
load
(
fwd_input_ptr
+
fwd_input_off
,
mask
=
mask
)
prob_grad_accum
+=
fwd_input
.
to
(
compute_type
)
*
inp
current_start
+=
BLOCK_SIZE
probs_grad
=
tl
.
sum
(
prob_grad_accum
).
to
(
merging_probs_grad_ptr
.
dtype
.
element_ty
)
probs_grad_off
=
(
pid
*
stride_merging_probs_grad_token
+
expert_idx
*
stride_merging_probs_grad_expert
)
tl
.
store
(
merging_probs_grad_ptr
+
probs_grad_off
,
probs_grad
)
try
:
_unpermute_bwd_with_merging_probs_kernel
=
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_SIZE"
:
64
}),
triton
.
Config
({
"BLOCK_SIZE"
:
128
}),
triton
.
Config
({
"BLOCK_SIZE"
:
256
}),
triton
.
Config
({
"BLOCK_SIZE"
:
512
}),
triton
.
Config
({
"BLOCK_SIZE"
:
1024
}),
triton
.
Config
({
"BLOCK_SIZE"
:
2048
}),
triton
.
Config
({
"BLOCK_SIZE"
:
4096
}),
],
key
=
[
"hidden_size"
],
)(
_unpermute_bwd_with_merging_probs_kernel
)
except
RuntimeError
:
pass
def
unpermute_with_mask_map_bwd_with_merging_probs
(
fwd_output_grad
:
torch
.
Tensor
,
row_id_map
:
torch
.
Tensor
,
...
...
@@ -808,47 +328,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
return
act_grad
,
merging_probs_grad
@
triton
.
jit
def
_make_chunk_sort_map_kernel
(
# pointers
split_sizes_ptr
,
sorted_indices_ptr
,
dst_rows_ptr
,
# sizes
num_splits
:
tl
.
constexpr
,
# metas
IDX_LOAD_WIDTH
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
load_split_offset
=
tl
.
arange
(
0
,
IDX_LOAD_WIDTH
)
sorted_indices
=
tl
.
load
(
sorted_indices_ptr
+
load_split_offset
,
mask
=
load_split_offset
<
num_splits
)
# get chunk idx of the current token in the input tensor
input_split_sizes
=
tl
.
load
(
split_sizes_ptr
+
load_split_offset
,
mask
=
load_split_offset
<
num_splits
,
other
=
0
).
to
(
tl
.
int32
)
input_split_sizes_cumsum
=
tl
.
cumsum
(
input_split_sizes
)
input_split_sizes_mask
=
tl
.
where
(
input_split_sizes_cumsum
<=
pid
,
1
,
0
)
input_chunk_idx
=
tl
.
sum
(
input_split_sizes_mask
)
input_split_sizes_presum
=
tl
.
sum
(
input_split_sizes
*
input_split_sizes_mask
)
in_chunk_offset
=
pid
-
input_split_sizes_presum
# get chunk idx of the current token in the output tensor
output_chunk_mask
=
tl
.
where
(
sorted_indices
==
input_chunk_idx
,
1
,
0
)
output_chunk_idx
=
tl
.
argmax
(
output_chunk_mask
,
axis
=-
1
)
# make row_id_map
output_split_sizes
=
tl
.
load
(
split_sizes_ptr
+
sorted_indices
,
mask
=
load_split_offset
<
num_splits
).
to
(
tl
.
int32
)
output_pre_split_sizes
=
tl
.
where
(
load_split_offset
<
output_chunk_idx
,
output_split_sizes
,
0
)
dst_row
=
tl
.
sum
(
output_pre_split_sizes
)
+
in_chunk_offset
tl
.
store
(
dst_rows_ptr
+
pid
,
dst_row
)
def
make_chunk_sort_map
(
split_sizes
:
torch
.
Tensor
,
sorted_indices
:
torch
.
Tensor
,
...
...
@@ -881,67 +360,6 @@ def make_chunk_sort_map(
return
row_id_map
@
triton
.
jit
def
_sort_chunks_by_map_kernel
(
# pointers
input_ptr
,
output_ptr
,
row_id_map_ptr
,
probs_ptr
,
permuted_probs_ptr
,
# sizes
hidden_size
:
tl
.
constexpr
,
# strides
stride_input_token
,
stride_input_hidden
,
stride_output_token
,
stride_output_hidden
,
stride_probs_token
,
stride_permuted_probs_token
,
# metas
PERMUTE_PROBS
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
FORWARD
:
tl
.
constexpr
,
):
pid_t
=
tl
.
program_id
(
0
)
pid_h
=
tl
.
program_id
(
1
)
if
FORWARD
:
src_row
=
pid_t
.
to
(
tl
.
int64
)
dst_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
).
to
(
tl
.
int64
)
else
:
src_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
).
to
(
tl
.
int64
)
dst_row
=
pid_t
.
to
(
tl
.
int64
)
current_offset
=
pid_h
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
current_offset
<
hidden_size
input_offsets
=
src_row
*
stride_input_token
+
current_offset
*
stride_input_hidden
output_offsets
=
dst_row
*
stride_output_token
+
current_offset
*
stride_output_hidden
inp
=
tl
.
load
(
input_ptr
+
input_offsets
,
mask
=
mask
)
tl
.
store
(
output_ptr
+
output_offsets
,
inp
,
mask
=
mask
)
if
PERMUTE_PROBS
:
if
pid_h
==
0
:
prob_off
=
src_row
*
stride_probs_token
prob
=
tl
.
load
(
probs_ptr
+
prob_off
)
permuted_prob_off
=
dst_row
*
stride_permuted_probs_token
tl
.
store
(
permuted_probs_ptr
+
permuted_prob_off
,
prob
)
try
:
_sort_chunks_by_map_kernel
=
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_SIZE"
:
64
}),
triton
.
Config
({
"BLOCK_SIZE"
:
128
}),
triton
.
Config
({
"BLOCK_SIZE"
:
256
}),
triton
.
Config
({
"BLOCK_SIZE"
:
512
}),
triton
.
Config
({
"BLOCK_SIZE"
:
1024
}),
triton
.
Config
({
"BLOCK_SIZE"
:
2048
}),
triton
.
Config
({
"BLOCK_SIZE"
:
4096
}),
],
key
=
[
"hidden_size"
],
)(
_sort_chunks_by_map_kernel
)
except
RuntimeError
:
pass
def
sort_chunks_by_map
(
inp
:
torch
.
Tensor
,
row_id_map
:
torch
.
Tensor
,
...
...
transformer_engine/pytorch/utils.py
View file @
c1a1c04e
...
...
@@ -12,7 +12,7 @@ import numpy as np
import
torch
from
.
import
torch_version
from
.tensor
.quantized_tensor
import
Quantizer
from
.quantized_tensor
import
Quantizer
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
__all__
=
[
"get_device_compute_capability"
,
"get_cudnn_version"
,
"is_bf16_available"
]
...
...
Prev
1
…
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment