Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ab3e5a92
Commit
ab3e5a92
authored
May 09, 2025
by
yuguo
Browse files
Merge commit '
04c730c0
' of...
Merge commit '
04c730c0
' of
https://github.com/NVIDIA/TransformerEngine
parents
a8d19fd9
04c730c0
Changes
174
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
573 additions
and
351 deletions
+573
-351
transformer_engine/jax/setup.py
transformer_engine/jax/setup.py
+1
-1
transformer_engine/jax/sharding.py
transformer_engine/jax/sharding.py
+28
-2
transformer_engine/pytorch/attention.py
transformer_engine/pytorch/attention.py
+23
-11
transformer_engine/pytorch/constants.py
transformer_engine/pytorch/constants.py
+6
-0
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+17
-53
transformer_engine/pytorch/cpu_offload.py
transformer_engine/pytorch/cpu_offload.py
+43
-30
transformer_engine/pytorch/csrc/common.h
transformer_engine/pytorch/csrc/common.h
+32
-0
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+36
-19
transformer_engine/pytorch/csrc/extensions/activation.cpp
transformer_engine/pytorch/csrc/extensions/activation.cpp
+6
-1
transformer_engine/pytorch/csrc/extensions/apply_rope.cpp
transformer_engine/pytorch/csrc/extensions/apply_rope.cpp
+125
-161
transformer_engine/pytorch/csrc/extensions/attention.cu
transformer_engine/pytorch/csrc/extensions/attention.cu
+13
-11
transformer_engine/pytorch/csrc/extensions/cast.cpp
transformer_engine/pytorch/csrc/extensions/cast.cpp
+12
-3
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
...rmer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
+10
-10
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+15
-2
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu
...src/extensions/multi_tensor/multi_tensor_compute_scale.cu
+4
-2
transformer_engine/pytorch/csrc/extensions/normalization.cpp
transformer_engine/pytorch/csrc/extensions/normalization.cpp
+14
-6
transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp
transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp
+129
-0
transformer_engine/pytorch/csrc/extensions/padding.cpp
transformer_engine/pytorch/csrc/extensions/padding.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/permutation.cu
transformer_engine/pytorch/csrc/extensions/permutation.cu
+9
-30
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+49
-8
No files found.
transformer_engine/jax/setup.py
View file @
ab3e5a92
...
@@ -101,7 +101,7 @@ if __name__ == "__main__":
...
@@ -101,7 +101,7 @@ if __name__ == "__main__":
ext_modules
=
ext_modules
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
},
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
},
install_requires
=
[
"jax"
,
"flax>=0.7.1"
],
install_requires
=
[
"jax"
,
"flax>=0.7.1"
],
tests_require
=
[
"numpy"
,
"praxis"
],
tests_require
=
[
"numpy"
],
)
)
if
any
(
x
in
sys
.
argv
for
x
in
(
"."
,
"sdist"
,
"bdist_wheel"
)):
if
any
(
x
in
sys
.
argv
for
x
in
(
"."
,
"sdist"
,
"bdist_wheel"
)):
shutil
.
rmtree
(
common_headers_dir
)
shutil
.
rmtree
(
common_headers_dir
)
...
...
transformer_engine/jax/sharding.py
View file @
ab3e5a92
...
@@ -89,7 +89,11 @@ def generate_pspec(logical_axis_names):
...
@@ -89,7 +89,11 @@ def generate_pspec(logical_axis_names):
Convert logical axes to PartitionSpec
Convert logical axes to PartitionSpec
"""
"""
rules
=
get_sharding_map_logic_axis_to_mesh_axis
()
rules
=
get_sharding_map_logic_axis_to_mesh_axis
()
mesh_axis_names
=
[
rules
[
name
]
for
name
in
logical_axis_names
]
# mesh_axis_names = [rules[name] for name in logical_axis_names]
mesh_axis_names
=
[]
for
name
in
logical_axis_names
:
axis_name
=
rules
[
name
]
if
name
in
rules
else
None
mesh_axis_names
.
append
(
axis_name
)
pspec
=
jax
.
sharding
.
PartitionSpec
(
*
mesh_axis_names
)
pspec
=
jax
.
sharding
.
PartitionSpec
(
*
mesh_axis_names
)
return
pspec
return
pspec
...
@@ -112,7 +116,7 @@ def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: t
...
@@ -112,7 +116,7 @@ def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: t
"""
"""
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
"""
"""
if
logical_axis_names
is
None
:
if
not
logical_axis_names
:
return
x
return
x
assert
len
(
x
.
shape
)
==
len
(
logical_axis_names
)
assert
len
(
x
.
shape
)
==
len
(
logical_axis_names
)
...
@@ -315,3 +319,25 @@ class ShardingType(Enum):
...
@@ -315,3 +319,25 @@ class ShardingType(Enum):
TP_ROW
=
(
MajorShardingType
.
TP
,
"tp_row"
)
TP_ROW
=
(
MajorShardingType
.
TP
,
"tp_row"
)
DP_TP_COL
=
(
MajorShardingType
.
DPTP
,
"dp_tp_col"
)
DP_TP_COL
=
(
MajorShardingType
.
DPTP
,
"dp_tp_col"
)
DP_TP_ROW
=
(
MajorShardingType
.
DPTP
,
"dp_tp_row"
)
DP_TP_ROW
=
(
MajorShardingType
.
DPTP
,
"dp_tp_row"
)
def
get_non_contracting_logical_axes
(
ndim
,
logical_axes
,
contracting_dims
):
"""Get logical axes for non-contracting dimensions.
Args:
ndim: Number of dimensions in the tensor.
logical_axes: Tuple of logical axes for each dimension.
contracting_dims: Set of dimensions that are being contracted.
Returns:
Tuple of logical axes for non-contracting dimensions.
"""
if
not
logical_axes
:
logical_axes
=
(
None
,)
*
ndim
elif
len
(
logical_axes
)
<
ndim
:
logical_axes
=
logical_axes
+
(
None
,)
*
(
ndim
-
len
(
logical_axes
))
assert
len
(
logical_axes
)
==
ndim
non_contracting_dims
=
[
i
for
i
in
range
(
ndim
)
if
i
not
in
contracting_dims
]
non_contracting_logical_axes
=
tuple
(
logical_axes
[
i
]
for
i
in
non_contracting_dims
)
return
non_contracting_logical_axes
transformer_engine/pytorch/attention.py
View file @
ab3e5a92
...
@@ -20,6 +20,7 @@ import torch
...
@@ -20,6 +20,7 @@ import torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
from
transformer_engine.pytorch.utils
import
(
from
transformer_engine.pytorch.utils
import
(
get_cudnn_version
,
get_cudnn_version
,
nvtx_range_pop
,
nvtx_range_pop
,
...
@@ -81,6 +82,7 @@ import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
...
@@ -81,6 +82,7 @@ import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
from
transformer_engine.pytorch.dot_product_attention.utils
import
FlashAttentionUtils
as
fa_utils
from
transformer_engine.pytorch.dot_product_attention.utils
import
FlashAttentionUtils
as
fa_utils
from
transformer_engine.pytorch.dot_product_attention.utils
import
AttentionLogging
as
attn_log
from
transformer_engine.pytorch.dot_product_attention.utils
import
AttentionLogging
as
attn_log
from
transformer_engine.pytorch.dot_product_attention.rope
import
apply_rotary_pos_emb
from
transformer_engine.pytorch.dot_product_attention.rope
import
apply_rotary_pos_emb
from
.cpu_offload
import
mark_activation_offload
# Setup Attention Logging
# Setup Attention Logging
...
@@ -618,7 +620,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -618,7 +620,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
rank
=
get_distributed_rank
(
cp_group
)
rank
=
get_distributed_rank
(
cp_group
)
send_dst
=
cp_global_ranks
[(
rank
+
1
)
%
cp_size
*
cp_size_a2a
+
rank_a2a
]
send_dst
=
cp_global_ranks
[(
rank
+
1
)
%
cp_size
*
cp_size_a2a
+
rank_a2a
]
recv_src
=
cp_global_ranks
[(
rank
-
1
)
%
cp_size
*
cp_size_a2a
+
rank_a2a
]
recv_src
=
cp_global_ranks
[(
rank
-
1
)
%
cp_size
*
cp_size_a2a
+
rank_a2a
]
batch_p2p_comm
=
int
(
os
.
getenv
(
"NVTE_BATCH_MHA_P2P_COMM"
,
"0"
))
or
(
cp_size
==
2
)
batch_p2p_comm
=
int
(
os
.
getenv
(
"NVTE_BATCH_MHA_P2P_COMM"
,
"0"
))
causal
=
"causal"
in
attn_mask_type
causal
=
"causal"
in
attn_mask_type
padding
=
"padding"
in
attn_mask_type
padding
=
"padding"
in
attn_mask_type
...
@@ -1566,7 +1568,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1566,7 +1568,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
rank
=
get_distributed_rank
(
ctx
.
cp_group
)
rank
=
get_distributed_rank
(
ctx
.
cp_group
)
send_dst
=
ctx
.
cp_global_ranks
[(
rank
-
1
)
%
cp_size
*
cp_size_a2a
+
rank_a2a
]
send_dst
=
ctx
.
cp_global_ranks
[(
rank
-
1
)
%
cp_size
*
cp_size_a2a
+
rank_a2a
]
recv_src
=
ctx
.
cp_global_ranks
[(
rank
+
1
)
%
cp_size
*
cp_size_a2a
+
rank_a2a
]
recv_src
=
ctx
.
cp_global_ranks
[(
rank
+
1
)
%
cp_size
*
cp_size_a2a
+
rank_a2a
]
batch_p2p_comm
=
int
(
os
.
getenv
(
"NVTE_BATCH_MHA_P2P_COMM"
,
"0"
))
or
(
cp_size
==
2
)
batch_p2p_comm
=
int
(
os
.
getenv
(
"NVTE_BATCH_MHA_P2P_COMM"
,
"0"
))
q
,
kv
,
out
,
softmax_lse
,
cu_seqlens_q_padded
,
cu_seqlens_kv_padded
,
*
other_tensors
=
(
q
,
kv
,
out
,
softmax_lse
,
cu_seqlens_q_padded
,
cu_seqlens_kv_padded
,
*
other_tensors
=
(
restore_from_saved
(
ctx
.
tensor_objects
,
ctx
.
saved_tensors
)
restore_from_saved
(
ctx
.
tensor_objects
,
ctx
.
saved_tensors
)
...
@@ -4323,10 +4325,9 @@ class FlashAttention(torch.nn.Module):
...
@@ -4323,10 +4325,9 @@ class FlashAttention(torch.nn.Module):
from
.cpu_offload
import
CPUOffloadEnabled
from
.cpu_offload
import
CPUOffloadEnabled
if
CPUOffloadEnabled
:
if
CPUOffloadEnabled
:
tensor_list
=
[
query_layer
,
key_layer
,
value_layer
,
cu_seqlens_q
,
cu_seqlens_kv
]
mark_activation_offload
(
for
tensor
in
tensor_list
:
query_layer
,
key_layer
,
value_layer
,
cu_seqlens_q
,
cu_seqlens_kv
if
tensor
is
not
None
:
)
tensor
.
activation_offloading
=
True
with
self
.
attention_dropout_ctx
():
with
self
.
attention_dropout_ctx
():
# | API | use cases
# | API | use cases
...
@@ -4728,12 +4729,9 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -4728,12 +4729,9 @@ class FusedAttnFunc(torch.autograd.Function):
else
:
else
:
tensor_list
=
[
q
,
k
,
v
,
out_save
]
tensor_list
=
[
q
,
k
,
v
,
out_save
]
tensor_list
.
extend
(
aux_ctx_tensors
)
qkv_layout
=
"sbhd_sbhd_sbhd"
qkv_layout
=
"sbhd_sbhd_sbhd"
for
tensor
in
tensor_list
:
mark_activation_offload
(
*
tensor_list
)
if
tensor
is
not
None
:
mark_activation_offload
(
*
aux_ctx_tensors
)
tensor
.
activation_offloading
=
True
ctx
.
is_input_fp8
=
is_input_fp8
ctx
.
is_input_fp8
=
is_input_fp8
ctx
.
is_output_fp8
=
is_output_fp8
ctx
.
is_output_fp8
=
is_output_fp8
...
@@ -6482,6 +6480,8 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -6482,6 +6480,8 @@ class MultiheadAttention(torch.nn.Module):
equal length. Please note that these formats do not reflect how
equal length. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `get_qkv_layout` to gain the layout information.
For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters
Parallelism parameters
----------------------
----------------------
...
@@ -6560,6 +6560,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -6560,6 +6560,7 @@ class MultiheadAttention(torch.nn.Module):
normalization
:
str
=
"LayerNorm"
,
normalization
:
str
=
"LayerNorm"
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
qkv_format
:
str
=
"sbhd"
,
qkv_format
:
str
=
"sbhd"
,
name
:
str
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -6611,6 +6612,8 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -6611,6 +6612,8 @@ class MultiheadAttention(torch.nn.Module):
self
.
hidden_size_q
=
self
.
hidden_size_per_attention_head
*
num_attention_heads
self
.
hidden_size_q
=
self
.
hidden_size_per_attention_head
*
num_attention_heads
self
.
hidden_size_kv
=
self
.
hidden_size_per_attention_head
*
self
.
num_gqa_groups
self
.
hidden_size_kv
=
self
.
hidden_size_per_attention_head
*
self
.
num_gqa_groups
self
.
name
=
name
common_gemm_kwargs
=
{
common_gemm_kwargs
=
{
"fuse_wgrad_accumulation"
:
fuse_wgrad_accumulation
,
"fuse_wgrad_accumulation"
:
fuse_wgrad_accumulation
,
"tp_group"
:
tp_group
,
"tp_group"
:
tp_group
,
...
@@ -6651,6 +6654,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -6651,6 +6654,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_ag
=
ub_overlap_ag
,
ub_overlap_ag
=
ub_overlap_ag
,
normalization
=
normalization
,
normalization
=
normalization
,
ub_name
=
"qkv"
,
ub_name
=
"qkv"
,
name
=
name
+
".layernorm_linear_qkv"
if
name
is
not
None
else
None
,
**
common_gemm_kwargs
,
**
common_gemm_kwargs
,
)
)
else
:
else
:
...
@@ -6662,6 +6666,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -6662,6 +6666,7 @@ class MultiheadAttention(torch.nn.Module):
return_bias
=
False
,
return_bias
=
False
,
parallel_mode
=
qkv_parallel_mode
,
parallel_mode
=
qkv_parallel_mode
,
parameters_split
=
parameters_split
,
parameters_split
=
parameters_split
,
name
=
name
+
".linear_qkv"
if
name
is
not
None
else
None
,
**
common_gemm_kwargs
,
**
common_gemm_kwargs
,
)
)
elif
self
.
attention_type
==
"cross"
:
elif
self
.
attention_type
==
"cross"
:
...
@@ -6683,6 +6688,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -6683,6 +6688,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_ag
=
ub_overlap_ag
,
ub_overlap_ag
=
ub_overlap_ag
,
normalization
=
normalization
,
normalization
=
normalization
,
ub_name
=
"qkv"
,
ub_name
=
"qkv"
,
name
=
name
+
".layernorm_linear_q"
if
name
is
not
None
else
None
,
**
common_gemm_kwargs
,
**
common_gemm_kwargs
,
)
)
else
:
else
:
...
@@ -6693,6 +6699,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -6693,6 +6699,7 @@ class MultiheadAttention(torch.nn.Module):
bias
=
bias
,
bias
=
bias
,
return_bias
=
False
,
return_bias
=
False
,
parallel_mode
=
qkv_parallel_mode
,
parallel_mode
=
qkv_parallel_mode
,
name
=
name
+
".linear_q"
if
name
is
not
None
else
None
,
**
common_gemm_kwargs
,
**
common_gemm_kwargs
,
)
)
self
.
key_value
=
Linear
(
self
.
key_value
=
Linear
(
...
@@ -6703,6 +6710,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -6703,6 +6710,7 @@ class MultiheadAttention(torch.nn.Module):
return_bias
=
False
,
return_bias
=
False
,
parallel_mode
=
qkv_parallel_mode
,
parallel_mode
=
qkv_parallel_mode
,
parameters_split
=
(
"key"
,
"value"
)
if
not
fuse_qkv_params
else
None
,
parameters_split
=
(
"key"
,
"value"
)
if
not
fuse_qkv_params
else
None
,
name
=
name
+
".linear_kv"
if
name
is
not
None
else
None
,
**
common_gemm_kwargs
,
**
common_gemm_kwargs
,
)
)
...
@@ -6732,6 +6740,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -6732,6 +6740,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_rs
=
ub_overlap_rs
,
ub_overlap_rs
=
ub_overlap_rs
,
ub_overlap_ag
=
ub_overlap_ag
,
ub_overlap_ag
=
ub_overlap_ag
,
ub_name
=
"proj"
,
ub_name
=
"proj"
,
name
=
name
+
".proj"
if
name
is
not
None
else
None
,
**
common_gemm_kwargs
,
**
common_gemm_kwargs
,
)
)
...
@@ -6922,6 +6931,9 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -6922,6 +6931,9 @@ class MultiheadAttention(torch.nn.Module):
core_attention_bias_type
in
AttnBiasTypes
core_attention_bias_type
in
AttnBiasTypes
),
f
"core_attention_bias_type
{
core_attention_bias_type
}
is not supported!"
),
f
"core_attention_bias_type
{
core_attention_bias_type
}
is not supported!"
if
TEDebugState
.
debug_enabled
:
TransformerEngineBaseModule
.
_validate_name
(
self
)
# =================================================
# =================================================
# Pre-allocate memory for key-value cache for inference
# Pre-allocate memory for key-value cache for inference
# =================================================
# =================================================
...
...
transformer_engine/pytorch/constants.py
View file @
ab3e5a92
...
@@ -24,6 +24,12 @@ TE_DType = {
...
@@ -24,6 +24,12 @@ TE_DType = {
torch
.
bfloat16
:
tex
.
DType
.
kBFloat16
,
torch
.
bfloat16
:
tex
.
DType
.
kBFloat16
,
}
}
"""
This is a map: int -> torch.dtype
Used for resolving cuda extension types to torch.
Has one to one mapping with enum in
transformer_engine.h
"""
TE_DType_To_Torch
=
{
TE_DType_To_Torch
=
{
tex
.
DType
.
kByte
:
torch
.
uint8
,
tex
.
DType
.
kByte
:
torch
.
uint8
,
tex
.
DType
.
kFloat8E4M3
:
torch
.
float8_e4m3fn
,
tex
.
DType
.
kFloat8E4M3
:
torch
.
float8_e4m3fn
,
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
ab3e5a92
...
@@ -9,11 +9,11 @@ import os
...
@@ -9,11 +9,11 @@ import os
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
..constants
import
TE_DType
from
..constants
import
TE_DType
from
..utils
import
assert_dim_for_fp8_exec
,
get_sm_count
from
..utils
import
get_sm_count
from
..tensor.quantized_tensor
import
Quantizer
from
..tensor.quantized_tensor
import
Quantizer
from
..tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..tensor._internal.float8_
blockwise_
tensor_base
import
Float8
BlockwiseQ
TensorBase
from
..
tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..
.debug.pytorch.debug_quantization
import
DebugQuantizer
__all__
=
[
__all__
=
[
"general_gemm"
,
"general_gemm"
,
...
@@ -28,46 +28,6 @@ def _empty_tensor() -> torch.Tensor:
...
@@ -28,46 +28,6 @@ def _empty_tensor() -> torch.Tensor:
return
torch
.
Tensor
().
cuda
()
return
torch
.
Tensor
().
cuda
()
def
swizzle_inputs
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
layout
:
str
):
"""Swizzle gemm inputs and return original scaling factor inverses."""
if
not
isinstance
(
A
,
MXFP8TensorBase
)
or
not
isinstance
(
B
,
MXFP8TensorBase
):
return
None
original_scale_inverses
=
(
A
.
_rowwise_scale_inv
,
A
.
_columnwise_scale_inv
,
B
.
_rowwise_scale_inv
,
B
.
_columnwise_scale_inv
,
)
if
layout
[
0
]
==
"T"
:
A
.
_rowwise_scale_inv
=
tex
.
rowwise_swizzle
(
A
.
_rowwise_data
,
A
.
_rowwise_scale_inv
)
else
:
A
.
_columnwise_scale_inv
=
tex
.
columnwise_swizzle
(
A
.
_columnwise_data
,
A
.
_columnwise_scale_inv
)
if
layout
[
1
]
==
"N"
:
B
.
_rowwise_scale_inv
=
tex
.
rowwise_swizzle
(
B
.
_rowwise_data
,
B
.
_rowwise_scale_inv
)
else
:
B
.
_columnwise_scale_inv
=
tex
.
columnwise_swizzle
(
B
.
_columnwise_data
,
B
.
_columnwise_scale_inv
)
return
original_scale_inverses
def
reset_swizzled_inputs
(
A
,
B
,
scale_inverses
):
"""Reset the swizzled scale inverses after GEMM."""
if
scale_inverses
is
not
None
:
(
A
.
_rowwise_scale_inv
,
A
.
_columnwise_scale_inv
,
B
.
_rowwise_scale_inv
,
B
.
_columnwise_scale_inv
,
)
=
scale_inverses
def
general_gemm
(
def
general_gemm
(
A
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
...
@@ -110,9 +70,20 @@ def general_gemm(
...
@@ -110,9 +70,20 @@ def general_gemm(
if
not
out
.
is_contiguous
():
if
not
out
.
is_contiguous
():
raise
ValueError
(
"Output tensor is not contiguous."
)
raise
ValueError
(
"Output tensor is not contiguous."
)
debug_quantizer
=
None
if
isinstance
(
quantization_params
,
DebugQuantizer
):
debug_quantizer
=
quantization_params
quantization_params
=
quantization_params
.
parent_quantizer
A
=
A
.
get_tensor
(
not
transa
)
B
=
B
.
get_tensor
(
transb
)
# Use bfloat16 as default bias_dtype
# Use bfloat16 as default bias_dtype
bias_dtype
=
TE_DType
[
torch
.
bfloat16
if
bias
is
None
else
bias
.
dtype
]
bias_dtype
=
TE_DType
[
torch
.
bfloat16
if
bias
is
None
else
bias
.
dtype
]
if
isinstance
(
A
,
Float8BlockwiseQTensorBase
)
or
isinstance
(
B
,
Float8BlockwiseQTensorBase
):
# There is not use_split_accumulator == False
# implementation for Float8BlockwiseQTensorBase GEMM
use_split_accumulator
=
True
args
=
(
args
=
(
A
,
A
,
transa
,
# transa
transa
,
# transa
...
@@ -138,9 +109,10 @@ def general_gemm(
...
@@ -138,9 +109,10 @@ def general_gemm(
"bulk_overlap"
:
bulk_overlap
,
"bulk_overlap"
:
bulk_overlap
,
}
}
original_scale_inverses
=
swizzle_inputs
(
A
,
B
,
layout
)
out
,
bias_grad
,
gelu_input
,
extra_output
=
tex
.
generic_gemm
(
*
args
,
**
kwargs
)
out
,
bias_grad
,
gelu_input
,
extra_output
=
tex
.
generic_gemm
(
*
args
,
**
kwargs
)
reset_swizzled_inputs
(
A
,
B
,
original_scale_inverses
)
if
debug_quantizer
is
not
None
:
out
=
debug_quantizer
.
process_gemm_output
(
out
)
return
out
,
bias_grad
,
gelu_input
,
extra_output
return
out
,
bias_grad
,
gelu_input
,
extra_output
...
@@ -170,14 +142,6 @@ def general_grouped_gemm(
...
@@ -170,14 +142,6 @@ def general_grouped_gemm(
transa
=
layout
[
0
]
==
"T"
transa
=
layout
[
0
]
==
"T"
transb
=
layout
[
1
]
==
"T"
transb
=
layout
[
1
]
==
"T"
# assert [a.is_contiguous() for a in A]
# assert [b.is_contiguous() for b in B]
if
isinstance
(
A
[
0
],
Float8TensorBase
):
for
a
,
b
in
zip
(
A
,
B
):
assert_dim_for_fp8_exec
(
a
.
_data
)
assert_dim_for_fp8_exec
(
b
.
_data
)
empty_tensor
=
_empty_tensor
()
empty_tensor
=
_empty_tensor
()
empty_tensors
=
[
empty_tensor
]
*
num_gemms
empty_tensors
=
[
empty_tensor
]
*
num_gemms
...
...
transformer_engine/pytorch/cpu_offload.py
View file @
ab3e5a92
...
@@ -16,18 +16,22 @@ __all__ = ["get_cpu_offload_context"]
...
@@ -16,18 +16,22 @@ __all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled
=
False
CPUOffloadEnabled
=
False
def
set_offloading_param
(
tensor
,
param_name
,
value
):
def
mark_activation_offload
(
*
tensors
):
"""Set the type of the offloading needed for a tensor."""
"""Set the type of the offloading needed for a tensor."""
assert
param_name
in
[
"weight_offloading"
,
"activation_offloading"
]
for
tensor
in
tensors
:
if
tensor
is
None
:
if
tensor
is
None
:
return
continue
if
type
(
tensor
)
in
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
]:
if
type
(
tensor
)
in
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
]:
setattr
(
tensor
,
param_name
,
val
ue
)
tensor
.
activation_offloading
=
Tr
ue
else
:
else
:
data_tensors
=
tensor
.
get_data_tensors
()
data_tensors
=
tensor
.
get_data_tensors
()
for
tensor
in
data_tensors
:
for
tensor
in
data_tensors
:
if
tensor
is
not
None
:
if
tensor
is
not
None
:
setattr
(
tensor
,
param_name
,
value
)
tensor
.
activation_offloading
=
True
# This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorBase classes are saved in the ctx,
# and they contain the reference to their data tensors.
tensor
.
needs_force_clear
=
True
def
is_cpu_offload_enabled
()
->
bool
:
def
is_cpu_offload_enabled
()
->
bool
:
...
@@ -459,8 +463,15 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
...
@@ -459,8 +463,15 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
d2h_stream
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
d2h_stream
)
# Time to free the activation memory after usage
# Time to free the activation memory after usage
for
tensor_tag
,
_
in
self
.
tensor_tag_to_buf
.
items
():
for
tensor_tag
,
tensor_buf
in
self
.
tensor_tag_to_buf
.
items
():
if
tensor_tag
[
0
]
==
self
.
offloaded_group_count
:
if
tensor_tag
[
0
]
==
self
.
offloaded_group_count
:
if
hasattr
(
tensor_buf
,
"needs_force_clear"
):
# Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorBase class,
# which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward.
tensor_buf
.
data
=
torch
.
Tensor
()
# Release the pointer to the tensor
self
.
tensor_tag_to_buf
[
tensor_tag
]
=
None
self
.
tensor_tag_to_buf
[
tensor_tag
]
=
None
# Time to offload the next group
# Time to offload the next group
...
@@ -538,7 +549,7 @@ def get_cpu_offload_context(
...
@@ -538,7 +549,7 @@ def get_cpu_offload_context(
num_layers
:
int
=
1
,
num_layers
:
int
=
1
,
model_layers
:
int
=
1
,
model_layers
:
int
=
1
,
offload_activations
:
bool
=
True
,
offload_activations
:
bool
=
True
,
offload_weights
:
bool
=
Tru
e
,
offload_weights
:
bool
=
Fals
e
,
):
):
"""
"""
This function returns the CPU Offload context and the synchronizer function that needs to be
This function returns the CPU Offload context and the synchronizer function that needs to be
...
@@ -570,28 +581,30 @@ def get_cpu_offload_context(
...
@@ -570,28 +581,30 @@ def get_cpu_offload_context(
"""
"""
def
tensor_need_offloading_checker_activations
(
tensor
):
if
not
offload_weights
and
not
offload_activations
:
return
hasattr
(
tensor
,
"activation_offloading"
)
# This includes the Gradient Accumulation Buffer
def
tensor_need_offloading_checker_weights
(
tensor
):
return
hasattr
(
tensor
,
"weight_offloading"
)
def
tensor_need_offloading_checker_all
(
tensor
):
return
hasattr
(
tensor
,
"activation_offloading"
)
or
hasattr
(
tensor
,
"weight_offloading"
)
if
offload_activations
and
offload_weights
:
tensor_need_offloading_checker
=
tensor_need_offloading_checker_all
elif
offload_activations
:
tensor_need_offloading_checker
=
tensor_need_offloading_checker_activations
elif
offload_weights
:
tensor_need_offloading_checker
=
tensor_need_offloading_checker_weights
else
:
raise
ValueError
(
raise
ValueError
(
"CPU Offloading is enabled while it is not "
"CPU Offloading is enabled while it is not "
"mentioned what to offload (weights/activations)"
"mentioned what to offload (weights/activations)"
)
)
if
offload_weights
:
import
warnings
warnings
.
warn
(
"Offloading weights is deprecated. Using offload_weights=True does not have any"
" effect."
,
DeprecationWarning
,
)
# Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
if
not
offload_activations
:
return
nullcontext
(),
lambda
x
:
x
def
tensor_need_offloading_checker_activations
(
tensor
):
return
hasattr
(
tensor
,
"activation_offloading"
)
tensor_need_offloading_checker
=
tensor_need_offloading_checker_activations
cpu_offload_handler
=
AsyncDoubleBufferGroupOffloadHandler
(
cpu_offload_handler
=
AsyncDoubleBufferGroupOffloadHandler
(
num_offload_group
=
num_layers
,
num_offload_group
=
num_layers
,
num_model_group
=
model_layers
,
num_model_group
=
model_layers
,
...
...
transformer_engine/pytorch/csrc/common.h
View file @
ab3e5a92
...
@@ -167,6 +167,38 @@ class Float8CurrentScalingQuantizer : public Quantizer {
...
@@ -167,6 +167,38 @@ class Float8CurrentScalingQuantizer : public Quantizer {
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
};
};
class
Float8BlockQuantizer
:
public
Quantizer
{
public:
// Which float8 type is used for q data.
DType
dtype
;
// Options about how to quantize the tensor
// Quantization scales are rounded down to powers of 2.
bool
force_pow_2_scales
=
false
;
// Amax within quantization tile has a floor of epsilon.
float
amax_epsilon
=
0.0
;
private:
int
block_scaling_dim
=
2
;
public:
// Initializes from a python handle to a Float8BlockQuantizer
explicit
Float8BlockQuantizer
(
const
py
::
handle
&
quantizer
);
NVTEScalingMode
get_scaling_mode
()
const
override
{
return
(
block_scaling_dim
==
2
)
?
NVTE_BLOCK_SCALING_2D
:
NVTE_BLOCK_SCALING_1D
;
}
// Gets rowwise and columnwise_data from tensor and sets them on wrapper
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
override
;
// Create a python Float8BlockQuantized tensor and C++ wrapper
// for the tensor. Should set quantized data, scales for rowwise
// and optionally columnwise usage.
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
};
class
MXFP8Quantizer
:
public
Quantizer
{
class
MXFP8Quantizer
:
public
Quantizer
{
public:
public:
DType
dtype
;
DType
dtype
;
...
...
transformer_engine/pytorch/csrc/extensions.h
View file @
ab3e5a92
...
@@ -50,11 +50,11 @@ std::vector<py::object> fused_attn_fwd(
...
@@ -50,11 +50,11 @@ std::vector<py::object> fused_attn_fwd(
NVTE_Mask_Type
attn_mask_type
,
const
std
::
vector
<
int64_t
>
window_size
,
NVTE_Mask_Type
attn_mask_type
,
const
std
::
vector
<
int64_t
>
window_size
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
const
c10
::
optional
<
at
::
Tensor
>
page_table_k
,
const
c10
::
optional
<
at
::
Tensor
>
page_table_v
,
const
std
::
optional
<
at
::
Tensor
>
page_table_k
,
const
std
::
optional
<
at
::
Tensor
>
page_table_v
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
c10
::
optional
<
at
::
Tensor
>
Bias
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
Bias
,
const
c10
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
);
const
std
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
);
std
::
vector
<
py
::
object
>
fused_attn_bwd
(
std
::
vector
<
py
::
object
>
fused_attn_bwd
(
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
p_dropout
,
bool
set_zero
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
p_dropout
,
bool
set_zero
,
...
@@ -63,8 +63,8 @@ std::vector<py::object> fused_attn_bwd(
...
@@ -63,8 +63,8 @@ std::vector<py::object> fused_attn_bwd(
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
py
::
handle
O
,
const
py
::
handle
dO
,
const
at
::
ScalarType
fake_dtype
,
const
py
::
handle
O
,
const
py
::
handle
dO
,
const
at
::
ScalarType
fake_dtype
,
const
transformer_engine
::
DType
dqkv_type
,
const
std
::
vector
<
at
::
Tensor
>
Aux_CTX_Tensors
,
const
transformer_engine
::
DType
dqkv_type
,
const
std
::
vector
<
at
::
Tensor
>
Aux_CTX_Tensors
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
py
::
handle
dp_quantizer
,
py
::
handle
dqkv_quantizer
);
py
::
handle
dp_quantizer
,
py
::
handle
dqkv_quantizer
);
at
::
Tensor
fa_prepare_fwd
(
at
::
Tensor
qkvi
);
at
::
Tensor
fa_prepare_fwd
(
at
::
Tensor
qkvi
);
...
@@ -121,18 +121,22 @@ std::vector<at::Tensor> te_batchgemm_ts(
...
@@ -121,18 +121,22 @@ std::vector<at::Tensor> te_batchgemm_ts(
int64_t
workspaceSize
,
int64_t
accumulate
,
int64_t
use_split_accumulator
);
int64_t
workspaceSize
,
int64_t
accumulate
,
int64_t
use_split_accumulator
);
#endif
#endif
namespace
transformer_engine
::
pytorch
{
/***************************************************************************************************
/***************************************************************************************************
* Transpose
* Transpose
**************************************************************************************************/
**************************************************************************************************/
std
::
vector
<
py
::
object
>
fused_multi_quantize
(
std
::
vector
<
py
::
handle
>
input_list
,
std
::
vector
<
py
::
object
>
fused_multi_quantize
(
std
::
vector
<
at
::
Tensor
>
input_list
,
std
::
optional
<
std
::
vector
<
py
::
handle
>>
output_list
,
std
::
optional
<
std
::
vector
<
py
::
object
>>
output_list
,
std
::
vector
<
py
::
handle
>
quantizer_list
,
std
::
vector
<
py
::
handle
>
quantizer_list
,
transformer_engine
::
DType
otype
);
transformer_engine
::
DType
otype
);
at
::
Tensor
fp8_transpose
(
at
::
Tensor
input
,
transformer_engine
::
DType
otype
,
at
::
Tensor
fp8_transpose
(
at
::
Tensor
input
,
transformer_engine
::
DType
otype
,
std
::
optional
<
at
::
Tensor
>
output
=
std
::
nullopt
);
std
::
optional
<
at
::
Tensor
>
output
=
std
::
nullopt
);
}
// namespace transformer_engine::pytorch
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
/***************************************************************************************************
/***************************************************************************************************
...
@@ -285,16 +289,14 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio
...
@@ -285,16 +289,14 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio
**************************************************************************************************/
**************************************************************************************************/
at
::
Tensor
fused_rope_forward
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
freqs
,
at
::
Tensor
fused_rope_forward
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
freqs
,
const
bool
transpose_output_memory
);
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens
,
const
int
cp_size
,
const
int
cp_rank
);
at
::
Tensor
fused_rope_backward
(
const
at
::
Tensor
&
output_grads
,
const
at
::
Tensor
&
freqs
,
at
::
Tensor
fused_rope_backward
(
const
at
::
Tensor
&
output_grads
,
const
at
::
Tensor
&
freqs
,
const
bool
transpose_output_memory
);
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens
,
const
int
cp_size
,
at
::
Tensor
fused_rope_thd_forward
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
cu_seqlens
,
const
int
cp_rank
);
const
at
::
Tensor
&
freqs
,
const
int
cp_size
,
const
int
cp_rank
);
at
::
Tensor
fused_rope_thd_backward
(
const
at
::
Tensor
&
output_grads
,
const
at
::
Tensor
&
cu_seqlens
,
const
at
::
Tensor
&
freqs
,
const
int
cp_size
,
const
int
cp_rank
);
/***************************************************************************************************
/***************************************************************************************************
* Miscellaneous
* Miscellaneous
...
@@ -394,10 +396,25 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
...
@@ -394,10 +396,25 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std
::
vector
<
size_t
>
padded_input_row_list
);
std
::
vector
<
size_t
>
padded_input_row_list
);
/***************************************************************************************************
/***************************************************************************************************
*
swizzle
*
NVSHMEM APIs
**************************************************************************************************/
**************************************************************************************************/
void
swizzle_scaling_factors
(
transformer_engine
::
TensorWrapper
&
input
,
bool
trans
);
namespace
nvshmem_api
{
void
init_nvshmem_backend
(
c10d
::
ProcessGroup
*
process_group
);
torch
::
Tensor
create_nvshmem_tensor
(
const
std
::
vector
<
int64_t
>
&
shape
,
c10
::
ScalarType
dtype
);
void
nvshmem_send_on_current_stream
(
torch
::
Tensor
src
,
torch
::
Tensor
dst
,
int
peer
,
torch
::
Tensor
signal
);
void
nvshmem_wait_on_current_stream
(
torch
::
Tensor
signal
,
const
std
::
string
&
wait_kind
);
void
nvshmem_finalize
();
}
// namespace nvshmem_api
/***************************************************************************************************
* swizzle
**************************************************************************************************/
at
::
Tensor
rowwise_swizzle
(
at
::
Tensor
input
,
at
::
Tensor
scale_inv
);
at
::
Tensor
rowwise_swizzle
(
at
::
Tensor
input
,
at
::
Tensor
scale_inv
);
...
...
transformer_engine/pytorch/csrc/extensions/activation.cpp
View file @
ab3e5a92
...
@@ -50,7 +50,12 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
...
@@ -50,7 +50,12 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
nvte_compute_scale_from_amax
(
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
nvte_compute_scale_from_amax
(
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
te_output
.
defaultShape
);
te_output
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
te_output
.
defaultShape
);
nvte_quantize
(
te_output_act
.
data
(),
te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
nvte_quantize_v2
(
te_output_act
.
data
(),
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
}
else
if
(
detail
::
IsFloat8BlockwiseQuantizers
(
quantizer
.
ptr
()))
{
// sanity check, since activation fusion is not supported for blockwise quantization yet
// need to raise an error here instead of silently going into act_func with wrong numerics
NVTE_ERROR
(
"Activation fusion is not supported for blockwise quantization yet."
);
}
else
{
}
else
{
act_func
(
te_input
.
data
(),
te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
act_func
(
te_input
.
data
(),
te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
}
}
...
...
transformer_engine/pytorch/csrc/extensions/apply_rope.cpp
View file @
ab3e5a92
...
@@ -7,138 +7,38 @@
...
@@ -7,138 +7,38 @@
#include "extensions.h"
#include "extensions.h"
at
::
Tensor
fused_rope_forward
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
freqs
,
at
::
Tensor
fused_rope_forward
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
freqs
,
const
bool
transpose_output_memory
)
{
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens
,
const
int
cp_size
,
const
int
cp_rank
)
{
using
namespace
transformer_engine
::
pytorch
;
using
namespace
transformer_engine
::
pytorch
;
TORCH_CHECK
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
freqs
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
freqs
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
input
.
size
(
0
)
<=
freqs
.
size
(
0
),
"expected freqs tensor has a longer sequence length than input"
);
TORCH_CHECK
(
freqs
.
size
(
1
)
==
1
&&
freqs
.
size
(
2
)
==
1
,
TORCH_CHECK
(
freqs
.
size
(
1
)
==
1
&&
freqs
.
size
(
2
)
==
1
,
"expected the second and third dims of the freqs tensor equal 1"
);
"expected the second and third dims of the freqs tensor equal 1"
);
TORCH_CHECK
(
input
.
size
(
3
)
>=
freqs
.
size
(
3
),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor"
);
TORCH_CHECK
(
freqs
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
TORCH_CHECK
(
freqs
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"Dtype of the freqs tensor must be float"
);
"Dtype of the freqs tensor must be float"
);
// input sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const
int
s
=
input
.
size
(
0
);
const
int
b
=
input
.
size
(
1
);
const
int
h
=
input
.
size
(
2
);
const
int
d
=
input
.
size
(
3
);
// input strides
const
int
stride_s
=
input
.
stride
(
0
);
const
int
stride_b
=
input
.
stride
(
1
);
const
int
stride_h
=
input
.
stride
(
2
);
const
int
stride_d
=
input
.
stride
(
3
);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const
int
d2
=
freqs
.
size
(
3
);
// output
// output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
auto
act_options
=
at
::
TensorOptions
().
dtype
(
input
.
scalar_type
()).
device
(
input
.
device
());
at
::
Tensor
output
;
auto
output
=
at
::
empty
(
input
.
sizes
(),
act_options
);
if
(
transpose_output_memory
)
{
output
=
torch
::
empty
({
b
,
s
,
h
,
d
},
act_options
).
transpose
(
0
,
1
);
}
else
{
output
=
torch
::
empty
({
s
,
b
,
h
,
d
},
act_options
);
}
// output strides
const
int
o_stride_s
=
output
.
stride
(
0
);
const
int
o_stride_b
=
output
.
stride
(
1
);
const
int
o_stride_h
=
output
.
stride
(
2
);
const
int
o_stride_d
=
output
.
stride
(
3
);
auto
input_cu
=
makeTransformerEngineTensor
(
input
);
auto
input_cu
=
makeTransformerEngineTensor
(
input
);
auto
freqs_cu
=
makeTransformerEngineTensor
(
freqs
);
auto
freqs_cu
=
makeTransformerEngineTensor
(
freqs
);
auto
output_cu
=
makeTransformerEngineTensor
(
output
);
auto
output_cu
=
makeTransformerEngineTensor
(
output
);
nvte_fused_rope_forward
(
input_cu
.
data
(),
freqs_cu
.
data
(),
output_cu
.
data
(),
s
,
b
,
h
,
d
,
d2
,
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
)
{
stride_s
,
stride_b
,
stride_h
,
stride_d
,
o_stride_s
,
o_stride_b
,
o_stride_h
,
o_stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
return
output
;
}
at
::
Tensor
fused_rope_backward
(
const
at
::
Tensor
&
output_grads
,
const
at
::
Tensor
&
freqs
,
const
bool
transpose_output_memory
)
{
using
namespace
transformer_engine
::
pytorch
;
TORCH_CHECK
(
output_grads
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
freqs
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
output_grads
.
size
(
0
)
<=
freqs
.
size
(
0
),
"expected freqs tensor has a longer sequence length than output_grads"
);
TORCH_CHECK
(
freqs
.
size
(
1
)
==
1
&&
freqs
.
size
(
2
)
==
1
,
"expected the second and third dims of the freqs tensor equal 1"
);
TORCH_CHECK
(
output_grads
.
size
(
3
)
>=
freqs
.
size
(
3
),
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor"
);
TORCH_CHECK
(
freqs
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"Dtype of the freqs tensor must be float"
);
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const
int
s
=
output_grads
.
size
(
0
);
const
int
b
=
output_grads
.
size
(
1
);
const
int
h
=
output_grads
.
size
(
2
);
const
int
d
=
output_grads
.
size
(
3
);
// output_grads strides
const
int
stride_s
=
output_grads
.
stride
(
0
);
const
int
stride_b
=
output_grads
.
stride
(
1
);
const
int
stride_h
=
output_grads
.
stride
(
2
);
const
int
stride_d
=
output_grads
.
stride
(
3
);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const
int
d2
=
freqs
.
size
(
3
);
auto
act_options
=
output_grads
.
options
().
requires_grad
(
false
);
at
::
Tensor
input_grads
;
if
(
transpose_output_memory
)
{
input_grads
=
torch
::
empty
({
b
,
s
,
h
,
d
},
act_options
).
transpose
(
0
,
1
);
}
else
{
input_grads
=
torch
::
empty
({
s
,
b
,
h
,
d
},
act_options
);
}
const
int
o_stride_s
=
input_grads
.
stride
(
0
);
const
int
o_stride_b
=
input_grads
.
stride
(
1
);
const
int
o_stride_h
=
input_grads
.
stride
(
2
);
const
int
o_stride_d
=
input_grads
.
stride
(
3
);
auto
output_grads_cu
=
makeTransformerEngineTensor
(
output_grads
);
auto
freqs_cu
=
makeTransformerEngineTensor
(
freqs
);
auto
input_grads_cu
=
makeTransformerEngineTensor
(
input_grads
);
nvte_fused_rope_backward
(
output_grads_cu
.
data
(),
freqs_cu
.
data
(),
input_grads_cu
.
data
(),
s
,
b
,
h
,
d
,
d2
,
stride_s
,
stride_b
,
stride_h
,
stride_d
,
o_stride_s
,
o_stride_b
,
o_stride_h
,
o_stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
return
input_grads
;
}
at
::
Tensor
fused_rope_thd_forward
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
cu_seqlens
,
const
at
::
Tensor
&
freqs
,
const
int
cp_size
,
const
int
cp_rank
)
{
using
namespace
transformer_engine
::
pytorch
;
TORCH_CHECK
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
TORCH_CHECK
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
,
"expected 1D tensor"
);
TORCH_CHECK
(
cu_seqlens
.
has_value
(),
"expected cu_seqlens tensor"
);
TORCH_CHECK
(
freqs
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
cu_seqlens
.
value
().
dim
()
==
1
,
"expected 1D tensor"
);
TORCH_CHECK
(
freqs
.
size
(
1
)
==
1
&&
freqs
.
size
(
2
)
==
1
,
"expected the second and third dims of the freqs tensor equal 1"
);
TORCH_CHECK
(
input
.
size
(
2
)
>=
freqs
.
size
(
3
),
TORCH_CHECK
(
input
.
size
(
2
)
>=
freqs
.
size
(
3
),
"expected the last dim of the input tensor equals or is "
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor"
);
"greater than the freqs tensor"
);
TORCH_CHECK
(
freqs
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"Dtype of the freqs tensor must be float"
);
// input sizes: (t, h, d)
// input sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// t: cumulative sum of sequence lengths
// h: head num
// h: head num
// d: dim of each head
// d: dim of each head
const
int
t
=
input
.
size
(
0
);
//
const int t = input.size(0);
const
int
h
=
input
.
size
(
1
);
const
int
h
=
input
.
size
(
1
);
const
int
d
=
input
.
size
(
2
);
const
int
d
=
input
.
size
(
2
);
// input strides
// input strides
...
@@ -146,51 +46,86 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_
...
@@ -146,51 +46,86 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_
const
int
stride_h
=
input
.
stride
(
1
);
const
int
stride_h
=
input
.
stride
(
1
);
const
int
stride_d
=
input
.
stride
(
2
);
const
int
stride_d
=
input
.
stride
(
2
);
// batch size
// batch size
const
int
b
=
cu_seqlens
.
size
(
0
)
-
1
;
const
int
b
=
cu_seqlens
.
value
().
size
(
0
)
-
1
;
// freqs' shape is (max_s, 1, 1, d2)
// freqs' shape is (max_s, 1, 1, d2)
const
int
max_s
=
freqs
.
size
(
0
);
const
int
max_s
=
freqs
.
size
(
0
);
const
int
d2
=
freqs
.
size
(
3
);
const
int
d2
=
freqs
.
size
(
3
);
// output
auto
cu_seqlens_cu
=
makeTransformerEngineTensor
(
cu_seqlens
.
value
());
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
auto
output
=
torch
::
empty
({
t
,
h
,
d
},
act_options
);
// output strides
const
int
o_stride_t
=
output
.
stride
(
0
);
const
int
o_stride_h
=
output
.
stride
(
1
);
const
int
o_stride_d
=
output
.
stride
(
2
);
auto
input_cu
=
makeTransformerEngineTensor
(
input
);
nvte_fused_rope_forward
(
input_cu
.
data
(),
cu_seqlens_cu
.
data
(),
freqs_cu
.
data
(),
auto
cu_seqlens_cu
=
makeTransformerEngineTensor
(
cu_seqlens
);
output_cu
.
data
(),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
max_s
,
b
,
auto
freqs_cu
=
makeTransformerEngineTensor
(
freqs
);
h
,
d
,
d2
,
stride_t
,
/*stride_b=*/
0
,
stride_h
,
stride_d
,
auto
output_cu
=
makeTransformerEngineTensor
(
output
);
nvte_fused_rope_thd_forward
(
input_cu
.
data
(),
cu_seqlens_cu
.
data
(),
freqs_cu
.
data
(),
output_cu
.
data
(),
cp_size
,
cp_rank
,
max_s
,
b
,
h
,
d
,
d2
,
stride_t
,
stride_h
,
stride_d
,
o_stride_t
,
o_stride_h
,
o_stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
at
::
cuda
::
getCurrentCUDAStream
());
return
output
;
}
TORCH_CHECK
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
// input sizes: (s, b, h, d) or (b, s, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const
int
s
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
input
.
size
(
0
)
:
input
.
size
(
1
);
const
int
b
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
input
.
size
(
1
)
:
input
.
size
(
0
);
const
int
h
=
input
.
size
(
2
);
const
int
d
=
input
.
size
(
3
);
// input strides
const
int
stride_s
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
input
.
stride
(
0
)
:
input
.
stride
(
1
);
const
int
stride_b
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
input
.
stride
(
1
)
:
input
.
stride
(
0
);
const
int
stride_h
=
input
.
stride
(
2
);
const
int
stride_d
=
input
.
stride
(
3
);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const
int
d2
=
freqs
.
size
(
3
);
TORCH_CHECK
(
s
*
cp_size
<=
freqs
.
size
(
0
),
"expected freqs tensor has a longer sequence length than input"
);
TORCH_CHECK
(
d
>=
d2
,
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor"
);
auto
cu_seqlens_cu
=
transformer_engine
::
TensorWrapper
();
// empty cu_seqlens tensor
nvte_fused_rope_forward
(
input_cu
.
data
(),
cu_seqlens_cu
.
data
(),
freqs_cu
.
data
(),
output_cu
.
data
(),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_s
,
stride_b
,
stride_h
,
stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
return
output
;
return
output
;
}
}
at
::
Tensor
fused_rope_thd_backward
(
const
at
::
Tensor
&
output_grads
,
const
at
::
Tensor
&
cu_seqlens
,
at
::
Tensor
fused_rope_backward
(
const
at
::
Tensor
&
output_grads
,
const
at
::
Tensor
&
freqs
,
const
at
::
Tensor
&
freqs
,
const
int
cp_size
,
const
int
cp_rank
)
{
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens
,
const
int
cp_size
,
const
int
cp_rank
)
{
using
namespace
transformer_engine
::
pytorch
;
using
namespace
transformer_engine
::
pytorch
;
TORCH_CHECK
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
,
"expected 1D tensor"
);
TORCH_CHECK
(
freqs
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
freqs
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
freqs
.
size
(
1
)
==
1
&&
freqs
.
size
(
2
)
==
1
,
TORCH_CHECK
(
freqs
.
size
(
1
)
==
1
&&
freqs
.
size
(
2
)
==
1
,
"expected the second and third dims of the freqs tensor equal 1"
);
"expected the second and third dims of the freqs tensor equal 1"
);
TORCH_CHECK
(
freqs
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"Dtype of the freqs tensor must be float"
);
auto
act_options
=
at
::
TensorOptions
().
dtype
(
output_grads
.
scalar_type
()).
device
(
output_grads
.
device
());
auto
input_grads
=
at
::
empty
(
output_grads
.
sizes
(),
act_options
);
auto
output_grads_cu
=
makeTransformerEngineTensor
(
output_grads
);
auto
freqs_cu
=
makeTransformerEngineTensor
(
freqs
);
auto
input_grads_cu
=
makeTransformerEngineTensor
(
input_grads
);
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
)
{
TORCH_CHECK
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
TORCH_CHECK
(
cu_seqlens
.
has_value
(),
"expected cu_seqlens tensor"
);
TORCH_CHECK
(
cu_seqlens
.
value
().
dim
()
==
1
,
"expected 1D tensor"
);
TORCH_CHECK
(
output_grads
.
size
(
2
)
>=
freqs
.
size
(
3
),
TORCH_CHECK
(
output_grads
.
size
(
2
)
>=
freqs
.
size
(
3
),
"expected the last dim of the output_grads tensor equals or is "
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor"
);
"greater than the freqs tensor"
);
TORCH_CHECK
(
freqs
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"Dtype of the freqs tensor must be float"
);
// output_grads sizes: (t, h, d)
// output_grads sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// t: cumulative sum of sequence lengths
// h: head num
// h: head num
// d: dim of each head
// d: dim of each head
const
int
t
=
output_grads
.
size
(
0
);
//
const int t = output_grads.size(0);
const
int
h
=
output_grads
.
size
(
1
);
const
int
h
=
output_grads
.
size
(
1
);
const
int
d
=
output_grads
.
size
(
2
);
const
int
d
=
output_grads
.
size
(
2
);
// output_grads strides
// output_grads strides
...
@@ -198,25 +133,54 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten
...
@@ -198,25 +133,54 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten
const
int
stride_h
=
output_grads
.
stride
(
1
);
const
int
stride_h
=
output_grads
.
stride
(
1
);
const
int
stride_d
=
output_grads
.
stride
(
2
);
const
int
stride_d
=
output_grads
.
stride
(
2
);
// batch size
// batch size
const
int
b
=
cu_seqlens
.
size
(
0
)
-
1
;
const
int
b
=
cu_seqlens
.
value
().
size
(
0
)
-
1
;
// freqs' shape is (max_s, 1, 1, d2)
// freqs' shape is (max_s, 1, 1, d2)
const
int
max_s
=
freqs
.
size
(
0
);
const
int
max_s
=
freqs
.
size
(
0
);
const
int
d2
=
freqs
.
size
(
3
);
const
int
d2
=
freqs
.
size
(
3
);
auto
act_options
=
output_grads
.
options
().
requires_grad
(
false
);
auto
cu_seqlens_cu
=
makeTransformerEngineTensor
(
cu_seqlens
.
value
());
auto
input_grads
=
torch
::
empty
({
t
,
h
,
d
},
act_options
);
const
int
o_stride_t
=
input_grads
.
stride
(
0
);
const
int
o_stride_h
=
input_grads
.
stride
(
1
);
const
int
o_stride_d
=
input_grads
.
stride
(
2
);
auto
output_grads_cu
=
makeTransformerEngineTensor
(
output_grads
);
nvte_fused_rope_backward
(
output_grads_cu
.
data
(),
cu_seqlens_cu
.
data
(),
freqs_cu
.
data
(),
auto
cu_seqlens_cu
=
makeTransformerEngineTensor
(
cu_seqlens
);
input_grads_cu
.
data
(),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
auto
freqs_cu
=
makeTransformerEngineTensor
(
freqs
);
max_s
,
b
,
h
,
d
,
d2
,
stride_t
,
/*stride_b=*/
0
,
stride_h
,
stride_d
,
auto
input_grads_cu
=
makeTransformerEngineTensor
(
input_grads
);
at
::
cuda
::
getCurrentCUDAStream
());
return
input_grads
;
}
TORCH_CHECK
(
output_grads
.
dim
()
==
4
,
"expected 4D tensor"
);
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const
int
s
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
output_grads
.
size
(
0
)
:
output_grads
.
size
(
1
);
const
int
b
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
output_grads
.
size
(
1
)
:
output_grads
.
size
(
0
);
const
int
h
=
output_grads
.
size
(
2
);
const
int
d
=
output_grads
.
size
(
3
);
// output_grads strides
const
int
stride_s
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
output_grads
.
stride
(
0
)
:
output_grads
.
stride
(
1
);
const
int
stride_b
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
output_grads
.
stride
(
1
)
:
output_grads
.
stride
(
0
);
const
int
stride_h
=
output_grads
.
stride
(
2
);
const
int
stride_d
=
output_grads
.
stride
(
3
);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const
int
d2
=
freqs
.
size
(
3
);
TORCH_CHECK
(
s
*
cp_size
<=
freqs
.
size
(
0
),
"expected freqs tensor has a longer sequence length than output_grads"
);
TORCH_CHECK
(
d
>=
d2
,
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor"
);
nvte_fused_rope_thd_backward
(
output_grads_cu
.
data
(),
cu_seqlens_cu
.
data
(),
freqs_cu
.
data
(),
auto
cu_seqlens_cu
=
transformer_engine
::
TensorWrapper
();
// empty cu_seqlens tensor
input_grads_cu
.
data
(),
cp_size
,
cp_rank
,
max_s
,
b
,
h
,
d
,
d2
,
nvte_fused_rope_backward
(
output_grads_cu
.
data
(),
cu_seqlens_cu
.
data
(),
freqs_cu
.
data
(),
stride_t
,
stride_h
,
stride_d
,
o_stride_t
,
o_stride_h
,
o_stride_d
,
input_grads_cu
.
data
(),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_s
,
stride_b
,
stride_h
,
stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
at
::
cuda
::
getCurrentCUDAStream
());
return
input_grads
;
return
input_grads
;
...
...
transformer_engine/pytorch/csrc/extensions/attention.cu
View file @
ab3e5a92
...
@@ -3,9 +3,11 @@
...
@@ -3,9 +3,11 @@
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "extensions.h"
#include "extensions.h"
#include "kv_cache.cuh"
#include "kv_cache.cuh"
#include "thd_utils.cuh"
#include "thd_utils.cuh"
#include "transformer_engine/transformer_engine.h"
constexpr
int
block_size
=
512
;
constexpr
int
block_size
=
512
;
constexpr
int
ctas_per_sm
=
4
;
constexpr
int
ctas_per_sm
=
4
;
...
@@ -95,11 +97,11 @@ std::vector<py::object> fused_attn_fwd(
...
@@ -95,11 +97,11 @@ std::vector<py::object> fused_attn_fwd(
NVTE_Mask_Type
attn_mask_type
,
const
std
::
vector
<
int64_t
>
window_size
,
NVTE_Mask_Type
attn_mask_type
,
const
std
::
vector
<
int64_t
>
window_size
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
const
c10
::
optional
<
at
::
Tensor
>
page_table_k
,
const
c10
::
optional
<
at
::
Tensor
>
page_table_v
,
const
std
::
optional
<
at
::
Tensor
>
page_table_k
,
const
std
::
optional
<
at
::
Tensor
>
page_table_v
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
c10
::
optional
<
at
::
Tensor
>
Bias
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
Bias
,
const
c10
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
)
{
const
std
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
assert
(
false
);
#else
#else
...
@@ -289,8 +291,8 @@ std::vector<py::object> fused_attn_bwd(
...
@@ -289,8 +291,8 @@ std::vector<py::object> fused_attn_bwd(
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
py
::
handle
O
,
const
py
::
handle
dO
,
const
at
::
ScalarType
fake_dtype
,
const
py
::
handle
O
,
const
py
::
handle
dO
,
const
at
::
ScalarType
fake_dtype
,
const
transformer_engine
::
DType
dqkv_type
,
const
std
::
vector
<
at
::
Tensor
>
Aux_CTX_Tensors
,
const
transformer_engine
::
DType
dqkv_type
,
const
std
::
vector
<
at
::
Tensor
>
Aux_CTX_Tensors
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
py
::
handle
dp_quantizer
,
py
::
handle
dqkv_quantizer
)
{
py
::
handle
dp_quantizer
,
py
::
handle
dqkv_quantizer
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
assert
(
false
);
...
@@ -461,13 +463,13 @@ std::vector<py::object> fused_attn_bwd(
...
@@ -461,13 +463,13 @@ std::vector<py::object> fused_attn_bwd(
nvte_tensor_pack_create
(
&
nvte_aux_tensor_pack
);
nvte_tensor_pack_create
(
&
nvte_aux_tensor_pack
);
nvte_aux_tensor_pack
.
size
=
Aux_CTX_Tensors
.
size
();
nvte_aux_tensor_pack
.
size
=
Aux_CTX_Tensors
.
size
();
for
(
size_t
i
=
0
;
i
<
nvte_aux_tensor_pack
.
size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nvte_aux_tensor_pack
.
size
;
++
i
)
{
std
::
vector
<
int64_t
>
tmp
(
Aux_CTX_Tensors
[
i
].
sizes
().
vec
()
)
;
const
std
::
vector
<
int64_t
>
&
signed_shape
=
Aux_CTX_Tensors
[
i
].
sizes
().
vec
();
auto
temp_vec
=
std
::
vector
<
size_t
>
(
tmp
.
begin
(),
tmp
.
end
());
const
std
::
vector
<
size_t
>
tmp
(
signed_shape
.
begin
(),
signed_shape
.
end
());
const
NVTEShape
temp_shape
=
{
temp_vec
.
data
(),
temp_vec
.
size
()};
NVTEBasicTensor
temp_data
=
{
NVTEBasicTensor
temp_data
=
{
Aux_CTX_Tensors
[
i
].
data_ptr
(),
Aux_CTX_Tensors
[
i
].
data_ptr
(),
static_cast
<
NVTEDType
>
(
GetTransformerEngineDType
(
Aux_CTX_Tensors
[
i
].
scalar_type
())),
static_cast
<
NVTEDType
>
(
GetTransformerEngineDType
(
Aux_CTX_Tensors
[
i
].
scalar_type
())),
temp_shape
};
nvte_make_shape
(
tmp
.
data
(),
tmp
.
size
())
};
nvte_set_tensor_param
(
&
nvte_aux_tensor_pack
.
tensors
[
i
],
kNVTERowwiseData
,
&
temp_data
);
nvte_set_tensor_param
(
&
nvte_aux_tensor_pack
.
tensors
[
i
],
kNVTERowwiseData
,
&
temp_data
);
}
}
...
...
transformer_engine/pytorch/csrc/extensions/cast.cpp
View file @
ab3e5a92
...
@@ -46,6 +46,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
...
@@ -46,6 +46,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
if
(
te_output
.
numel
()
==
0
)
return
out
;
if
(
te_output
.
numel
()
==
0
)
return
out
;
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_noop_tensor
(
te_noop
.
data
());
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// my_quantizer here has to be a Float8CurrentScalingQuantizer
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
...
@@ -61,14 +64,20 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
...
@@ -61,14 +64,20 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
}
}
QuantizationConfigWrapper
quant_config
;
// this config is used for cs scaling factor computation
// because compute scale is cannot be fused with quantize kernel
// so in nvte_quantize_v2 with current scaling, the quant config is not used again
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
nvte_compute_scale_from_amax
(
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
nvte_compute_scale_from_amax
(
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
te_output
.
defaultShape
);
te_output
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
te_output
.
defaultShape
);
}
else
if
(
detail
::
IsFloat8BlockwiseQuantizers
(
quantizer
.
ptr
()))
{
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_quantizer
.
get
());
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
}
}
nvte_quantize_
noop
(
te_input
.
data
(),
te_output
.
data
(),
te_noop
.
data
()
,
nvte_quantize_
v2
(
te_input
.
data
(),
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
at
::
cuda
::
getCurrentCUDAStream
());
return
out
;
return
out
;
...
...
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
View file @
ab3e5a92
...
@@ -157,15 +157,15 @@ void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool
...
@@ -157,15 +157,15 @@ void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool
char
*
ubuf_ptr
=
reinterpret_cast
<
char
*>
(
_ubuf
.
dptr
());
char
*
ubuf_ptr
=
reinterpret_cast
<
char
*>
(
_ubuf
.
dptr
());
if
(
local_chunk
)
{
if
(
local_chunk
)
{
if
(
input_tensor
.
numel
()
*
_tp_size
>
(
int64_t
)
_ubuf
.
numel
())
if
(
input_tensor
.
numel
()
*
_tp_size
>
_ubuf
.
numel
())
NVTE_ERROR
(
"input is larger than the local communication buffer!"
);
NVTE_ERROR
(
"input is larger than the local communication buffer!"
);
if
(
input_tensor
.
element_size
()
!=
(
int64_t
)
_ubuf
.
element_size
())
if
(
input_tensor
.
element_size
()
!=
_ubuf
.
element_size
())
NVTE_ERROR
(
"input data type does not match communication buffer!"
);
NVTE_ERROR
(
"input data type does not match communication buffer!"
);
ubuf_ptr
+=
(
_ubuf
.
numel
()
/
_tp_size
)
*
_tp_id
*
_ubuf
.
element_size
();
ubuf_ptr
+=
(
_ubuf
.
numel
()
/
_tp_size
)
*
_tp_id
*
_ubuf
.
element_size
();
}
else
{
}
else
{
if
(
input_tensor
.
numel
()
>
(
int64_t
)
_ubuf
.
numel
())
if
(
input_tensor
.
numel
()
>
_ubuf
.
numel
())
NVTE_ERROR
(
"input is larger than the global communication buffer!"
);
NVTE_ERROR
(
"input is larger than the global communication buffer!"
);
if
(
input_tensor
.
element_size
()
!=
(
int64_t
)
_ubuf
.
element_size
())
if
(
input_tensor
.
element_size
()
!=
_ubuf
.
element_size
())
NVTE_ERROR
(
"input data type does not match communication buffer!"
);
NVTE_ERROR
(
"input data type does not match communication buffer!"
);
}
}
...
@@ -189,7 +189,7 @@ py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk,
...
@@ -189,7 +189,7 @@ py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk,
std
::
vector
<
int64_t
>
torch_shape
;
std
::
vector
<
int64_t
>
torch_shape
;
if
(
shape
.
has_value
())
{
if
(
shape
.
has_value
())
{
torch_shape
=
shape
.
value
();
torch_shape
=
shape
.
value
();
auto
requested
=
product
(
torch_shape
);
size_t
requested
=
product
(
torch_shape
);
auto
expected
=
local_chunk
?
_ubuf
.
numel
()
/
_tp_size
:
_ubuf
.
numel
();
auto
expected
=
local_chunk
?
_ubuf
.
numel
()
/
_tp_size
:
_ubuf
.
numel
();
NVTE_CHECK
(
requested
==
expected
,
"Number of elements in the requested shape ("
,
requested
,
NVTE_CHECK
(
requested
==
expected
,
"Number of elements in the requested shape ("
,
requested
,
") does not match allocated buffer size ("
,
expected
,
")!"
);
") does not match allocated buffer size ("
,
expected
,
")!"
);
...
@@ -253,18 +253,18 @@ void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bo
...
@@ -253,18 +253,18 @@ void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bo
at
::
cuda
::
CUDAStream
stream_main
=
at
::
cuda
::
getCurrentCUDAStream
();
at
::
cuda
::
CUDAStream
stream_main
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
local_chunk
)
{
if
(
local_chunk
)
{
// Copy input to the target ubuf chunk by rank offset
// Copy input to the target ubuf chunk by rank offset
if
(
input_tensor
.
numel
()
*
_tp_size
>
(
int64_t
)
_ubuf
.
numel
())
if
(
input_tensor
.
numel
()
*
_tp_size
>
_ubuf
.
numel
())
NVTE_ERROR
(
"input is larger than the local communication buffer!"
);
NVTE_ERROR
(
"input is larger than the local communication buffer!"
);
if
(
input_tensor
.
element_size
()
!=
(
int64_t
)
_ubuf
.
element_size
())
if
(
input_tensor
.
element_size
()
!=
_ubuf
.
element_size
())
NVTE_ERROR
(
"input data type does not match communication buffer!"
);
NVTE_ERROR
(
"input data type does not match communication buffer!"
);
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
_ubufs
[
_tp_id
].
dptr
(),
input_ptr
,
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
_ubufs
[
_tp_id
].
dptr
(),
input_ptr
,
input_tensor
.
numel
()
*
input_tensor
.
element_size
(),
input_tensor
.
numel
()
*
input_tensor
.
element_size
(),
cudaMemcpyDeviceToDevice
,
(
cudaStream_t
)
stream_main
));
cudaMemcpyDeviceToDevice
,
(
cudaStream_t
)
stream_main
));
}
else
{
}
else
{
if
(
input_tensor
.
numel
()
>
(
int64_t
)
_ubuf
.
numel
())
if
(
input_tensor
.
numel
()
>
_ubuf
.
numel
())
NVTE_ERROR
(
"input is larger than the global communication buffer!"
);
NVTE_ERROR
(
"input is larger than the global communication buffer!"
);
if
(
input_tensor
.
element_size
()
!=
(
int64_t
)
_ubuf
.
element_size
())
if
(
input_tensor
.
element_size
()
!=
_ubuf
.
element_size
())
NVTE_ERROR
(
"input data type does not match communication buffer!"
);
NVTE_ERROR
(
"input data type does not match communication buffer!"
);
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
_ubuf
.
dptr
(),
input_ptr
,
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
_ubuf
.
dptr
(),
input_ptr
,
input_tensor
.
numel
()
*
input_tensor
.
element_size
(),
input_tensor
.
numel
()
*
input_tensor
.
element_size
(),
...
@@ -280,7 +280,7 @@ py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk,
...
@@ -280,7 +280,7 @@ py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk,
std
::
vector
<
int64_t
>
torch_shape
;
std
::
vector
<
int64_t
>
torch_shape
;
if
(
shape
.
has_value
())
{
if
(
shape
.
has_value
())
{
torch_shape
=
shape
.
value
();
torch_shape
=
shape
.
value
();
auto
requested
=
product
(
torch_shape
);
size_t
requested
=
product
(
torch_shape
);
auto
expected
=
local_chunk
?
_ubufs
[
_tp_id
].
numel
()
:
_ubuf
.
numel
();
auto
expected
=
local_chunk
?
_ubufs
[
_tp_id
].
numel
()
:
_ubuf
.
numel
();
NVTE_CHECK
(
requested
==
expected
,
"Number of elements in the requested shape ("
,
requested
,
NVTE_CHECK
(
requested
==
expected
,
"Number of elements in the requested shape ("
,
requested
,
") does not match allocated buffer size ("
,
expected
,
")!"
);
") does not match allocated buffer size ("
,
expected
,
")!"
);
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
ab3e5a92
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include "extensions.h"
#include "extensions.h"
#include "pybind.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h"
namespace
{
namespace
{
...
@@ -179,8 +180,15 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
...
@@ -179,8 +180,15 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
const
int
sm_count
=
transformer_engine
::
cuda
::
sm_count
(
device_id
);
const
int
sm_count
=
transformer_engine
::
cuda
::
sm_count
(
device_id
);
int
num_math_sms
=
sm_count
-
transformer_engine
::
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
sm_count
);
int
num_math_sms
=
sm_count
-
transformer_engine
::
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
sm_count
);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std
::
vector
<
std
::
optional
<
at
::
Tensor
>>
swizzled_scale_inverses_list
;
auto
main_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
main_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
A_tensor
.
numel
()
!=
0
&&
B_tensor
.
numel
()
!=
0
)
{
if
(
A_tensor
.
numel
()
!=
0
&&
B_tensor
.
numel
()
!=
0
)
{
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
A_tensor
,
transa
)));
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
B_tensor
,
!
transb
)));
if
(
comm_overlap
)
{
if
(
comm_overlap
)
{
// Prepare extra output tensor
// Prepare extra output tensor
TensorWrapper
extra_output_tensor
;
TensorWrapper
extra_output_tensor
;
...
@@ -317,17 +325,18 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
...
@@ -317,17 +325,18 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_pre_gelu_out_vector
,
te_workspace_vector
;
te_pre_gelu_out_vector
,
te_workspace_vector
;
std
::
vector
<
TensorWrapper
>
wrappers
;
std
::
vector
<
TensorWrapper
>
wrappers
;
std
::
vector
<
at
::
Tensor
>
D_vectors
;
std
::
vector
<
at
::
Tensor
>
D_vectors
;
// Keep the swizzled scaling factor tensors alive during the GEMMs.
std
::
vector
<
std
::
optional
<
at
::
Tensor
>>
swizzled_scale_inverses_list
;
auto
none
=
py
::
none
();
auto
none
=
py
::
none
();
std
::
vector
<
size_t
>
single_output_begins
;
std
::
vector
<
size_t
>
single_output_begins
;
std
::
vector
<
size_t
>
single_output_ends
;
std
::
vector
<
size_t
>
single_output_ends
;
int
slicing_dim
;
if
(
single_output
&&
D
==
std
::
nullopt
)
{
if
(
single_output
&&
D
==
std
::
nullopt
)
{
NVTE_ERROR
(
"not implemented, D should be allocated for single output case."
);
NVTE_ERROR
(
"not implemented, D should be allocated for single output case."
);
}
}
void
*
output_data_ptr
;
void
*
output_data_ptr
=
nullptr
;
if
(
single_output
)
{
if
(
single_output
)
{
output_data_ptr
=
(
*
D
)[
0
].
data_ptr
();
output_data_ptr
=
(
*
D
)[
0
].
data_ptr
();
}
}
...
@@ -384,6 +393,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
...
@@ -384,6 +393,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
continue
;
continue
;
}
}
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
te_A
,
transa
)));
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
te_B
,
!
transb
)));
auto
te_D
=
makeTransformerEngineTensor
(
out_tensor
);
auto
te_D
=
makeTransformerEngineTensor
(
out_tensor
);
auto
te_bias
=
makeTransformerEngineTensor
(
bias
[
i
]);
auto
te_bias
=
makeTransformerEngineTensor
(
bias
[
i
]);
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
pre_gelu_out
[
i
]);
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
pre_gelu_out
[
i
]);
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu
View file @
ab3e5a92
...
@@ -12,6 +12,8 @@
...
@@ -12,6 +12,8 @@
// #include <torch/all.h>
// #include <torch/all.h>
#include <assert.h>
#include <assert.h>
#include <limits>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream>
#include <sstream>
...
@@ -47,8 +49,8 @@ struct ComputeScaleAndScaleInvFunctor {
...
@@ -47,8 +49,8 @@ struct ComputeScaleAndScaleInvFunctor {
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
for
(
int
i_start
=
threadIdx
.
x
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
for
(
int
i_start
=
threadIdx
.
x
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
float
scale_val
=
transformer_engine
::
compute_scale_from_amax
(
amax
[
i_start
],
max_fp8
,
float
scale_val
=
transformer_engine
::
compute_scale_from_amax
(
force_pow_2_scales
,
epsilon
);
amax
[
i_start
],
max_fp8
,
force_pow_2_scales
,
epsilon
,
std
::
numeric_limits
<
float
>::
max
()
);
scale
[
i_start
]
=
scale_val
;
scale
[
i_start
]
=
scale_val
;
transformer_engine
::
reciprocal
(
scale_inv
+
i_start
,
scale_val
);
transformer_engine
::
reciprocal
(
scale_inv
+
i_start
,
scale_val
);
}
}
...
...
transformer_engine/pytorch/csrc/extensions/normalization.cpp
View file @
ab3e5a92
...
@@ -150,6 +150,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
...
@@ -150,6 +150,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel
// Quantize output if using unfused kernel
if
(
force_unfused_kernel
)
{
if
(
force_unfused_kernel
)
{
QuantizationConfigWrapper
quant_config
;
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// my_quantizer here has to be a Float8CurrentScalingQuantizer
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
...
@@ -166,14 +167,17 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
...
@@ -166,14 +167,17 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
}
}
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
nvte_compute_scale_from_amax
(
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
nvte_compute_scale_from_amax
(
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
out_cu
.
defaultShape
);
out_cu
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
out_cu
.
defaultShape
);
}
else
if
(
IsFloat8BlockwiseQuantizers
(
quantizer
.
ptr
()))
{
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_quantizer
.
get
());
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
}
}
nvte_quantize_
noop
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
nullptr
,
nvte_quantize_
v2
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
at
::
cuda
::
getCurrentCUDAStream
());
}
}
...
@@ -293,6 +297,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
...
@@ -293,6 +297,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel
// Quantize output if using unfused kernel
if
(
force_unfused_kernel
)
{
if
(
force_unfused_kernel
)
{
QuantizationConfigWrapper
quant_config
;
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// my_quantizer here has to be a Float8CurrentScalingQuantizer
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
...
@@ -309,14 +314,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
...
@@ -309,14 +314,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
}
}
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
nvte_compute_scale_from_amax
(
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
nvte_compute_scale_from_amax
(
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
out_cu
.
defaultShape
);
out_cu
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
out_cu
.
defaultShape
);
}
else
if
(
IsFloat8BlockwiseQuantizers
(
quantizer
.
ptr
()))
{
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_quantizer
.
get
());
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
}
}
nvte_quantize_
noop
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
nullptr
,
nvte_quantize_
v2
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
at
::
cuda
::
getCurrentCUDAStream
());
}
}
...
...
transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp
0 → 100644
View file @
ab3e5a92
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#ifdef NVTE_ENABLE_NVSHMEM
#include <nvshmem.h>
#include <nvshmem_api/nvshmem_waitkernel.h>
#include <nvshmemx.h>
#endif
#include <cuda.h>
#include <cuda_fp8.h>
#include <torch/cuda.h>
#include <torch/extension.h>
namespace
nvshmem_api
{
void
init_nvshmem_backend
(
c10d
::
ProcessGroup
*
process_group
)
{
#ifdef NVTE_ENABLE_NVSHMEM
nvshmemx_init_attr_t
attr
=
{};
nvshmemx_uniqueid_t
id
=
{};
int
my_rank
=
process_group
->
getRank
();
int
num_ranks
=
process_group
->
getSize
();
if
(
my_rank
==
0
)
{
nvshmemx_get_uniqueid
(
&
id
);
}
auto
backend_is_nccl
=
(
process_group
->
getBackendType
()
==
c10d
::
ProcessGroup
::
BackendType
::
NCCL
);
NVTE_CHECK
(
backend_is_nccl
,
"Currently only support NCCL boostrap for NVSHMEM"
);
auto
datatensor
=
torch
::
from_blob
(
reinterpret_cast
<
void
*>
(
&
id
),
{
static_cast
<
int64_t
>
(
sizeof
(
nvshmemx_uniqueid_t
)
/
sizeof
(
uint8_t
))},
at
::
device
(
torch
::
kCPU
).
dtype
(
torch
::
kUInt8
));
auto
datatmp
=
(
backend_is_nccl
)
?
datatensor
.
cuda
()
:
datatensor
;
c10d
::
BroadcastOptions
bcast_opts
;
bcast_opts
.
rootRank
=
0
;
std
::
vector
<
torch
::
Tensor
>
datachunk
=
{
datatmp
};
auto
work
=
process_group
->
broadcast
(
datachunk
,
bcast_opts
);
work
->
wait
();
if
(
backend_is_nccl
)
{
datatensor
.
copy_
(
datatmp
.
cpu
());
datatmp
=
torch
::
Tensor
();
}
nvshmemx_set_attr_uniqueid_args
(
my_rank
,
num_ranks
,
&
id
,
&
attr
);
nvshmemx_init_attr
(
NVSHMEMX_INIT_WITH_UNIQUEID
,
&
attr
);
NVTE_CHECK
(
my_rank
==
nvshmem_my_pe
(),
"my_rank: "
,
my_rank
,
" != nvshmem_my_pe(): "
,
nvshmem_my_pe
());
NVTE_CHECK
(
num_ranks
==
nvshmem_n_pes
(),
"num_ranks: "
,
num_ranks
,
" != nvshmem_n_pes(): "
,
nvshmem_n_pes
());
#else
NVTE_ERROR
(
"Internal TE error: init_nvshmem_backend cannot be initialized with valid PyTorch "
,
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"
);
#endif
}
void
nvshmem_wait_on_current_stream
(
torch
::
Tensor
signal
,
const
std
::
string
&
wait_kind
)
{
#ifdef NVTE_ENABLE_NVSHMEM
uint64_t
*
sig_addr
=
reinterpret_cast
<
uint64_t
*>
(
signal
.
data_ptr
());
cudaStream_t
cur_stream
=
(
cudaStream_t
)
at
::
cuda
::
getCurrentCUDAStream
();
WaitKind
wait_kind_enum
=
WaitKind
::
STREAM_WAIT
;
if
(
wait_kind
==
"kernel"
)
{
wait_kind_enum
=
WaitKind
::
KERNEL_WAIT
;
}
else
if
(
wait_kind
==
"nvshmem"
)
{
wait_kind_enum
=
WaitKind
::
NVSHMEM_WAIT
;
}
else
if
(
wait_kind
==
"stream"
)
{
wait_kind_enum
=
WaitKind
::
STREAM_WAIT
;
}
else
{
NVTE_ERROR
(
"Invalid wait kind: "
,
wait_kind
);
}
nvshmem_wait_on_stream
(
sig_addr
,
wait_kind_enum
,
cur_stream
);
#else
NVTE_ERROR
(
"Internal TE error: nvshmem_wait_on_current_stream cannot be initialized with valid PyTorch "
,
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"
);
#endif
}
torch
::
Tensor
create_nvshmem_tensor
(
const
std
::
vector
<
int64_t
>
&
shape
,
c10
::
ScalarType
dtype
)
{
#ifdef NVTE_ENABLE_NVSHMEM
auto
option_gpu
=
at
::
TensorOptions
().
dtype
(
dtype
).
device
(
at
::
kCUDA
).
device_index
(
c10
::
cuda
::
current_device
());
auto
size
=
torch
::
elementSize
(
dtype
)
*
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<>
());
return
at
::
from_blob
(
nvshmem_malloc
(
size
),
shape
,
[](
void
*
ptr
)
{
nvshmem_free
(
ptr
);
},
option_gpu
);
#else
NVTE_ERROR
(
"Internal TE error: create_nvshmem_tensor cannot be initialized with valid PyTorch "
,
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"
);
#endif
}
void
nvshmem_send_on_current_stream
(
torch
::
Tensor
src
,
torch
::
Tensor
dst
,
int
peer
,
torch
::
Tensor
signal
)
{
#ifdef NVTE_ENABLE_NVSHMEM
void
*
src_ptr
=
reinterpret_cast
<
void
*>
(
src
.
data_ptr
());
void
*
dst_ptr
=
reinterpret_cast
<
void
*>
(
dst
.
data_ptr
());
uint64_t
*
sig_addr
=
reinterpret_cast
<
uint64_t
*>
(
signal
.
data_ptr
());
auto
nelement
=
src
.
numel
()
*
src
.
element_size
();
uint64_t
sigval
=
1
;
at
::
cuda
::
CUDAStream
cur_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
nvshmemx_putmem_signal_on_stream
(
dst_ptr
,
src_ptr
,
nelement
,
sig_addr
,
sigval
,
NVSHMEM_SIGNAL_SET
,
peer
,
(
cudaStream_t
)
cur_stream
);
#else
NVTE_ERROR
(
"Internal TE error: nvshmem_send_on_current_stream cannot be initialized with valid PyTorch "
,
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"
);
#endif
}
void
nvshmem_finalize
()
{
#ifdef NVTE_ENABLE_NVSHMEM
nvshmem_finalize
();
#else
NVTE_ERROR
(
"Internal TE error: nvshmem_finalize cannot be initialized with valid PyTorch "
,
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"
);
#endif
}
}
// namespace nvshmem_api
transformer_engine/pytorch/csrc/extensions/padding.cpp
View file @
ab3e5a92
...
@@ -17,7 +17,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
...
@@ -17,7 +17,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
NVTE_CHECK
(
input
.
dim
()
==
2
,
"Dimension of input must equal 2."
);
NVTE_CHECK
(
input
.
dim
()
==
2
,
"Dimension of input must equal 2."
);
NVTE_CHECK
(
output
.
dim
()
==
2
,
"Dimension of output must equal 2."
);
NVTE_CHECK
(
output
.
dim
()
==
2
,
"Dimension of output must equal 2."
);
const
int
num_tensors
=
input_row_list
.
size
();
const
auto
num_tensors
=
input_row_list
.
size
();
// Extract properties from PyTorch tensors
// Extract properties from PyTorch tensors
std
::
vector
<
void
*>
input_dptr_list
,
output_dptr_list
;
std
::
vector
<
void
*>
input_dptr_list
,
output_dptr_list
;
std
::
vector
<
std
::
vector
<
size_t
>>
input_shape_list
,
output_shape_list
;
std
::
vector
<
std
::
vector
<
size_t
>>
input_shape_list
,
output_shape_list
;
...
...
transformer_engine/pytorch/csrc/extensions/permutation.cu
View file @
ab3e5a92
...
@@ -52,18 +52,11 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
...
@@ -52,18 +52,11 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
sorted_indices_ptr
,
row_id_ptr
,
sorted_row_id_ptr
,
sorted_indices_ptr
,
row_id_ptr
,
sorted_row_id_ptr
,
num_tokens
*
topK
);
num_tokens
*
topK
);
// Activations type
at
::
ScalarType
_st
;
if
(
dtype
==
transformer_engine
::
DType
::
kFloat8E4M3
||
dtype
==
transformer_engine
::
DType
::
kFloat8E5M2
)
_st
=
at
::
ScalarType
::
Byte
;
else
_st
=
input
.
scalar_type
();
// Output buffer alloc
// Output buffer alloc
num_out_tokens
=
(
num_out_tokens
>
0
)
?
num_out_tokens
:
num_tokens
*
topK
;
num_out_tokens
=
(
num_out_tokens
>
0
)
?
num_out_tokens
:
num_tokens
*
topK
;
at
::
Tensor
permuted_output
=
torch
::
empty
(
at
::
Tensor
permuted_output
=
{
num_out_tokens
,
num_cols
},
torch
::
dtype
(
_st
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
torch
::
empty
({
num_out_tokens
,
num_cols
},
torch
::
dtype
(
input
.
scalar_type
()).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
at
::
Tensor
row_id_map
=
torch
::
empty
(
at
::
Tensor
row_id_map
=
torch
::
empty
(
{
num_tokens
*
topK
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
{
num_tokens
*
topK
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
...
@@ -100,17 +93,10 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
...
@@ -100,17 +93,10 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
using
namespace
transformer_engine
::
pytorch
;
using
namespace
transformer_engine
::
pytorch
;
int
num_cols
=
input
.
size
(
1
);
int
num_cols
=
input
.
size
(
1
);
// Activations type
at
::
ScalarType
_st
;
if
(
dtype
==
transformer_engine
::
DType
::
kFloat8E4M3
||
dtype
==
transformer_engine
::
DType
::
kFloat8E5M2
)
_st
=
at
::
ScalarType
::
Byte
;
else
_st
=
input
.
scalar_type
();
// Output buffer alloc
// Output buffer alloc
at
::
Tensor
unpermuted_output
=
torch
::
empty
(
at
::
Tensor
unpermuted_output
=
{
num_tokens
,
num_cols
},
torch
::
dtype
(
_st
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
torch
::
empty
({
num_tokens
,
num_cols
},
torch
::
dtype
(
input
.
scalar_type
()).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
@@ -136,17 +122,10 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
...
@@ -136,17 +122,10 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
const
int
num_tokens
=
(
prob
.
numel
()
>
0
)
?
prob
.
size
(
0
)
:
row_id_map
.
size
(
0
);
const
int
num_tokens
=
(
prob
.
numel
()
>
0
)
?
prob
.
size
(
0
)
:
row_id_map
.
size
(
0
);
int
num_cols
=
input_bwd
.
size
(
1
);
int
num_cols
=
input_bwd
.
size
(
1
);
// Activations type
at
::
ScalarType
_st
;
if
(
dtype
==
transformer_engine
::
DType
::
kFloat8E4M3
||
dtype
==
transformer_engine
::
DType
::
kFloat8E5M2
)
_st
=
at
::
ScalarType
::
Byte
;
else
_st
=
input_bwd
.
scalar_type
();
// Output buffer alloc
// Output buffer alloc
at
::
Tensor
act_grad
=
torch
::
empty
({
input_fwd
.
size
(
0
),
num_cols
},
at
::
Tensor
act_grad
=
torch
::
dtype
(
_st
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
torch
::
empty
({
input_fwd
.
size
(
0
),
num_cols
},
torch
::
dtype
(
input_bwd
.
scalar_type
()).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
at
::
Tensor
prob_grad
=
torch
::
empty
(
at
::
Tensor
prob_grad
=
torch
::
empty
(
{
num_tokens
,
topK
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
{
num_tokens
,
topK
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
...
...
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
ab3e5a92
...
@@ -28,6 +28,9 @@ PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr;
...
@@ -28,6 +28,9 @@ PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr;
PyTypeObject
*
MXFP8TensorPythonClass
=
nullptr
;
/// TODO Remove
PyTypeObject
*
MXFP8TensorPythonClass
=
nullptr
;
/// TODO Remove
PyTypeObject
*
MXFP8TensorBasePythonClass
=
nullptr
;
PyTypeObject
*
MXFP8TensorBasePythonClass
=
nullptr
;
PyTypeObject
*
MXFP8QuantizerClass
=
nullptr
;
PyTypeObject
*
MXFP8QuantizerClass
=
nullptr
;
PyTypeObject
*
Float8BlockwiseQTensorPythonClass
=
nullptr
;
PyTypeObject
*
Float8BlockwiseQTensorBasePythonClass
=
nullptr
;
PyTypeObject
*
Float8BlockwiseQuantizerClass
=
nullptr
;
void
init_float8_extension
()
{
void
init_float8_extension
()
{
if
(
Float8TensorPythonClass
)
return
;
if
(
Float8TensorPythonClass
)
return
;
...
@@ -61,9 +64,31 @@ void init_mxfp8_extension() {
...
@@ -61,9 +64,31 @@ void init_mxfp8_extension() {
"Internal error: could not initialize pyTorch MXFP8 extension."
);
"Internal error: could not initialize pyTorch MXFP8 extension."
);
}
}
void
init_float8blockwise_extension
()
{
if
(
Float8BlockwiseQTensorBasePythonClass
)
return
;
auto
fp8_module
=
py
::
module_
::
import
(
"transformer_engine.pytorch.tensor.float8_blockwise_tensor"
);
auto
fp8_base_module
=
py
::
module_
::
import
(
"transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base"
);
Float8BlockwiseQuantizerClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
fp8_module
.
ptr
(),
"Float8BlockQuantizer"
));
Float8BlockwiseQTensorBasePythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
fp8_base_module
.
ptr
(),
"Float8BlockwiseQTensorBase"
));
Float8BlockwiseQTensorPythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
fp8_module
.
ptr
(),
"Float8BlockwiseQTensor"
));
NVTE_CHECK
(
Float8BlockwiseQuantizerClass
!=
nullptr
,
"Internal error: could not initialize pyTorch float8blockwise extension."
);
NVTE_CHECK
(
Float8BlockwiseQTensorBasePythonClass
!=
nullptr
,
"Internal error: could not initialize pyTorch float8blockwise extension."
);
NVTE_CHECK
(
Float8BlockwiseQTensorPythonClass
!=
nullptr
,
"Internal error: could not initialize pyTorch float8blockwise extension."
);
}
void
init_extension
()
{
void
init_extension
()
{
init_float8_extension
();
init_float8_extension
();
init_mxfp8_extension
();
init_mxfp8_extension
();
init_float8blockwise_extension
();
}
}
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
...
@@ -76,6 +101,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -76,6 +101,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"output"
)
=
py
::
none
(),
py
::
arg
(
"noop"
)
=
py
::
none
());
py
::
arg
(
"output"
)
=
py
::
none
(),
py
::
arg
(
"noop"
)
=
py
::
none
());
m
.
def
(
"dequantize"
,
&
transformer_engine
::
pytorch
::
dequantize
,
"Dequantize"
,
py
::
arg
(
"input"
),
m
.
def
(
"dequantize"
,
&
transformer_engine
::
pytorch
::
dequantize
,
"Dequantize"
,
py
::
arg
(
"input"
),
py
::
arg
(
"otype"
));
py
::
arg
(
"otype"
));
m
.
def
(
"bgrad_quantize"
,
transformer_engine
::
pytorch
::
bgrad_quantize
,
m
.
def
(
"bgrad_quantize"
,
transformer_engine
::
pytorch
::
bgrad_quantize
,
"Compute bias gradient and quantize"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
"Compute bias gradient and quantize"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"generic_gemm"
,
transformer_engine
::
pytorch
::
gemm
,
"Compute GEMM (matrix-matrix multiply)"
,
m
.
def
(
"generic_gemm"
,
transformer_engine
::
pytorch
::
gemm
,
"Compute GEMM (matrix-matrix multiply)"
,
...
@@ -170,15 +196,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -170,15 +196,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"ln_out"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"otype"
),
py
::
arg
(
"sm_margin"
),
py
::
arg
(
"ln_out"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"otype"
),
py
::
arg
(
"sm_margin"
),
py
::
arg
(
"zero_centered_gamma"
));
py
::
arg
(
"zero_centered_gamma"
));
m
.
def
(
"rmsnorm_bwd"
,
&
rmsnorm_bwd
,
"Backward of RMSNorm"
);
m
.
def
(
"rmsnorm_bwd"
,
&
rmsnorm_bwd
,
"Backward of RMSNorm"
);
m
.
def
(
"fused_multi_quantize"
,
&
fused_multi_quantize
,
"Fused Multi-tensor Cast + Transpose"
,
m
.
def
(
"fused_multi_quantize"
,
&
transformer_engine
::
pytorch
::
fused_multi_quantize
,
py
::
arg
(
"input_list"
),
py
::
arg
(
"output_list"
),
py
::
arg
(
"quantizer_list"
),
py
::
arg
(
"otype"
));
"Fused Multi-tensor Cast + Transpose"
,
py
::
arg
(
"input_list"
),
py
::
arg
(
"output_list"
),
py
::
arg
(
"quantizer_list"
),
py
::
arg
(
"otype"
));
m
.
def
(
"te_general_grouped_gemm"
,
&
te_general_grouped_gemm
,
"Grouped GEMM"
);
m
.
def
(
"te_general_grouped_gemm"
,
&
te_general_grouped_gemm
,
"Grouped GEMM"
);
#ifdef USE_ROCM
#ifdef USE_ROCM
m
.
def
(
"te_batchgemm_ts"
,
&
te_batchgemm_ts
,
"Batched GEMM"
);
/// rocblas
m
.
def
(
"te_batchgemm_ts"
,
&
te_batchgemm_ts
,
"Batched GEMM"
);
/// rocblas
#endif
#endif
m
.
def
(
"fp8_transpose"
,
&
fp8_transpose
,
"Transpose with FP8 I/O"
,
py
::
arg
(
"input"
),
m
.
def
(
"fp8_transpose"
,
&
transformer_engine
::
pytorch
::
fp8_transpose
,
"Transpose with FP8 I/O"
,
py
::
arg
(
"dtype"
),
py
::
kw_only
(),
py
::
arg
(
"out"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
arg
(
"input"
),
py
::
arg
(
"dtype"
),
py
::
kw_only
(),
py
::
arg
(
"out"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"get_fused_attn_backend"
,
&
get_fused_attn_backend
,
"Get Fused Attention backend"
,
m
.
def
(
"get_fused_attn_backend"
,
&
get_fused_attn_backend
,
"Get Fused Attention backend"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"compute_amax"
,
&
compute_amax
,
"Compute amax"
,
py
::
arg
(
"input"
),
py
::
arg
(
"amax"
));
m
.
def
(
"compute_amax"
,
&
compute_amax
,
"Compute amax"
,
py
::
arg
(
"input"
),
py
::
arg
(
"amax"
));
...
@@ -206,10 +234,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -206,10 +234,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_rope_backward"
,
&
fused_rope_backward
,
"Fused Apply RoPE BWD"
,
m
.
def
(
"fused_rope_backward"
,
&
fused_rope_backward
,
"Fused Apply RoPE BWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_rope_thd_forward"
,
&
fused_rope_thd_forward
,
"Fused Apply RoPE FWD for thd format"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_rope_thd_backward"
,
&
fused_rope_thd_backward
,
"Fused Apply RoPE BWD for thd format"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// Misc
// Misc
m
.
def
(
"get_cublasLt_version"
,
&
get_cublasLt_version
,
"Get cublasLt version"
,
m
.
def
(
"get_cublasLt_version"
,
&
get_cublasLt_version
,
"Get cublasLt version"
,
...
@@ -240,6 +264,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -240,6 +264,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Generate partitioned indices for inputs in THD format"
,
"Generate partitioned indices for inputs in THD format"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// nvshmem functions
m
.
def
(
"init_nvshmem_backend"
,
&
nvshmem_api
::
init_nvshmem_backend
,
"Initialize nvshmem backend with Pytorch distributed process groups"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"create_nvshmem_tensor"
,
&
nvshmem_api
::
create_nvshmem_tensor
,
"Create a tensor in NVSHMEM shared memory"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"nvshmem_send_on_current_stream"
,
&
nvshmem_api
::
nvshmem_send_on_current_stream
,
"Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"nvshmem_wait_on_current_stream"
,
&
nvshmem_api
::
nvshmem_wait_on_current_stream
,
"Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA "
"stream"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"nvshmem_finalize"
,
&
nvshmem_api
::
nvshmem_finalize
,
"Clean up and finalize the NVSHMEM communication backend and free associated resources"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// multi-tensor functions
// multi-tensor functions
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
"Fused overflow check + scale for a list of contiguous tensors"
,
"Fused overflow check + scale for a list of contiguous tensors"
,
...
...
Prev
1
…
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment