Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ab3e5a92
Commit
ab3e5a92
authored
May 09, 2025
by
yuguo
Browse files
Merge commit '
04c730c0
' of...
Merge commit '
04c730c0
' of
https://github.com/NVIDIA/TransformerEngine
parents
a8d19fd9
04c730c0
Changes
174
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
573 additions
and
351 deletions
+573
-351
transformer_engine/jax/setup.py
transformer_engine/jax/setup.py
+1
-1
transformer_engine/jax/sharding.py
transformer_engine/jax/sharding.py
+28
-2
transformer_engine/pytorch/attention.py
transformer_engine/pytorch/attention.py
+23
-11
transformer_engine/pytorch/constants.py
transformer_engine/pytorch/constants.py
+6
-0
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+17
-53
transformer_engine/pytorch/cpu_offload.py
transformer_engine/pytorch/cpu_offload.py
+43
-30
transformer_engine/pytorch/csrc/common.h
transformer_engine/pytorch/csrc/common.h
+32
-0
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+36
-19
transformer_engine/pytorch/csrc/extensions/activation.cpp
transformer_engine/pytorch/csrc/extensions/activation.cpp
+6
-1
transformer_engine/pytorch/csrc/extensions/apply_rope.cpp
transformer_engine/pytorch/csrc/extensions/apply_rope.cpp
+125
-161
transformer_engine/pytorch/csrc/extensions/attention.cu
transformer_engine/pytorch/csrc/extensions/attention.cu
+13
-11
transformer_engine/pytorch/csrc/extensions/cast.cpp
transformer_engine/pytorch/csrc/extensions/cast.cpp
+12
-3
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
...rmer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
+10
-10
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+15
-2
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu
...src/extensions/multi_tensor/multi_tensor_compute_scale.cu
+4
-2
transformer_engine/pytorch/csrc/extensions/normalization.cpp
transformer_engine/pytorch/csrc/extensions/normalization.cpp
+14
-6
transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp
transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp
+129
-0
transformer_engine/pytorch/csrc/extensions/padding.cpp
transformer_engine/pytorch/csrc/extensions/padding.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/permutation.cu
transformer_engine/pytorch/csrc/extensions/permutation.cu
+9
-30
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+49
-8
No files found.
transformer_engine/jax/setup.py
View file @
ab3e5a92
...
...
@@ -101,7 +101,7 @@ if __name__ == "__main__":
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
},
install_requires
=
[
"jax"
,
"flax>=0.7.1"
],
tests_require
=
[
"numpy"
,
"praxis"
],
tests_require
=
[
"numpy"
],
)
if
any
(
x
in
sys
.
argv
for
x
in
(
"."
,
"sdist"
,
"bdist_wheel"
)):
shutil
.
rmtree
(
common_headers_dir
)
...
...
transformer_engine/jax/sharding.py
View file @
ab3e5a92
...
...
@@ -89,7 +89,11 @@ def generate_pspec(logical_axis_names):
Convert logical axes to PartitionSpec
"""
rules
=
get_sharding_map_logic_axis_to_mesh_axis
()
mesh_axis_names
=
[
rules
[
name
]
for
name
in
logical_axis_names
]
# mesh_axis_names = [rules[name] for name in logical_axis_names]
mesh_axis_names
=
[]
for
name
in
logical_axis_names
:
axis_name
=
rules
[
name
]
if
name
in
rules
else
None
mesh_axis_names
.
append
(
axis_name
)
pspec
=
jax
.
sharding
.
PartitionSpec
(
*
mesh_axis_names
)
return
pspec
...
...
@@ -112,7 +116,7 @@ def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: t
"""
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
"""
if
logical_axis_names
is
None
:
if
not
logical_axis_names
:
return
x
assert
len
(
x
.
shape
)
==
len
(
logical_axis_names
)
...
...
@@ -315,3 +319,25 @@ class ShardingType(Enum):
TP_ROW
=
(
MajorShardingType
.
TP
,
"tp_row"
)
DP_TP_COL
=
(
MajorShardingType
.
DPTP
,
"dp_tp_col"
)
DP_TP_ROW
=
(
MajorShardingType
.
DPTP
,
"dp_tp_row"
)
def
get_non_contracting_logical_axes
(
ndim
,
logical_axes
,
contracting_dims
):
"""Get logical axes for non-contracting dimensions.
Args:
ndim: Number of dimensions in the tensor.
logical_axes: Tuple of logical axes for each dimension.
contracting_dims: Set of dimensions that are being contracted.
Returns:
Tuple of logical axes for non-contracting dimensions.
"""
if
not
logical_axes
:
logical_axes
=
(
None
,)
*
ndim
elif
len
(
logical_axes
)
<
ndim
:
logical_axes
=
logical_axes
+
(
None
,)
*
(
ndim
-
len
(
logical_axes
))
assert
len
(
logical_axes
)
==
ndim
non_contracting_dims
=
[
i
for
i
in
range
(
ndim
)
if
i
not
in
contracting_dims
]
non_contracting_logical_axes
=
tuple
(
logical_axes
[
i
]
for
i
in
non_contracting_dims
)
return
non_contracting_logical_axes
transformer_engine/pytorch/attention.py
View file @
ab3e5a92
...
...
@@ -20,6 +20,7 @@ import torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
from
transformer_engine.pytorch.utils
import
(
get_cudnn_version
,
nvtx_range_pop
,
...
...
@@ -81,6 +82,7 @@ import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
from
transformer_engine.pytorch.dot_product_attention.utils
import
FlashAttentionUtils
as
fa_utils
from
transformer_engine.pytorch.dot_product_attention.utils
import
AttentionLogging
as
attn_log
from
transformer_engine.pytorch.dot_product_attention.rope
import
apply_rotary_pos_emb
from
.cpu_offload
import
mark_activation_offload
# Setup Attention Logging
...
...
@@ -618,7 +620,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
rank
=
get_distributed_rank
(
cp_group
)
send_dst
=
cp_global_ranks
[(
rank
+
1
)
%
cp_size
*
cp_size_a2a
+
rank_a2a
]
recv_src
=
cp_global_ranks
[(
rank
-
1
)
%
cp_size
*
cp_size_a2a
+
rank_a2a
]
batch_p2p_comm
=
int
(
os
.
getenv
(
"NVTE_BATCH_MHA_P2P_COMM"
,
"0"
))
or
(
cp_size
==
2
)
batch_p2p_comm
=
int
(
os
.
getenv
(
"NVTE_BATCH_MHA_P2P_COMM"
,
"0"
))
causal
=
"causal"
in
attn_mask_type
padding
=
"padding"
in
attn_mask_type
...
...
@@ -1566,7 +1568,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
rank
=
get_distributed_rank
(
ctx
.
cp_group
)
send_dst
=
ctx
.
cp_global_ranks
[(
rank
-
1
)
%
cp_size
*
cp_size_a2a
+
rank_a2a
]
recv_src
=
ctx
.
cp_global_ranks
[(
rank
+
1
)
%
cp_size
*
cp_size_a2a
+
rank_a2a
]
batch_p2p_comm
=
int
(
os
.
getenv
(
"NVTE_BATCH_MHA_P2P_COMM"
,
"0"
))
or
(
cp_size
==
2
)
batch_p2p_comm
=
int
(
os
.
getenv
(
"NVTE_BATCH_MHA_P2P_COMM"
,
"0"
))
q
,
kv
,
out
,
softmax_lse
,
cu_seqlens_q_padded
,
cu_seqlens_kv_padded
,
*
other_tensors
=
(
restore_from_saved
(
ctx
.
tensor_objects
,
ctx
.
saved_tensors
)
...
...
@@ -4323,10 +4325,9 @@ class FlashAttention(torch.nn.Module):
from
.cpu_offload
import
CPUOffloadEnabled
if
CPUOffloadEnabled
:
tensor_list
=
[
query_layer
,
key_layer
,
value_layer
,
cu_seqlens_q
,
cu_seqlens_kv
]
for
tensor
in
tensor_list
:
if
tensor
is
not
None
:
tensor
.
activation_offloading
=
True
mark_activation_offload
(
query_layer
,
key_layer
,
value_layer
,
cu_seqlens_q
,
cu_seqlens_kv
)
with
self
.
attention_dropout_ctx
():
# | API | use cases
...
...
@@ -4728,12 +4729,9 @@ class FusedAttnFunc(torch.autograd.Function):
else
:
tensor_list
=
[
q
,
k
,
v
,
out_save
]
tensor_list
.
extend
(
aux_ctx_tensors
)
qkv_layout
=
"sbhd_sbhd_sbhd"
for
tensor
in
tensor_list
:
if
tensor
is
not
None
:
tensor
.
activation_offloading
=
True
mark_activation_offload
(
*
tensor_list
)
mark_activation_offload
(
*
aux_ctx_tensors
)
ctx
.
is_input_fp8
=
is_input_fp8
ctx
.
is_output_fp8
=
is_output_fp8
...
...
@@ -6482,6 +6480,8 @@ class MultiheadAttention(torch.nn.Module):
equal length. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
...
...
@@ -6560,6 +6560,7 @@ class MultiheadAttention(torch.nn.Module):
normalization
:
str
=
"LayerNorm"
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
qkv_format
:
str
=
"sbhd"
,
name
:
str
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -6611,6 +6612,8 @@ class MultiheadAttention(torch.nn.Module):
self
.
hidden_size_q
=
self
.
hidden_size_per_attention_head
*
num_attention_heads
self
.
hidden_size_kv
=
self
.
hidden_size_per_attention_head
*
self
.
num_gqa_groups
self
.
name
=
name
common_gemm_kwargs
=
{
"fuse_wgrad_accumulation"
:
fuse_wgrad_accumulation
,
"tp_group"
:
tp_group
,
...
...
@@ -6651,6 +6654,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_ag
=
ub_overlap_ag
,
normalization
=
normalization
,
ub_name
=
"qkv"
,
name
=
name
+
".layernorm_linear_qkv"
if
name
is
not
None
else
None
,
**
common_gemm_kwargs
,
)
else
:
...
...
@@ -6662,6 +6666,7 @@ class MultiheadAttention(torch.nn.Module):
return_bias
=
False
,
parallel_mode
=
qkv_parallel_mode
,
parameters_split
=
parameters_split
,
name
=
name
+
".linear_qkv"
if
name
is
not
None
else
None
,
**
common_gemm_kwargs
,
)
elif
self
.
attention_type
==
"cross"
:
...
...
@@ -6683,6 +6688,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_ag
=
ub_overlap_ag
,
normalization
=
normalization
,
ub_name
=
"qkv"
,
name
=
name
+
".layernorm_linear_q"
if
name
is
not
None
else
None
,
**
common_gemm_kwargs
,
)
else
:
...
...
@@ -6693,6 +6699,7 @@ class MultiheadAttention(torch.nn.Module):
bias
=
bias
,
return_bias
=
False
,
parallel_mode
=
qkv_parallel_mode
,
name
=
name
+
".linear_q"
if
name
is
not
None
else
None
,
**
common_gemm_kwargs
,
)
self
.
key_value
=
Linear
(
...
...
@@ -6703,6 +6710,7 @@ class MultiheadAttention(torch.nn.Module):
return_bias
=
False
,
parallel_mode
=
qkv_parallel_mode
,
parameters_split
=
(
"key"
,
"value"
)
if
not
fuse_qkv_params
else
None
,
name
=
name
+
".linear_kv"
if
name
is
not
None
else
None
,
**
common_gemm_kwargs
,
)
...
...
@@ -6732,6 +6740,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_rs
=
ub_overlap_rs
,
ub_overlap_ag
=
ub_overlap_ag
,
ub_name
=
"proj"
,
name
=
name
+
".proj"
if
name
is
not
None
else
None
,
**
common_gemm_kwargs
,
)
...
...
@@ -6922,6 +6931,9 @@ class MultiheadAttention(torch.nn.Module):
core_attention_bias_type
in
AttnBiasTypes
),
f
"core_attention_bias_type
{
core_attention_bias_type
}
is not supported!"
if
TEDebugState
.
debug_enabled
:
TransformerEngineBaseModule
.
_validate_name
(
self
)
# =================================================
# Pre-allocate memory for key-value cache for inference
# =================================================
...
...
transformer_engine/pytorch/constants.py
View file @
ab3e5a92
...
...
@@ -24,6 +24,12 @@ TE_DType = {
torch
.
bfloat16
:
tex
.
DType
.
kBFloat16
,
}
"""
This is a map: int -> torch.dtype
Used for resolving cuda extension types to torch.
Has one to one mapping with enum in
transformer_engine.h
"""
TE_DType_To_Torch
=
{
tex
.
DType
.
kByte
:
torch
.
uint8
,
tex
.
DType
.
kFloat8E4M3
:
torch
.
float8_e4m3fn
,
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
ab3e5a92
...
...
@@ -9,11 +9,11 @@ import os
import
torch
import
transformer_engine_torch
as
tex
from
..constants
import
TE_DType
from
..utils
import
assert_dim_for_fp8_exec
,
get_sm_count
from
..utils
import
get_sm_count
from
..tensor.quantized_tensor
import
Quantizer
from
..tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..
tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor._internal.float8_
blockwise_
tensor_base
import
Float8
BlockwiseQ
TensorBase
from
..
.debug.pytorch.debug_quantization
import
DebugQuantizer
__all__
=
[
"general_gemm"
,
...
...
@@ -28,46 +28,6 @@ def _empty_tensor() -> torch.Tensor:
return
torch
.
Tensor
().
cuda
()
def
swizzle_inputs
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
layout
:
str
):
"""Swizzle gemm inputs and return original scaling factor inverses."""
if
not
isinstance
(
A
,
MXFP8TensorBase
)
or
not
isinstance
(
B
,
MXFP8TensorBase
):
return
None
original_scale_inverses
=
(
A
.
_rowwise_scale_inv
,
A
.
_columnwise_scale_inv
,
B
.
_rowwise_scale_inv
,
B
.
_columnwise_scale_inv
,
)
if
layout
[
0
]
==
"T"
:
A
.
_rowwise_scale_inv
=
tex
.
rowwise_swizzle
(
A
.
_rowwise_data
,
A
.
_rowwise_scale_inv
)
else
:
A
.
_columnwise_scale_inv
=
tex
.
columnwise_swizzle
(
A
.
_columnwise_data
,
A
.
_columnwise_scale_inv
)
if
layout
[
1
]
==
"N"
:
B
.
_rowwise_scale_inv
=
tex
.
rowwise_swizzle
(
B
.
_rowwise_data
,
B
.
_rowwise_scale_inv
)
else
:
B
.
_columnwise_scale_inv
=
tex
.
columnwise_swizzle
(
B
.
_columnwise_data
,
B
.
_columnwise_scale_inv
)
return
original_scale_inverses
def
reset_swizzled_inputs
(
A
,
B
,
scale_inverses
):
"""Reset the swizzled scale inverses after GEMM."""
if
scale_inverses
is
not
None
:
(
A
.
_rowwise_scale_inv
,
A
.
_columnwise_scale_inv
,
B
.
_rowwise_scale_inv
,
B
.
_columnwise_scale_inv
,
)
=
scale_inverses
def
general_gemm
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
...
...
@@ -110,9 +70,20 @@ def general_gemm(
if
not
out
.
is_contiguous
():
raise
ValueError
(
"Output tensor is not contiguous."
)
debug_quantizer
=
None
if
isinstance
(
quantization_params
,
DebugQuantizer
):
debug_quantizer
=
quantization_params
quantization_params
=
quantization_params
.
parent_quantizer
A
=
A
.
get_tensor
(
not
transa
)
B
=
B
.
get_tensor
(
transb
)
# Use bfloat16 as default bias_dtype
bias_dtype
=
TE_DType
[
torch
.
bfloat16
if
bias
is
None
else
bias
.
dtype
]
if
isinstance
(
A
,
Float8BlockwiseQTensorBase
)
or
isinstance
(
B
,
Float8BlockwiseQTensorBase
):
# There is not use_split_accumulator == False
# implementation for Float8BlockwiseQTensorBase GEMM
use_split_accumulator
=
True
args
=
(
A
,
transa
,
# transa
...
...
@@ -138,9 +109,10 @@ def general_gemm(
"bulk_overlap"
:
bulk_overlap
,
}
original_scale_inverses
=
swizzle_inputs
(
A
,
B
,
layout
)
out
,
bias_grad
,
gelu_input
,
extra_output
=
tex
.
generic_gemm
(
*
args
,
**
kwargs
)
reset_swizzled_inputs
(
A
,
B
,
original_scale_inverses
)
if
debug_quantizer
is
not
None
:
out
=
debug_quantizer
.
process_gemm_output
(
out
)
return
out
,
bias_grad
,
gelu_input
,
extra_output
...
...
@@ -170,14 +142,6 @@ def general_grouped_gemm(
transa
=
layout
[
0
]
==
"T"
transb
=
layout
[
1
]
==
"T"
# assert [a.is_contiguous() for a in A]
# assert [b.is_contiguous() for b in B]
if
isinstance
(
A
[
0
],
Float8TensorBase
):
for
a
,
b
in
zip
(
A
,
B
):
assert_dim_for_fp8_exec
(
a
.
_data
)
assert_dim_for_fp8_exec
(
b
.
_data
)
empty_tensor
=
_empty_tensor
()
empty_tensors
=
[
empty_tensor
]
*
num_gemms
...
...
transformer_engine/pytorch/cpu_offload.py
View file @
ab3e5a92
...
...
@@ -16,18 +16,22 @@ __all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled
=
False
def
set_offloading_param
(
tensor
,
param_name
,
value
):
def
mark_activation_offload
(
*
tensors
):
"""Set the type of the offloading needed for a tensor."""
assert
param_name
in
[
"weight_offloading"
,
"activation_offloading"
]
if
tensor
is
None
:
return
if
type
(
tensor
)
in
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
]:
setattr
(
tensor
,
param_name
,
value
)
else
:
data_tensors
=
tensor
.
get_data_tensors
()
for
tensor
in
data_tensors
:
if
tensor
is
not
None
:
setattr
(
tensor
,
param_name
,
value
)
for
tensor
in
tensors
:
if
tensor
is
None
:
continue
if
type
(
tensor
)
in
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
]:
tensor
.
activation_offloading
=
True
else
:
data_tensors
=
tensor
.
get_data_tensors
()
for
tensor
in
data_tensors
:
if
tensor
is
not
None
:
tensor
.
activation_offloading
=
True
# This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorBase classes are saved in the ctx,
# and they contain the reference to their data tensors.
tensor
.
needs_force_clear
=
True
def
is_cpu_offload_enabled
()
->
bool
:
...
...
@@ -459,8 +463,15 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
d2h_stream
)
# Time to free the activation memory after usage
for
tensor_tag
,
_
in
self
.
tensor_tag_to_buf
.
items
():
for
tensor_tag
,
tensor_buf
in
self
.
tensor_tag_to_buf
.
items
():
if
tensor_tag
[
0
]
==
self
.
offloaded_group_count
:
if
hasattr
(
tensor_buf
,
"needs_force_clear"
):
# Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorBase class,
# which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward.
tensor_buf
.
data
=
torch
.
Tensor
()
# Release the pointer to the tensor
self
.
tensor_tag_to_buf
[
tensor_tag
]
=
None
# Time to offload the next group
...
...
@@ -538,7 +549,7 @@ def get_cpu_offload_context(
num_layers
:
int
=
1
,
model_layers
:
int
=
1
,
offload_activations
:
bool
=
True
,
offload_weights
:
bool
=
Tru
e
,
offload_weights
:
bool
=
Fals
e
,
):
"""
This function returns the CPU Offload context and the synchronizer function that needs to be
...
...
@@ -570,28 +581,30 @@ def get_cpu_offload_context(
"""
def
tensor_need_offloading_checker_activations
(
tensor
):
return
hasattr
(
tensor
,
"activation_offloading"
)
# This includes the Gradient Accumulation Buffer
def
tensor_need_offloading_checker_weights
(
tensor
):
return
hasattr
(
tensor
,
"weight_offloading"
)
def
tensor_need_offloading_checker_all
(
tensor
):
return
hasattr
(
tensor
,
"activation_offloading"
)
or
hasattr
(
tensor
,
"weight_offloading"
)
if
offload_activations
and
offload_weights
:
tensor_need_offloading_checker
=
tensor_need_offloading_checker_all
elif
offload_activations
:
tensor_need_offloading_checker
=
tensor_need_offloading_checker_activations
elif
offload_weights
:
tensor_need_offloading_checker
=
tensor_need_offloading_checker_weights
else
:
if
not
offload_weights
and
not
offload_activations
:
raise
ValueError
(
"CPU Offloading is enabled while it is not "
"mentioned what to offload (weights/activations)"
)
if
offload_weights
:
import
warnings
warnings
.
warn
(
"Offloading weights is deprecated. Using offload_weights=True does not have any"
" effect."
,
DeprecationWarning
,
)
# Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
if
not
offload_activations
:
return
nullcontext
(),
lambda
x
:
x
def
tensor_need_offloading_checker_activations
(
tensor
):
return
hasattr
(
tensor
,
"activation_offloading"
)
tensor_need_offloading_checker
=
tensor_need_offloading_checker_activations
cpu_offload_handler
=
AsyncDoubleBufferGroupOffloadHandler
(
num_offload_group
=
num_layers
,
num_model_group
=
model_layers
,
...
...
transformer_engine/pytorch/csrc/common.h
View file @
ab3e5a92
...
...
@@ -167,6 +167,38 @@ class Float8CurrentScalingQuantizer : public Quantizer {
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
};
class
Float8BlockQuantizer
:
public
Quantizer
{
public:
// Which float8 type is used for q data.
DType
dtype
;
// Options about how to quantize the tensor
// Quantization scales are rounded down to powers of 2.
bool
force_pow_2_scales
=
false
;
// Amax within quantization tile has a floor of epsilon.
float
amax_epsilon
=
0.0
;
private:
int
block_scaling_dim
=
2
;
public:
// Initializes from a python handle to a Float8BlockQuantizer
explicit
Float8BlockQuantizer
(
const
py
::
handle
&
quantizer
);
NVTEScalingMode
get_scaling_mode
()
const
override
{
return
(
block_scaling_dim
==
2
)
?
NVTE_BLOCK_SCALING_2D
:
NVTE_BLOCK_SCALING_1D
;
}
// Gets rowwise and columnwise_data from tensor and sets them on wrapper
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
override
;
// Create a python Float8BlockQuantized tensor and C++ wrapper
// for the tensor. Should set quantized data, scales for rowwise
// and optionally columnwise usage.
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
};
class
MXFP8Quantizer
:
public
Quantizer
{
public:
DType
dtype
;
...
...
transformer_engine/pytorch/csrc/extensions.h
View file @
ab3e5a92
...
...
@@ -50,11 +50,11 @@ std::vector<py::object> fused_attn_fwd(
NVTE_Mask_Type
attn_mask_type
,
const
std
::
vector
<
int64_t
>
window_size
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
const
c10
::
optional
<
at
::
Tensor
>
page_table_k
,
const
c10
::
optional
<
at
::
Tensor
>
page_table_v
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
c10
::
optional
<
at
::
Tensor
>
Bias
,
const
c10
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
);
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
const
std
::
optional
<
at
::
Tensor
>
page_table_k
,
const
std
::
optional
<
at
::
Tensor
>
page_table_v
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
Bias
,
const
std
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
);
std
::
vector
<
py
::
object
>
fused_attn_bwd
(
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
p_dropout
,
bool
set_zero
,
...
...
@@ -63,8 +63,8 @@ std::vector<py::object> fused_attn_bwd(
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
py
::
handle
O
,
const
py
::
handle
dO
,
const
at
::
ScalarType
fake_dtype
,
const
transformer_engine
::
DType
dqkv_type
,
const
std
::
vector
<
at
::
Tensor
>
Aux_CTX_Tensors
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
py
::
handle
dp_quantizer
,
py
::
handle
dqkv_quantizer
);
at
::
Tensor
fa_prepare_fwd
(
at
::
Tensor
qkvi
);
...
...
@@ -121,18 +121,22 @@ std::vector<at::Tensor> te_batchgemm_ts(
int64_t
workspaceSize
,
int64_t
accumulate
,
int64_t
use_split_accumulator
);
#endif
namespace
transformer_engine
::
pytorch
{
/***************************************************************************************************
* Transpose
**************************************************************************************************/
std
::
vector
<
py
::
object
>
fused_multi_quantize
(
std
::
vector
<
py
::
handle
>
input_list
,
std
::
optional
<
std
::
vector
<
py
::
handle
>>
output_list
,
std
::
vector
<
py
::
object
>
fused_multi_quantize
(
std
::
vector
<
at
::
Tensor
>
input_list
,
std
::
optional
<
std
::
vector
<
py
::
object
>>
output_list
,
std
::
vector
<
py
::
handle
>
quantizer_list
,
transformer_engine
::
DType
otype
);
at
::
Tensor
fp8_transpose
(
at
::
Tensor
input
,
transformer_engine
::
DType
otype
,
std
::
optional
<
at
::
Tensor
>
output
=
std
::
nullopt
);
}
// namespace transformer_engine::pytorch
namespace
transformer_engine
::
pytorch
{
/***************************************************************************************************
...
...
@@ -285,16 +289,14 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio
**************************************************************************************************/
at
::
Tensor
fused_rope_forward
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
freqs
,
const
bool
transpose_output_memory
);
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens
,
const
int
cp_size
,
const
int
cp_rank
);
at
::
Tensor
fused_rope_backward
(
const
at
::
Tensor
&
output_grads
,
const
at
::
Tensor
&
freqs
,
const
bool
transpose_output_memory
);
at
::
Tensor
fused_rope_thd_forward
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
cu_seqlens
,
const
at
::
Tensor
&
freqs
,
const
int
cp_size
,
const
int
cp_rank
);
at
::
Tensor
fused_rope_thd_backward
(
const
at
::
Tensor
&
output_grads
,
const
at
::
Tensor
&
cu_seqlens
,
const
at
::
Tensor
&
freqs
,
const
int
cp_size
,
const
int
cp_rank
);
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens
,
const
int
cp_size
,
const
int
cp_rank
);
/***************************************************************************************************
* Miscellaneous
...
...
@@ -394,10 +396,25 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std
::
vector
<
size_t
>
padded_input_row_list
);
/***************************************************************************************************
*
swizzle
*
NVSHMEM APIs
**************************************************************************************************/
void
swizzle_scaling_factors
(
transformer_engine
::
TensorWrapper
&
input
,
bool
trans
);
namespace
nvshmem_api
{
void
init_nvshmem_backend
(
c10d
::
ProcessGroup
*
process_group
);
torch
::
Tensor
create_nvshmem_tensor
(
const
std
::
vector
<
int64_t
>
&
shape
,
c10
::
ScalarType
dtype
);
void
nvshmem_send_on_current_stream
(
torch
::
Tensor
src
,
torch
::
Tensor
dst
,
int
peer
,
torch
::
Tensor
signal
);
void
nvshmem_wait_on_current_stream
(
torch
::
Tensor
signal
,
const
std
::
string
&
wait_kind
);
void
nvshmem_finalize
();
}
// namespace nvshmem_api
/***************************************************************************************************
* swizzle
**************************************************************************************************/
at
::
Tensor
rowwise_swizzle
(
at
::
Tensor
input
,
at
::
Tensor
scale_inv
);
...
...
transformer_engine/pytorch/csrc/extensions/activation.cpp
View file @
ab3e5a92
...
...
@@ -50,7 +50,12 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
nvte_compute_scale_from_amax
(
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
te_output
.
defaultShape
);
nvte_quantize
(
te_output_act
.
data
(),
te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
nvte_quantize_v2
(
te_output_act
.
data
(),
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
}
else
if
(
detail
::
IsFloat8BlockwiseQuantizers
(
quantizer
.
ptr
()))
{
// sanity check, since activation fusion is not supported for blockwise quantization yet
// need to raise an error here instead of silently going into act_func with wrong numerics
NVTE_ERROR
(
"Activation fusion is not supported for blockwise quantization yet."
);
}
else
{
act_func
(
te_input
.
data
(),
te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
}
...
...
transformer_engine/pytorch/csrc/extensions/apply_rope.cpp
View file @
ab3e5a92
...
...
@@ -7,217 +7,181 @@
#include "extensions.h"
at
::
Tensor
fused_rope_forward
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
freqs
,
const
bool
transpose_output_memory
)
{
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens
,
const
int
cp_size
,
const
int
cp_rank
)
{
using
namespace
transformer_engine
::
pytorch
;
TORCH_CHECK
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
freqs
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
input
.
size
(
0
)
<=
freqs
.
size
(
0
),
"expected freqs tensor has a longer sequence length than input"
);
TORCH_CHECK
(
freqs
.
size
(
1
)
==
1
&&
freqs
.
size
(
2
)
==
1
,
"expected the second and third dims of the freqs tensor equal 1"
);
TORCH_CHECK
(
input
.
size
(
3
)
>=
freqs
.
size
(
3
),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor"
);
TORCH_CHECK
(
freqs
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"Dtype of the freqs tensor must be float"
);
// input sizes: (s, b, h, d)
// output
auto
act_options
=
at
::
TensorOptions
().
dtype
(
input
.
scalar_type
()).
device
(
input
.
device
());
auto
output
=
at
::
empty
(
input
.
sizes
(),
act_options
);
auto
input_cu
=
makeTransformerEngineTensor
(
input
);
auto
freqs_cu
=
makeTransformerEngineTensor
(
freqs
);
auto
output_cu
=
makeTransformerEngineTensor
(
output
);
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
TORCH_CHECK
(
cu_seqlens
.
has_value
(),
"expected cu_seqlens tensor"
);
TORCH_CHECK
(
cu_seqlens
.
value
().
dim
()
==
1
,
"expected 1D tensor"
);
TORCH_CHECK
(
input
.
size
(
2
)
>=
freqs
.
size
(
3
),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor"
);
// input sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// h: head num
// d: dim of each head
// const int t = input.size(0);
const
int
h
=
input
.
size
(
1
);
const
int
d
=
input
.
size
(
2
);
// input strides
const
int
stride_t
=
input
.
stride
(
0
);
const
int
stride_h
=
input
.
stride
(
1
);
const
int
stride_d
=
input
.
stride
(
2
);
// batch size
const
int
b
=
cu_seqlens
.
value
().
size
(
0
)
-
1
;
// freqs' shape is (max_s, 1, 1, d2)
const
int
max_s
=
freqs
.
size
(
0
);
const
int
d2
=
freqs
.
size
(
3
);
auto
cu_seqlens_cu
=
makeTransformerEngineTensor
(
cu_seqlens
.
value
());
nvte_fused_rope_forward
(
input_cu
.
data
(),
cu_seqlens_cu
.
data
(),
freqs_cu
.
data
(),
output_cu
.
data
(),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
max_s
,
b
,
h
,
d
,
d2
,
stride_t
,
/*stride_b=*/
0
,
stride_h
,
stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
return
output
;
}
TORCH_CHECK
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
// input sizes: (s, b, h, d) or (b, s, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const
int
s
=
input
.
size
(
0
);
const
int
b
=
input
.
size
(
1
);
const
int
s
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
input
.
size
(
0
)
:
input
.
size
(
1
);
const
int
b
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
input
.
size
(
1
)
:
input
.
size
(
0
);
const
int
h
=
input
.
size
(
2
);
const
int
d
=
input
.
size
(
3
);
// input strides
const
int
stride_s
=
input
.
stride
(
0
);
const
int
stride_b
=
input
.
stride
(
1
);
const
int
stride_s
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
input
.
stride
(
0
)
:
input
.
stride
(
1
);
const
int
stride_b
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
input
.
stride
(
1
)
:
input
.
stride
(
0
);
const
int
stride_h
=
input
.
stride
(
2
);
const
int
stride_d
=
input
.
stride
(
3
);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const
int
d2
=
freqs
.
size
(
3
);
// output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
at
::
Tensor
output
;
if
(
transpose_output_memory
)
{
output
=
torch
::
empty
({
b
,
s
,
h
,
d
},
act_options
).
transpose
(
0
,
1
);
}
else
{
output
=
torch
::
empty
({
s
,
b
,
h
,
d
},
act_options
);
}
// output strides
const
int
o_stride_s
=
output
.
stride
(
0
);
const
int
o_stride_b
=
output
.
stride
(
1
);
const
int
o_stride_h
=
output
.
stride
(
2
);
const
int
o_stride_d
=
output
.
stride
(
3
);
auto
input_cu
=
makeTransformerEngineTensor
(
input
);
auto
freqs_cu
=
makeTransformerEngineTensor
(
freqs
);
auto
output_cu
=
makeTransformerEngineTensor
(
output
);
TORCH_CHECK
(
s
*
cp_size
<=
freqs
.
size
(
0
),
"expected freqs tensor has a longer sequence length than input"
);
TORCH_CHECK
(
d
>=
d2
,
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor"
);
nvte_fused_rope_forward
(
input_cu
.
data
(),
freqs_cu
.
data
(),
output_cu
.
data
(),
s
,
b
,
h
,
d
,
d2
,
stride_s
,
stride_b
,
stride_h
,
stride_d
,
o_stride_s
,
o_stride_b
,
o_stride_h
,
o_stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
auto
cu_seqlens_cu
=
transformer_engine
::
TensorWrapper
();
// empty cu_seqlens tensor
nvte_fused_rope_forward
(
input_cu
.
data
(),
cu_seqlens_cu
.
data
(),
freqs_cu
.
data
(),
output_cu
.
data
(),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_s
,
stride_b
,
stride_h
,
stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
return
output
;
}
at
::
Tensor
fused_rope_backward
(
const
at
::
Tensor
&
output_grads
,
const
at
::
Tensor
&
freqs
,
const
bool
transpose_output_memory
)
{
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens
,
const
int
cp_size
,
const
int
cp_rank
)
{
using
namespace
transformer_engine
::
pytorch
;
TORCH_CHECK
(
output_grads
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
freqs
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
output_grads
.
size
(
0
)
<=
freqs
.
size
(
0
),
"expected freqs tensor has a longer sequence length than output_grads"
);
TORCH_CHECK
(
freqs
.
size
(
1
)
==
1
&&
freqs
.
size
(
2
)
==
1
,
"expected the second and third dims of the freqs tensor equal 1"
);
TORCH_CHECK
(
output_grads
.
size
(
3
)
>=
freqs
.
size
(
3
),
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor"
);
TORCH_CHECK
(
freqs
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"Dtype of the freqs tensor must be float"
);
auto
act_options
=
at
::
TensorOptions
().
dtype
(
output_grads
.
scalar_type
()).
device
(
output_grads
.
device
());
auto
input_grads
=
at
::
empty
(
output_grads
.
sizes
(),
act_options
);
auto
output_grads_cu
=
makeTransformerEngineTensor
(
output_grads
);
auto
freqs_cu
=
makeTransformerEngineTensor
(
freqs
);
auto
input_grads_cu
=
makeTransformerEngineTensor
(
input_grads
);
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
)
{
TORCH_CHECK
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
TORCH_CHECK
(
cu_seqlens
.
has_value
(),
"expected cu_seqlens tensor"
);
TORCH_CHECK
(
cu_seqlens
.
value
().
dim
()
==
1
,
"expected 1D tensor"
);
TORCH_CHECK
(
output_grads
.
size
(
2
)
>=
freqs
.
size
(
3
),
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor"
);
// output_grads sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// h: head num
// d: dim of each head
// const int t = output_grads.size(0);
const
int
h
=
output_grads
.
size
(
1
);
const
int
d
=
output_grads
.
size
(
2
);
// output_grads strides
const
int
stride_t
=
output_grads
.
stride
(
0
);
const
int
stride_h
=
output_grads
.
stride
(
1
);
const
int
stride_d
=
output_grads
.
stride
(
2
);
// batch size
const
int
b
=
cu_seqlens
.
value
().
size
(
0
)
-
1
;
// freqs' shape is (max_s, 1, 1, d2)
const
int
max_s
=
freqs
.
size
(
0
);
const
int
d2
=
freqs
.
size
(
3
);
auto
cu_seqlens_cu
=
makeTransformerEngineTensor
(
cu_seqlens
.
value
());
nvte_fused_rope_backward
(
output_grads_cu
.
data
(),
cu_seqlens_cu
.
data
(),
freqs_cu
.
data
(),
input_grads_cu
.
data
(),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
max_s
,
b
,
h
,
d
,
d2
,
stride_t
,
/*stride_b=*/
0
,
stride_h
,
stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
return
input_grads
;
}
TORCH_CHECK
(
output_grads
.
dim
()
==
4
,
"expected 4D tensor"
);
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const
int
s
=
output_grads
.
size
(
0
);
const
int
b
=
output_grads
.
size
(
1
);
const
int
s
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
output_grads
.
size
(
0
)
:
output_grads
.
size
(
1
);
const
int
b
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
output_grads
.
size
(
1
)
:
output_grads
.
size
(
0
);
const
int
h
=
output_grads
.
size
(
2
);
const
int
d
=
output_grads
.
size
(
3
);
// output_grads strides
const
int
stride_s
=
output_grads
.
stride
(
0
);
const
int
stride_b
=
output_grads
.
stride
(
1
);
const
int
stride_s
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
output_grads
.
stride
(
0
)
:
output_grads
.
stride
(
1
);
const
int
stride_b
=
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
?
output_grads
.
stride
(
1
)
:
output_grads
.
stride
(
0
);
const
int
stride_h
=
output_grads
.
stride
(
2
);
const
int
stride_d
=
output_grads
.
stride
(
3
);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const
int
d2
=
freqs
.
size
(
3
);
auto
act_options
=
output_grads
.
options
().
requires_grad
(
false
);
at
::
Tensor
input_grads
;
if
(
transpose_output_memory
)
{
input_grads
=
torch
::
empty
({
b
,
s
,
h
,
d
},
act_options
).
transpose
(
0
,
1
);
}
else
{
input_grads
=
torch
::
empty
({
s
,
b
,
h
,
d
},
act_options
);
}
const
int
o_stride_s
=
input_grads
.
stride
(
0
);
const
int
o_stride_b
=
input_grads
.
stride
(
1
);
const
int
o_stride_h
=
input_grads
.
stride
(
2
);
const
int
o_stride_d
=
input_grads
.
stride
(
3
);
auto
output_grads_cu
=
makeTransformerEngineTensor
(
output_grads
);
auto
freqs_cu
=
makeTransformerEngineTensor
(
freqs
);
auto
input_grads_cu
=
makeTransformerEngineTensor
(
input_grads
);
nvte_fused_rope_backward
(
output_grads_cu
.
data
(),
freqs_cu
.
data
(),
input_grads_cu
.
data
(),
s
,
b
,
h
,
d
,
d2
,
stride_s
,
stride_b
,
stride_h
,
stride_d
,
o_stride_s
,
o_stride_b
,
o_stride_h
,
o_stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
return
input_grads
;
}
at
::
Tensor
fused_rope_thd_forward
(
const
at
::
Tensor
&
input
,
const
at
::
Tensor
&
cu_seqlens
,
const
at
::
Tensor
&
freqs
,
const
int
cp_size
,
const
int
cp_rank
)
{
using
namespace
transformer_engine
::
pytorch
;
TORCH_CHECK
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
,
"expected 1D tensor"
);
TORCH_CHECK
(
freqs
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
freqs
.
size
(
1
)
==
1
&&
freqs
.
size
(
2
)
==
1
,
"expected the second and third dims of the freqs tensor equal 1"
);
TORCH_CHECK
(
input
.
size
(
2
)
>=
freqs
.
size
(
3
),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor"
);
TORCH_CHECK
(
freqs
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"Dtype of the freqs tensor must be float"
);
// input sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// h: head num
// d: dim of each head
const
int
t
=
input
.
size
(
0
);
const
int
h
=
input
.
size
(
1
);
const
int
d
=
input
.
size
(
2
);
// input strides
const
int
stride_t
=
input
.
stride
(
0
);
const
int
stride_h
=
input
.
stride
(
1
);
const
int
stride_d
=
input
.
stride
(
2
);
// batch size
const
int
b
=
cu_seqlens
.
size
(
0
)
-
1
;
// freqs' shape is (max_s, 1, 1, d2)
const
int
max_s
=
freqs
.
size
(
0
);
const
int
d2
=
freqs
.
size
(
3
);
// output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
auto
output
=
torch
::
empty
({
t
,
h
,
d
},
act_options
);
// output strides
const
int
o_stride_t
=
output
.
stride
(
0
);
const
int
o_stride_h
=
output
.
stride
(
1
);
const
int
o_stride_d
=
output
.
stride
(
2
);
auto
input_cu
=
makeTransformerEngineTensor
(
input
);
auto
cu_seqlens_cu
=
makeTransformerEngineTensor
(
cu_seqlens
);
auto
freqs_cu
=
makeTransformerEngineTensor
(
freqs
);
auto
output_cu
=
makeTransformerEngineTensor
(
output
);
nvte_fused_rope_thd_forward
(
input_cu
.
data
(),
cu_seqlens_cu
.
data
(),
freqs_cu
.
data
(),
output_cu
.
data
(),
cp_size
,
cp_rank
,
max_s
,
b
,
h
,
d
,
d2
,
stride_t
,
stride_h
,
stride_d
,
o_stride_t
,
o_stride_h
,
o_stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
return
output
;
}
at
::
Tensor
fused_rope_thd_backward
(
const
at
::
Tensor
&
output_grads
,
const
at
::
Tensor
&
cu_seqlens
,
const
at
::
Tensor
&
freqs
,
const
int
cp_size
,
const
int
cp_rank
)
{
using
namespace
transformer_engine
::
pytorch
;
TORCH_CHECK
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
,
"expected 1D tensor"
);
TORCH_CHECK
(
freqs
.
dim
()
==
4
,
"expected 4D tensor"
);
TORCH_CHECK
(
freqs
.
size
(
1
)
==
1
&&
freqs
.
size
(
2
)
==
1
,
"expected the second and third dims of the freqs tensor equal 1"
);
TORCH_CHECK
(
output_grads
.
size
(
2
)
>=
freqs
.
size
(
3
),
TORCH_CHECK
(
s
*
cp_size
<=
freqs
.
size
(
0
),
"expected freqs tensor has a longer sequence length than output_grads"
);
TORCH_CHECK
(
d
>=
d2
,
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor"
);
TORCH_CHECK
(
freqs
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"Dtype of the freqs tensor must be float"
);
// output_grads sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// h: head num
// d: dim of each head
const
int
t
=
output_grads
.
size
(
0
);
const
int
h
=
output_grads
.
size
(
1
);
const
int
d
=
output_grads
.
size
(
2
);
// output_grads strides
const
int
stride_t
=
output_grads
.
stride
(
0
);
const
int
stride_h
=
output_grads
.
stride
(
1
);
const
int
stride_d
=
output_grads
.
stride
(
2
);
// batch size
const
int
b
=
cu_seqlens
.
size
(
0
)
-
1
;
// freqs' shape is (max_s, 1, 1, d2)
const
int
max_s
=
freqs
.
size
(
0
);
const
int
d2
=
freqs
.
size
(
3
);
auto
act_options
=
output_grads
.
options
().
requires_grad
(
false
);
auto
input_grads
=
torch
::
empty
({
t
,
h
,
d
},
act_options
);
const
int
o_stride_t
=
input_grads
.
stride
(
0
);
const
int
o_stride_h
=
input_grads
.
stride
(
1
);
const
int
o_stride_d
=
input_grads
.
stride
(
2
);
auto
output_grads_cu
=
makeTransformerEngineTensor
(
output_grads
);
auto
cu_seqlens_cu
=
makeTransformerEngineTensor
(
cu_seqlens
);
auto
freqs_cu
=
makeTransformerEngineTensor
(
freqs
);
auto
input_grads_cu
=
makeTransformerEngineTensor
(
input_grads
);
nvte_fused_rope_thd_backward
(
output_grads_cu
.
data
(),
cu_seqlens_cu
.
data
(),
freqs_cu
.
data
(),
input_grads_cu
.
data
(),
cp_size
,
cp_rank
,
max_s
,
b
,
h
,
d
,
d2
,
stride_t
,
stride_h
,
stride_d
,
o_stride_t
,
o_stride_h
,
o_stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
auto
cu_seqlens_cu
=
transformer_engine
::
TensorWrapper
();
// empty cu_seqlens tensor
nvte_fused_rope_backward
(
output_grads_cu
.
data
(),
cu_seqlens_cu
.
data
(),
freqs_cu
.
data
(),
input_grads_cu
.
data
(),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_s
,
stride_b
,
stride_h
,
stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
return
input_grads
;
}
transformer_engine/pytorch/csrc/extensions/attention.cu
View file @
ab3e5a92
...
...
@@ -3,9 +3,11 @@
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "kv_cache.cuh"
#include "thd_utils.cuh"
#include "transformer_engine/transformer_engine.h"
constexpr
int
block_size
=
512
;
constexpr
int
ctas_per_sm
=
4
;
...
...
@@ -95,11 +97,11 @@ std::vector<py::object> fused_attn_fwd(
NVTE_Mask_Type
attn_mask_type
,
const
std
::
vector
<
int64_t
>
window_size
,
const
at
::
Tensor
cu_seqlens_q
,
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
at
::
ScalarType
fake_dtype
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
const
c10
::
optional
<
at
::
Tensor
>
page_table_k
,
const
c10
::
optional
<
at
::
Tensor
>
page_table_v
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
c10
::
optional
<
at
::
Tensor
>
Bias
,
const
c10
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
)
{
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
const
std
::
optional
<
at
::
Tensor
>
page_table_k
,
const
std
::
optional
<
at
::
Tensor
>
page_table_v
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
Bias
,
const
std
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
)
{
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
#else
...
...
@@ -289,8 +291,8 @@ std::vector<py::object> fused_attn_bwd(
const
at
::
Tensor
cu_seqlens_kv
,
const
py
::
handle
Q
,
const
py
::
handle
K
,
const
py
::
handle
V
,
const
py
::
handle
O
,
const
py
::
handle
dO
,
const
at
::
ScalarType
fake_dtype
,
const
transformer_engine
::
DType
dqkv_type
,
const
std
::
vector
<
at
::
Tensor
>
Aux_CTX_Tensors
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
c10
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
py
::
handle
dp_quantizer
,
py
::
handle
dqkv_quantizer
)
{
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
...
...
@@ -461,13 +463,13 @@ std::vector<py::object> fused_attn_bwd(
nvte_tensor_pack_create
(
&
nvte_aux_tensor_pack
);
nvte_aux_tensor_pack
.
size
=
Aux_CTX_Tensors
.
size
();
for
(
size_t
i
=
0
;
i
<
nvte_aux_tensor_pack
.
size
;
++
i
)
{
std
::
vector
<
int64_t
>
tmp
(
Aux_CTX_Tensors
[
i
].
sizes
().
vec
()
)
;
auto
temp_vec
=
std
::
vector
<
size_t
>
(
tmp
.
begin
(),
tmp
.
end
());
const
NVTEShape
temp_shape
=
{
temp_vec
.
data
(),
temp_vec
.
size
()};
const
std
::
vector
<
int64_t
>
&
signed_shape
=
Aux_CTX_Tensors
[
i
].
sizes
().
vec
();
const
std
::
vector
<
size_t
>
tmp
(
signed_shape
.
begin
(),
signed_shape
.
end
());
NVTEBasicTensor
temp_data
=
{
Aux_CTX_Tensors
[
i
].
data_ptr
(),
static_cast
<
NVTEDType
>
(
GetTransformerEngineDType
(
Aux_CTX_Tensors
[
i
].
scalar_type
())),
temp_shape
};
nvte_make_shape
(
tmp
.
data
(),
tmp
.
size
())
};
nvte_set_tensor_param
(
&
nvte_aux_tensor_pack
.
tensors
[
i
],
kNVTERowwiseData
,
&
temp_data
);
}
...
...
transformer_engine/pytorch/csrc/extensions/cast.cpp
View file @
ab3e5a92
...
...
@@ -46,6 +46,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
if
(
te_output
.
numel
()
==
0
)
return
out
;
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_noop_tensor
(
te_noop
.
data
());
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
...
...
@@ -61,15 +64,21 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
}
QuantizationConfigWrapper
quant_config
;
// this config is used for cs scaling factor computation
// because compute scale is cannot be fused with quantize kernel
// so in nvte_quantize_v2 with current scaling, the quant config is not used again
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
nvte_compute_scale_from_amax
(
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
te_output
.
defaultShape
);
}
else
if
(
detail
::
IsFloat8BlockwiseQuantizers
(
quantizer
.
ptr
()))
{
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_quantizer
.
get
());
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
}
nvte_quantize_
noop
(
te_input
.
data
(),
te_output
.
data
(),
te_noop
.
data
()
,
at
::
cuda
::
getCurrentCUDAStream
());
nvte_quantize_
v2
(
te_input
.
data
(),
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
return
out
;
}
...
...
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
View file @
ab3e5a92
...
...
@@ -157,15 +157,15 @@ void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool
char
*
ubuf_ptr
=
reinterpret_cast
<
char
*>
(
_ubuf
.
dptr
());
if
(
local_chunk
)
{
if
(
input_tensor
.
numel
()
*
_tp_size
>
(
int64_t
)
_ubuf
.
numel
())
if
(
input_tensor
.
numel
()
*
_tp_size
>
_ubuf
.
numel
())
NVTE_ERROR
(
"input is larger than the local communication buffer!"
);
if
(
input_tensor
.
element_size
()
!=
(
int64_t
)
_ubuf
.
element_size
())
if
(
input_tensor
.
element_size
()
!=
_ubuf
.
element_size
())
NVTE_ERROR
(
"input data type does not match communication buffer!"
);
ubuf_ptr
+=
(
_ubuf
.
numel
()
/
_tp_size
)
*
_tp_id
*
_ubuf
.
element_size
();
}
else
{
if
(
input_tensor
.
numel
()
>
(
int64_t
)
_ubuf
.
numel
())
if
(
input_tensor
.
numel
()
>
_ubuf
.
numel
())
NVTE_ERROR
(
"input is larger than the global communication buffer!"
);
if
(
input_tensor
.
element_size
()
!=
(
int64_t
)
_ubuf
.
element_size
())
if
(
input_tensor
.
element_size
()
!=
_ubuf
.
element_size
())
NVTE_ERROR
(
"input data type does not match communication buffer!"
);
}
...
...
@@ -189,7 +189,7 @@ py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk,
std
::
vector
<
int64_t
>
torch_shape
;
if
(
shape
.
has_value
())
{
torch_shape
=
shape
.
value
();
auto
requested
=
product
(
torch_shape
);
size_t
requested
=
product
(
torch_shape
);
auto
expected
=
local_chunk
?
_ubuf
.
numel
()
/
_tp_size
:
_ubuf
.
numel
();
NVTE_CHECK
(
requested
==
expected
,
"Number of elements in the requested shape ("
,
requested
,
") does not match allocated buffer size ("
,
expected
,
")!"
);
...
...
@@ -253,18 +253,18 @@ void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bo
at
::
cuda
::
CUDAStream
stream_main
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
local_chunk
)
{
// Copy input to the target ubuf chunk by rank offset
if
(
input_tensor
.
numel
()
*
_tp_size
>
(
int64_t
)
_ubuf
.
numel
())
if
(
input_tensor
.
numel
()
*
_tp_size
>
_ubuf
.
numel
())
NVTE_ERROR
(
"input is larger than the local communication buffer!"
);
if
(
input_tensor
.
element_size
()
!=
(
int64_t
)
_ubuf
.
element_size
())
if
(
input_tensor
.
element_size
()
!=
_ubuf
.
element_size
())
NVTE_ERROR
(
"input data type does not match communication buffer!"
);
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
_ubufs
[
_tp_id
].
dptr
(),
input_ptr
,
input_tensor
.
numel
()
*
input_tensor
.
element_size
(),
cudaMemcpyDeviceToDevice
,
(
cudaStream_t
)
stream_main
));
}
else
{
if
(
input_tensor
.
numel
()
>
(
int64_t
)
_ubuf
.
numel
())
if
(
input_tensor
.
numel
()
>
_ubuf
.
numel
())
NVTE_ERROR
(
"input is larger than the global communication buffer!"
);
if
(
input_tensor
.
element_size
()
!=
(
int64_t
)
_ubuf
.
element_size
())
if
(
input_tensor
.
element_size
()
!=
_ubuf
.
element_size
())
NVTE_ERROR
(
"input data type does not match communication buffer!"
);
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
_ubuf
.
dptr
(),
input_ptr
,
input_tensor
.
numel
()
*
input_tensor
.
element_size
(),
...
...
@@ -280,7 +280,7 @@ py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk,
std
::
vector
<
int64_t
>
torch_shape
;
if
(
shape
.
has_value
())
{
torch_shape
=
shape
.
value
();
auto
requested
=
product
(
torch_shape
);
size_t
requested
=
product
(
torch_shape
);
auto
expected
=
local_chunk
?
_ubufs
[
_tp_id
].
numel
()
:
_ubuf
.
numel
();
NVTE_CHECK
(
requested
==
expected
,
"Number of elements in the requested shape ("
,
requested
,
") does not match allocated buffer size ("
,
expected
,
")!"
);
...
...
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
ab3e5a92
...
...
@@ -21,6 +21,7 @@
#include "extensions.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h"
namespace
{
...
...
@@ -179,8 +180,15 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
const
int
sm_count
=
transformer_engine
::
cuda
::
sm_count
(
device_id
);
int
num_math_sms
=
sm_count
-
transformer_engine
::
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
sm_count
);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std
::
vector
<
std
::
optional
<
at
::
Tensor
>>
swizzled_scale_inverses_list
;
auto
main_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
A_tensor
.
numel
()
!=
0
&&
B_tensor
.
numel
()
!=
0
)
{
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
A_tensor
,
transa
)));
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
B_tensor
,
!
transb
)));
if
(
comm_overlap
)
{
// Prepare extra output tensor
TensorWrapper
extra_output_tensor
;
...
...
@@ -317,17 +325,18 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_pre_gelu_out_vector
,
te_workspace_vector
;
std
::
vector
<
TensorWrapper
>
wrappers
;
std
::
vector
<
at
::
Tensor
>
D_vectors
;
// Keep the swizzled scaling factor tensors alive during the GEMMs.
std
::
vector
<
std
::
optional
<
at
::
Tensor
>>
swizzled_scale_inverses_list
;
auto
none
=
py
::
none
();
std
::
vector
<
size_t
>
single_output_begins
;
std
::
vector
<
size_t
>
single_output_ends
;
int
slicing_dim
;
if
(
single_output
&&
D
==
std
::
nullopt
)
{
NVTE_ERROR
(
"not implemented, D should be allocated for single output case."
);
}
void
*
output_data_ptr
;
void
*
output_data_ptr
=
nullptr
;
if
(
single_output
)
{
output_data_ptr
=
(
*
D
)[
0
].
data_ptr
();
}
...
...
@@ -384,6 +393,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
continue
;
}
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
te_A
,
transa
)));
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
te_B
,
!
transb
)));
auto
te_D
=
makeTransformerEngineTensor
(
out_tensor
);
auto
te_bias
=
makeTransformerEngineTensor
(
bias
[
i
]);
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
pre_gelu_out
[
i
]);
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu
View file @
ab3e5a92
...
...
@@ -12,6 +12,8 @@
// #include <torch/all.h>
#include <assert.h>
#include <limits>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream>
...
...
@@ -47,8 +49,8 @@ struct ComputeScaleAndScaleInvFunctor {
n
-=
chunk_idx
*
chunk_size
;
for
(
int
i_start
=
threadIdx
.
x
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
float
scale_val
=
transformer_engine
::
compute_scale_from_amax
(
amax
[
i_start
],
max_fp8
,
force_pow_2_scales
,
epsilon
);
float
scale_val
=
transformer_engine
::
compute_scale_from_amax
(
amax
[
i_start
],
max_fp8
,
force_pow_2_scales
,
epsilon
,
std
::
numeric_limits
<
float
>::
max
()
);
scale
[
i_start
]
=
scale_val
;
transformer_engine
::
reciprocal
(
scale_inv
+
i_start
,
scale_val
);
}
...
...
transformer_engine/pytorch/csrc/extensions/normalization.cpp
View file @
ab3e5a92
...
...
@@ -150,6 +150,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel
if
(
force_unfused_kernel
)
{
QuantizationConfigWrapper
quant_config
;
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
...
...
@@ -166,15 +167,18 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
}
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
nvte_compute_scale_from_amax
(
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
out_cu
.
defaultShape
);
}
else
if
(
IsFloat8BlockwiseQuantizers
(
quantizer
.
ptr
()))
{
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_quantizer
.
get
());
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
}
nvte_quantize_
noop
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
nullptr
,
at
::
cuda
::
getCurrentCUDAStream
());
nvte_quantize_
v2
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
}
return
{
out
,
py
::
cast
(
mu
),
py
::
cast
(
rsigma
)};
...
...
@@ -293,6 +297,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel
if
(
force_unfused_kernel
)
{
QuantizationConfigWrapper
quant_config
;
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
...
...
@@ -309,15 +314,18 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
}
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
nvte_compute_scale_from_amax
(
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
out_cu
.
defaultShape
);
}
else
if
(
IsFloat8BlockwiseQuantizers
(
quantizer
.
ptr
()))
{
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
my_quantizer
.
get
());
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
}
nvte_quantize_
noop
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
nullptr
,
at
::
cuda
::
getCurrentCUDAStream
());
nvte_quantize_
v2
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
}
return
{
out
,
py
::
none
(),
py
::
cast
(
rsigma
)};
...
...
transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp
0 → 100644
View file @
ab3e5a92
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#ifdef NVTE_ENABLE_NVSHMEM
#include <nvshmem.h>
#include <nvshmem_api/nvshmem_waitkernel.h>
#include <nvshmemx.h>
#endif
#include <cuda.h>
#include <cuda_fp8.h>
#include <torch/cuda.h>
#include <torch/extension.h>
namespace
nvshmem_api
{
void
init_nvshmem_backend
(
c10d
::
ProcessGroup
*
process_group
)
{
#ifdef NVTE_ENABLE_NVSHMEM
nvshmemx_init_attr_t
attr
=
{};
nvshmemx_uniqueid_t
id
=
{};
int
my_rank
=
process_group
->
getRank
();
int
num_ranks
=
process_group
->
getSize
();
if
(
my_rank
==
0
)
{
nvshmemx_get_uniqueid
(
&
id
);
}
auto
backend_is_nccl
=
(
process_group
->
getBackendType
()
==
c10d
::
ProcessGroup
::
BackendType
::
NCCL
);
NVTE_CHECK
(
backend_is_nccl
,
"Currently only support NCCL boostrap for NVSHMEM"
);
auto
datatensor
=
torch
::
from_blob
(
reinterpret_cast
<
void
*>
(
&
id
),
{
static_cast
<
int64_t
>
(
sizeof
(
nvshmemx_uniqueid_t
)
/
sizeof
(
uint8_t
))},
at
::
device
(
torch
::
kCPU
).
dtype
(
torch
::
kUInt8
));
auto
datatmp
=
(
backend_is_nccl
)
?
datatensor
.
cuda
()
:
datatensor
;
c10d
::
BroadcastOptions
bcast_opts
;
bcast_opts
.
rootRank
=
0
;
std
::
vector
<
torch
::
Tensor
>
datachunk
=
{
datatmp
};
auto
work
=
process_group
->
broadcast
(
datachunk
,
bcast_opts
);
work
->
wait
();
if
(
backend_is_nccl
)
{
datatensor
.
copy_
(
datatmp
.
cpu
());
datatmp
=
torch
::
Tensor
();
}
nvshmemx_set_attr_uniqueid_args
(
my_rank
,
num_ranks
,
&
id
,
&
attr
);
nvshmemx_init_attr
(
NVSHMEMX_INIT_WITH_UNIQUEID
,
&
attr
);
NVTE_CHECK
(
my_rank
==
nvshmem_my_pe
(),
"my_rank: "
,
my_rank
,
" != nvshmem_my_pe(): "
,
nvshmem_my_pe
());
NVTE_CHECK
(
num_ranks
==
nvshmem_n_pes
(),
"num_ranks: "
,
num_ranks
,
" != nvshmem_n_pes(): "
,
nvshmem_n_pes
());
#else
NVTE_ERROR
(
"Internal TE error: init_nvshmem_backend cannot be initialized with valid PyTorch "
,
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"
);
#endif
}
void
nvshmem_wait_on_current_stream
(
torch
::
Tensor
signal
,
const
std
::
string
&
wait_kind
)
{
#ifdef NVTE_ENABLE_NVSHMEM
uint64_t
*
sig_addr
=
reinterpret_cast
<
uint64_t
*>
(
signal
.
data_ptr
());
cudaStream_t
cur_stream
=
(
cudaStream_t
)
at
::
cuda
::
getCurrentCUDAStream
();
WaitKind
wait_kind_enum
=
WaitKind
::
STREAM_WAIT
;
if
(
wait_kind
==
"kernel"
)
{
wait_kind_enum
=
WaitKind
::
KERNEL_WAIT
;
}
else
if
(
wait_kind
==
"nvshmem"
)
{
wait_kind_enum
=
WaitKind
::
NVSHMEM_WAIT
;
}
else
if
(
wait_kind
==
"stream"
)
{
wait_kind_enum
=
WaitKind
::
STREAM_WAIT
;
}
else
{
NVTE_ERROR
(
"Invalid wait kind: "
,
wait_kind
);
}
nvshmem_wait_on_stream
(
sig_addr
,
wait_kind_enum
,
cur_stream
);
#else
NVTE_ERROR
(
"Internal TE error: nvshmem_wait_on_current_stream cannot be initialized with valid PyTorch "
,
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"
);
#endif
}
torch
::
Tensor
create_nvshmem_tensor
(
const
std
::
vector
<
int64_t
>
&
shape
,
c10
::
ScalarType
dtype
)
{
#ifdef NVTE_ENABLE_NVSHMEM
auto
option_gpu
=
at
::
TensorOptions
().
dtype
(
dtype
).
device
(
at
::
kCUDA
).
device_index
(
c10
::
cuda
::
current_device
());
auto
size
=
torch
::
elementSize
(
dtype
)
*
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<>
());
return
at
::
from_blob
(
nvshmem_malloc
(
size
),
shape
,
[](
void
*
ptr
)
{
nvshmem_free
(
ptr
);
},
option_gpu
);
#else
NVTE_ERROR
(
"Internal TE error: create_nvshmem_tensor cannot be initialized with valid PyTorch "
,
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"
);
#endif
}
void
nvshmem_send_on_current_stream
(
torch
::
Tensor
src
,
torch
::
Tensor
dst
,
int
peer
,
torch
::
Tensor
signal
)
{
#ifdef NVTE_ENABLE_NVSHMEM
void
*
src_ptr
=
reinterpret_cast
<
void
*>
(
src
.
data_ptr
());
void
*
dst_ptr
=
reinterpret_cast
<
void
*>
(
dst
.
data_ptr
());
uint64_t
*
sig_addr
=
reinterpret_cast
<
uint64_t
*>
(
signal
.
data_ptr
());
auto
nelement
=
src
.
numel
()
*
src
.
element_size
();
uint64_t
sigval
=
1
;
at
::
cuda
::
CUDAStream
cur_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
nvshmemx_putmem_signal_on_stream
(
dst_ptr
,
src_ptr
,
nelement
,
sig_addr
,
sigval
,
NVSHMEM_SIGNAL_SET
,
peer
,
(
cudaStream_t
)
cur_stream
);
#else
NVTE_ERROR
(
"Internal TE error: nvshmem_send_on_current_stream cannot be initialized with valid PyTorch "
,
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"
);
#endif
}
void
nvshmem_finalize
()
{
#ifdef NVTE_ENABLE_NVSHMEM
nvshmem_finalize
();
#else
NVTE_ERROR
(
"Internal TE error: nvshmem_finalize cannot be initialized with valid PyTorch "
,
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"
);
#endif
}
}
// namespace nvshmem_api
transformer_engine/pytorch/csrc/extensions/padding.cpp
View file @
ab3e5a92
...
...
@@ -17,7 +17,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
NVTE_CHECK
(
input
.
dim
()
==
2
,
"Dimension of input must equal 2."
);
NVTE_CHECK
(
output
.
dim
()
==
2
,
"Dimension of output must equal 2."
);
const
int
num_tensors
=
input_row_list
.
size
();
const
auto
num_tensors
=
input_row_list
.
size
();
// Extract properties from PyTorch tensors
std
::
vector
<
void
*>
input_dptr_list
,
output_dptr_list
;
std
::
vector
<
std
::
vector
<
size_t
>>
input_shape_list
,
output_shape_list
;
...
...
transformer_engine/pytorch/csrc/extensions/permutation.cu
View file @
ab3e5a92
...
...
@@ -52,18 +52,11 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
sorted_indices_ptr
,
row_id_ptr
,
sorted_row_id_ptr
,
num_tokens
*
topK
);
// Activations type
at
::
ScalarType
_st
;
if
(
dtype
==
transformer_engine
::
DType
::
kFloat8E4M3
||
dtype
==
transformer_engine
::
DType
::
kFloat8E5M2
)
_st
=
at
::
ScalarType
::
Byte
;
else
_st
=
input
.
scalar_type
();
// Output buffer alloc
num_out_tokens
=
(
num_out_tokens
>
0
)
?
num_out_tokens
:
num_tokens
*
topK
;
at
::
Tensor
permuted_output
=
torch
::
empty
(
{
num_out_tokens
,
num_cols
},
torch
::
dtype
(
_st
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
at
::
Tensor
permuted_output
=
torch
::
empty
({
num_out_tokens
,
num_cols
},
torch
::
dtype
(
input
.
scalar_type
()).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
at
::
Tensor
row_id_map
=
torch
::
empty
(
{
num_tokens
*
topK
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
...
...
@@ -100,17 +93,10 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
using
namespace
transformer_engine
::
pytorch
;
int
num_cols
=
input
.
size
(
1
);
// Activations type
at
::
ScalarType
_st
;
if
(
dtype
==
transformer_engine
::
DType
::
kFloat8E4M3
||
dtype
==
transformer_engine
::
DType
::
kFloat8E5M2
)
_st
=
at
::
ScalarType
::
Byte
;
else
_st
=
input
.
scalar_type
();
// Output buffer alloc
at
::
Tensor
unpermuted_output
=
torch
::
empty
(
{
num_tokens
,
num_cols
},
torch
::
dtype
(
_st
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
at
::
Tensor
unpermuted_output
=
torch
::
empty
({
num_tokens
,
num_cols
},
torch
::
dtype
(
input
.
scalar_type
()).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
...
@@ -136,17 +122,10 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
const
int
num_tokens
=
(
prob
.
numel
()
>
0
)
?
prob
.
size
(
0
)
:
row_id_map
.
size
(
0
);
int
num_cols
=
input_bwd
.
size
(
1
);
// Activations type
at
::
ScalarType
_st
;
if
(
dtype
==
transformer_engine
::
DType
::
kFloat8E4M3
||
dtype
==
transformer_engine
::
DType
::
kFloat8E5M2
)
_st
=
at
::
ScalarType
::
Byte
;
else
_st
=
input_bwd
.
scalar_type
();
// Output buffer alloc
at
::
Tensor
act_grad
=
torch
::
empty
({
input_fwd
.
size
(
0
),
num_cols
},
torch
::
dtype
(
_st
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
at
::
Tensor
act_grad
=
torch
::
empty
({
input_fwd
.
size
(
0
),
num_cols
},
torch
::
dtype
(
input_bwd
.
scalar_type
()).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
at
::
Tensor
prob_grad
=
torch
::
empty
(
{
num_tokens
,
topK
},
torch
::
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
...
...
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
ab3e5a92
...
...
@@ -28,6 +28,9 @@ PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr;
PyTypeObject
*
MXFP8TensorPythonClass
=
nullptr
;
/// TODO Remove
PyTypeObject
*
MXFP8TensorBasePythonClass
=
nullptr
;
PyTypeObject
*
MXFP8QuantizerClass
=
nullptr
;
PyTypeObject
*
Float8BlockwiseQTensorPythonClass
=
nullptr
;
PyTypeObject
*
Float8BlockwiseQTensorBasePythonClass
=
nullptr
;
PyTypeObject
*
Float8BlockwiseQuantizerClass
=
nullptr
;
void
init_float8_extension
()
{
if
(
Float8TensorPythonClass
)
return
;
...
...
@@ -61,9 +64,31 @@ void init_mxfp8_extension() {
"Internal error: could not initialize pyTorch MXFP8 extension."
);
}
void
init_float8blockwise_extension
()
{
if
(
Float8BlockwiseQTensorBasePythonClass
)
return
;
auto
fp8_module
=
py
::
module_
::
import
(
"transformer_engine.pytorch.tensor.float8_blockwise_tensor"
);
auto
fp8_base_module
=
py
::
module_
::
import
(
"transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base"
);
Float8BlockwiseQuantizerClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
fp8_module
.
ptr
(),
"Float8BlockQuantizer"
));
Float8BlockwiseQTensorBasePythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
fp8_base_module
.
ptr
(),
"Float8BlockwiseQTensorBase"
));
Float8BlockwiseQTensorPythonClass
=
reinterpret_cast
<
PyTypeObject
*>
(
PyObject_GetAttrString
(
fp8_module
.
ptr
(),
"Float8BlockwiseQTensor"
));
NVTE_CHECK
(
Float8BlockwiseQuantizerClass
!=
nullptr
,
"Internal error: could not initialize pyTorch float8blockwise extension."
);
NVTE_CHECK
(
Float8BlockwiseQTensorBasePythonClass
!=
nullptr
,
"Internal error: could not initialize pyTorch float8blockwise extension."
);
NVTE_CHECK
(
Float8BlockwiseQTensorPythonClass
!=
nullptr
,
"Internal error: could not initialize pyTorch float8blockwise extension."
);
}
void
init_extension
()
{
init_float8_extension
();
init_mxfp8_extension
();
init_float8blockwise_extension
();
}
}
// namespace transformer_engine::pytorch
...
...
@@ -76,6 +101,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"output"
)
=
py
::
none
(),
py
::
arg
(
"noop"
)
=
py
::
none
());
m
.
def
(
"dequantize"
,
&
transformer_engine
::
pytorch
::
dequantize
,
"Dequantize"
,
py
::
arg
(
"input"
),
py
::
arg
(
"otype"
));
m
.
def
(
"bgrad_quantize"
,
transformer_engine
::
pytorch
::
bgrad_quantize
,
"Compute bias gradient and quantize"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"generic_gemm"
,
transformer_engine
::
pytorch
::
gemm
,
"Compute GEMM (matrix-matrix multiply)"
,
...
...
@@ -170,15 +196,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"ln_out"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"otype"
),
py
::
arg
(
"sm_margin"
),
py
::
arg
(
"zero_centered_gamma"
));
m
.
def
(
"rmsnorm_bwd"
,
&
rmsnorm_bwd
,
"Backward of RMSNorm"
);
m
.
def
(
"fused_multi_quantize"
,
&
fused_multi_quantize
,
"Fused Multi-tensor Cast + Transpose"
,
py
::
arg
(
"input_list"
),
py
::
arg
(
"output_list"
),
py
::
arg
(
"quantizer_list"
),
py
::
arg
(
"otype"
));
m
.
def
(
"fused_multi_quantize"
,
&
transformer_engine
::
pytorch
::
fused_multi_quantize
,
"Fused Multi-tensor Cast + Transpose"
,
py
::
arg
(
"input_list"
),
py
::
arg
(
"output_list"
),
py
::
arg
(
"quantizer_list"
),
py
::
arg
(
"otype"
));
m
.
def
(
"te_general_grouped_gemm"
,
&
te_general_grouped_gemm
,
"Grouped GEMM"
);
#ifdef USE_ROCM
m
.
def
(
"te_batchgemm_ts"
,
&
te_batchgemm_ts
,
"Batched GEMM"
);
/// rocblas
#endif
m
.
def
(
"fp8_transpose"
,
&
fp8_transpose
,
"Transpose with FP8 I/O"
,
py
::
arg
(
"input"
),
py
::
arg
(
"dtype"
),
py
::
kw_only
(),
py
::
arg
(
"out"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fp8_transpose"
,
&
transformer_engine
::
pytorch
::
fp8_transpose
,
"Transpose with FP8 I/O"
,
py
::
arg
(
"input"
),
py
::
arg
(
"dtype"
),
py
::
kw_only
(),
py
::
arg
(
"out"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"get_fused_attn_backend"
,
&
get_fused_attn_backend
,
"Get Fused Attention backend"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"compute_amax"
,
&
compute_amax
,
"Compute amax"
,
py
::
arg
(
"input"
),
py
::
arg
(
"amax"
));
...
...
@@ -206,10 +234,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_rope_backward"
,
&
fused_rope_backward
,
"Fused Apply RoPE BWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_rope_thd_forward"
,
&
fused_rope_thd_forward
,
"Fused Apply RoPE FWD for thd format"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_rope_thd_backward"
,
&
fused_rope_thd_backward
,
"Fused Apply RoPE BWD for thd format"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// Misc
m
.
def
(
"get_cublasLt_version"
,
&
get_cublasLt_version
,
"Get cublasLt version"
,
...
...
@@ -240,6 +264,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Generate partitioned indices for inputs in THD format"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// nvshmem functions
m
.
def
(
"init_nvshmem_backend"
,
&
nvshmem_api
::
init_nvshmem_backend
,
"Initialize nvshmem backend with Pytorch distributed process groups"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"create_nvshmem_tensor"
,
&
nvshmem_api
::
create_nvshmem_tensor
,
"Create a tensor in NVSHMEM shared memory"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"nvshmem_send_on_current_stream"
,
&
nvshmem_api
::
nvshmem_send_on_current_stream
,
"Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"nvshmem_wait_on_current_stream"
,
&
nvshmem_api
::
nvshmem_wait_on_current_stream
,
"Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA "
"stream"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"nvshmem_finalize"
,
&
nvshmem_api
::
nvshmem_finalize
,
"Clean up and finalize the NVSHMEM communication backend and free associated resources"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// multi-tensor functions
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
"Fused overflow check + scale for a list of contiguous tensors"
,
...
...
Prev
1
…
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment