Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
87e3e56e
Commit
87e3e56e
authored
Aug 27, 2025
by
yuguo
Browse files
Merge commit '
734bcedd
' of...
Merge commit '
734bcedd
' of
https://github.com/NVIDIA/TransformerEngine
parents
2f11bd2e
734bcedd
Changes
217
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
660 additions
and
390 deletions
+660
-390
transformer_engine/jax/sharding.py
transformer_engine/jax/sharding.py
+25
-34
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
.../attention/dot_product_attention/dot_product_attention.py
+1
-1
transformer_engine/pytorch/attention/dot_product_attention/utils.py
...r_engine/pytorch/attention/dot_product_attention/utils.py
+3
-3
transformer_engine/pytorch/attention/multi_head_attention.py
transformer_engine/pytorch/attention/multi_head_attention.py
+116
-23
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+16
-0
transformer_engine/pytorch/cpu_offload.py
transformer_engine/pytorch/cpu_offload.py
+15
-3
transformer_engine/pytorch/cross_entropy.py
transformer_engine/pytorch/cross_entropy.py
+13
-2
transformer_engine/pytorch/csrc/common.cpp
transformer_engine/pytorch/csrc/common.cpp
+2
-2
transformer_engine/pytorch/csrc/common.h
transformer_engine/pytorch/csrc/common.h
+78
-20
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+17
-3
transformer_engine/pytorch/csrc/extensions/activation.cpp
transformer_engine/pytorch/csrc/extensions/activation.cpp
+74
-69
transformer_engine/pytorch/csrc/extensions/attention.cpp
transformer_engine/pytorch/csrc/extensions/attention.cpp
+45
-8
transformer_engine/pytorch/csrc/extensions/bias.cpp
transformer_engine/pytorch/csrc/extensions/bias.cpp
+198
-55
transformer_engine/pytorch/csrc/extensions/cast.cpp
transformer_engine/pytorch/csrc/extensions/cast.cpp
+6
-130
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
...rmer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
+14
-4
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+28
-13
transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp
...rmer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp
+5
-11
transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
...ne/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
+1
-2
transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
...er_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
+2
-5
transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
...mer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
+1
-2
No files found.
transformer_engine/jax/sharding.py
View file @
87e3e56e
...
@@ -15,10 +15,10 @@ from dataclasses import dataclass
...
@@ -15,10 +15,10 @@ from dataclasses import dataclass
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
import
warnings
import
warnings
from
jax.interpreters
import
pxla
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
jax.sharding
import
PartitionSpec
from
jax.interpreters
import
pxla
from
jax.sharding
import
PartitionSpec
,
get_abstract_mesh
import
numpy
as
np
import
numpy
as
np
_PXLA_THREAD_RESOURCES
=
pxla
.
thread_resources
_PXLA_THREAD_RESOURCES
=
pxla
.
thread_resources
...
@@ -86,24 +86,29 @@ def get_sharding_map_logic_axis_to_mesh_axis():
...
@@ -86,24 +86,29 @@ def get_sharding_map_logic_axis_to_mesh_axis():
return
te_logical_axis_to_mesh_axis
return
te_logical_axis_to_mesh_axis
def
generate_pspec
(
logical_axis_names
):
def
_
generate_pspec
(
logical_axis_names
):
"""
"""
Convert logical axes to PartitionSpec
Convert TransformerEngine logical axes (e.g. BATCH_AXES) to a JAX PartitionSpec.
Note, this method does not support Flax logical axes.
Args:
logical_axis_names: TransformerEngine logical axes to convert to a JAX PartitionSpec.
Returns:
A JAX PartitionSpec with the mesh axes corresponding to the given TransformerEngine logical axis names
"""
"""
rules
=
get_sharding_map_logic_axis_to_mesh_axis
()
rules
=
get_sharding_map_logic_axis_to_mesh_axis
()
# mesh_axis_names = [rules[name] for name in logical_axis_names]
mesh_axis_names
=
[]
mesh_axis_names
=
[
rules
.
get
(
name
)
for
name
in
logical_axis_names
]
for
name
in
logical_axis_names
:
axis_name
=
rules
[
name
]
if
name
in
rules
else
None
mesh_axis_names
.
append
(
axis_name
)
pspec
=
jax
.
sharding
.
PartitionSpec
(
*
mesh_axis_names
)
pspec
=
jax
.
sharding
.
PartitionSpec
(
*
mesh_axis_names
)
return
pspec
return
pspec
def
with_sharding_constraint
(
x
:
jnp
.
array
,
pspec
:
PartitionSpec
):
def
with_sharding_constraint
(
x
:
jnp
.
array
,
pspec
:
PartitionSpec
):
"""
"""
A wrapper function to jax.lax.with_sharding_constraint to
A wrapper function to jax.lax.with_sharding_constraint
support the case that Mesh is empty.
1. Does nothing if mesh is empty.
2. If all mesh axes are manual axes, replaces pspec with all Nones.
3. Otherwise, strips only the manual axes.
"""
"""
if
pspec
is
None
:
if
pspec
is
None
:
return
x
return
x
...
@@ -111,7 +116,14 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
...
@@ -111,7 +116,14 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
mesh
=
_PXLA_THREAD_RESOURCES
.
env
.
physical_mesh
mesh
=
_PXLA_THREAD_RESOURCES
.
env
.
physical_mesh
if
mesh
.
empty
:
if
mesh
.
empty
:
return
x
return
x
return
jax
.
lax
.
with_sharding_constraint
(
x
,
pspec
)
# We want to exclude the axes that already used by shard_map and shard_map
# only sets those in the abstract_mesh, not the physical one
manual_axis_names
=
get_abstract_mesh
().
manual_axes
cleaned_axis_names
=
tuple
(
name
if
name
not
in
manual_axis_names
else
None
for
name
in
pspec
)
cleaned_pspec
=
PartitionSpec
(
*
cleaned_axis_names
)
return
jax
.
lax
.
with_sharding_constraint
(
x
,
cleaned_pspec
)
def
with_sharding_constraint_by_logical_axes
(
def
with_sharding_constraint_by_logical_axes
(
...
@@ -159,7 +171,7 @@ def with_sharding_constraint_by_logical_axes(
...
@@ -159,7 +171,7 @@ def with_sharding_constraint_by_logical_axes(
# If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table
# If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table
assert
len
(
x
.
shape
)
==
len
(
logical_axis_names
)
assert
len
(
x
.
shape
)
==
len
(
logical_axis_names
)
pspec
=
generate_pspec
(
logical_axis_names
)
pspec
=
_
generate_pspec
(
logical_axis_names
)
return
with_sharding_constraint
(
x
,
pspec
)
return
with_sharding_constraint
(
x
,
pspec
)
...
@@ -383,24 +395,3 @@ class ShardingType(Enum):
...
@@ -383,24 +395,3 @@ class ShardingType(Enum):
TP_ROW
=
(
MajorShardingType
.
TP
,
"tp_row"
)
TP_ROW
=
(
MajorShardingType
.
TP
,
"tp_row"
)
DP_TP_COL
=
(
MajorShardingType
.
DPTP
,
"dp_tp_col"
)
DP_TP_COL
=
(
MajorShardingType
.
DPTP
,
"dp_tp_col"
)
DP_TP_ROW
=
(
MajorShardingType
.
DPTP
,
"dp_tp_row"
)
DP_TP_ROW
=
(
MajorShardingType
.
DPTP
,
"dp_tp_row"
)
def
get_non_contracting_logical_axes
(
ndim
,
logical_axes
:
tuple
[
Optional
[
str
]],
contracting_dims
)
->
tuple
[
Optional
[
str
]]:
"""Get logical axes for non-contracting dimensions.
Args:
ndim: Number of dimensions in the tensor.
logical_axes: Tuple of logical axes for each dimension.
contracting_dims: Set of dimensions that are being contracted.
Returns:
Tuple of logical axes for non-contracting dimensions.
"""
assert
logical_axes
is
not
None
,
"Logical axes must be a tuple and cannot be None."
assert
len
(
logical_axes
)
==
ndim
,
"Logical axes must match the number of dimensions."
non_contracting_dims
=
[
i
for
i
in
range
(
ndim
)
if
i
not
in
contracting_dims
]
non_contracting_logical_axes
=
tuple
(
logical_axes
[
i
]
for
i
in
non_contracting_dims
)
return
non_contracting_logical_axes
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
View file @
87e3e56e
...
@@ -630,7 +630,7 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -630,7 +630,7 @@ class DotProductAttention(TransformerEngineBaseModule):
If true, there are padding tokens between individual sequences in a packed batch.
If true, there are padding tokens between individual sequences in a packed batch.
"""
"""
with
self
.
prepare_forward
(
with
torch
.
cuda
.
device
(
query_layer
.
device
),
self
.
prepare_forward
(
query_layer
,
query_layer
,
num_gemms
=
3
,
num_gemms
=
3
,
allow_non_contiguous
=
True
,
allow_non_contiguous
=
True
,
...
...
transformer_engine/pytorch/attention/dot_product_attention/utils.py
View file @
87e3e56e
...
@@ -438,8 +438,8 @@ def get_attention_backend(
...
@@ -438,8 +438,8 @@ def get_attention_backend(
# | FP8 | non-paged/paged | sm90 | thd | >= 1
# | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if
inference_params
is
not
None
:
if
inference_params
is
not
None
:
if
device_compute_capability
==
(
8
,
9
)
and
cudnn_version
<
(
9
,
12
,
0
):
if
device_compute_capability
==
(
8
,
9
)
and
cudnn_version
<
=
(
9
,
12
,
0
):
logger
.
debug
(
"Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.12"
)
logger
.
debug
(
"Disabling FusedAttention for KV caching for sm89 and cuDNN <
=
9.12"
)
use_fused_attention
=
False
use_fused_attention
=
False
if
context_parallel
:
if
context_parallel
:
logger
.
debug
(
"Disabling all backends for KV caching with context parallelism"
)
logger
.
debug
(
"Disabling all backends for KV caching with context parallelism"
)
...
@@ -625,7 +625,7 @@ def get_attention_backend(
...
@@ -625,7 +625,7 @@ def get_attention_backend(
" bias for THD format"
" bias for THD format"
)
)
use_fused_attention
=
False
use_fused_attention
=
False
elif
fp8
and
head_dim_qk
!=
head_dim_v
:
elif
fp8
and
fp8_meta
[
"recipe"
].
fp8_dpa
and
head_dim_qk
!=
head_dim_v
:
logger
.
debug
(
logger
.
debug
(
"Disabling FusedAttention as it does not support context parallelism with FP8"
"Disabling FusedAttention as it does not support context parallelism with FP8"
" MLA attention"
" MLA attention"
...
...
transformer_engine/pytorch/attention/multi_head_attention.py
View file @
87e3e56e
...
@@ -11,7 +11,7 @@ from transformer_engine.debug.pytorch.debug_state import TEDebugState
...
@@ -11,7 +11,7 @@ from transformer_engine.debug.pytorch.debug_state import TEDebugState
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.module.base
import
TransformerEngineBaseModule
from
transformer_engine.pytorch.module.base
import
TransformerEngineBaseModule
from
transformer_engine.pytorch.module
import
LayerNormLinear
,
Linear
from
transformer_engine.pytorch.module
import
LayerNormLinear
,
Linear
,
RMSNorm
,
LayerNorm
from
transformer_engine.pytorch.ops.basic.l2normalization
import
L2Normalization
from
transformer_engine.pytorch.ops.basic.l2normalization
import
L2Normalization
from
transformer_engine.pytorch.utils
import
(
from
transformer_engine.pytorch.utils
import
(
SplitAlongDim
,
SplitAlongDim
,
...
@@ -175,14 +175,23 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -175,14 +175,23 @@ class MultiheadAttention(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument
fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`.
`fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False'
qk_norm_type: Optional[str], default = None
if set to `True`, L2 normalization is applied to query and key tensors
type of normalization to apply to query and key tensors.
after RoPE (if applicable) but before attention computation.
Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied.
This follows the Llama4 approach for QK normalization to improve
When 'L2Normalization', L2 normalization is applied to query and key tensors.
training stability and model performance.
When 'RMSNorm', RMS normalization is applied to query and key tensors.
When 'LayerNorm', layer normalization is applied to query and key tensors.
Normalization is applied after RoPE (if applicable) but before attention computation
when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach
for QK normalization to improve training stability and model performance.
qk_norm_eps: float, default = 1e-6
qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors.
epsilon value for normalization of query and key tensors.
Only used when `use_qk_norm` is True.
Only used when `qk_norm_type` is not None.
qk_norm_before_rope: bool, default = `False`
if set to `True`, query and key normalization is applied before rotary position
embedding. When `False` (default), normalization is applied after RoPE.
This parameter allows supporting different architectural variants that apply
QK normalization at different points.
seq_length: Optional[int], default = `None`
seq_length: Optional[int], default = `None`
sequence length of input samples. Needed for JIT Warmup, a technique where jit
sequence length of input samples. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are used for
fused functions are warmed up before training to ensure same kernels are used for
...
@@ -231,8 +240,9 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -231,8 +240,9 @@ class MultiheadAttention(torch.nn.Module):
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
qkv_format
:
str
=
"sbhd"
,
qkv_format
:
str
=
"sbhd"
,
name
:
str
=
None
,
name
:
str
=
None
,
use_
qk_norm
:
bool
=
Fals
e
,
qk_norm
_type
:
Optional
[
str
]
=
Non
e
,
qk_norm_eps
:
float
=
1e-6
,
qk_norm_eps
:
float
=
1e-6
,
qk_norm_before_rope
:
bool
=
False
,
seq_length
:
Optional
[
int
]
=
None
,
seq_length
:
Optional
[
int
]
=
None
,
micro_batch_size
:
Optional
[
int
]
=
None
,
micro_batch_size
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
...
@@ -264,6 +274,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -264,6 +274,7 @@ class MultiheadAttention(torch.nn.Module):
qkv_weight_interleaved
=
False
qkv_weight_interleaved
=
False
self
.
qkv_weight_interleaved
=
qkv_weight_interleaved
self
.
qkv_weight_interleaved
=
qkv_weight_interleaved
self
.
rotary_pos_interleaved
=
rotary_pos_interleaved
self
.
rotary_pos_interleaved
=
rotary_pos_interleaved
self
.
qk_norm_before_rope
=
qk_norm_before_rope
assert
attention_type
in
AttnTypes
,
f
"attention_type
{
attention_type
}
not supported"
assert
attention_type
in
AttnTypes
,
f
"attention_type
{
attention_type
}
not supported"
if
layer_number
is
not
None
:
if
layer_number
is
not
None
:
...
@@ -288,7 +299,6 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -288,7 +299,6 @@ class MultiheadAttention(torch.nn.Module):
self
.
hidden_size_kv
=
self
.
hidden_size_per_attention_head
*
self
.
num_gqa_groups
self
.
hidden_size_kv
=
self
.
hidden_size_per_attention_head
*
self
.
num_gqa_groups
self
.
name
=
name
self
.
name
=
name
self
.
use_qk_norm
=
use_qk_norm
common_gemm_kwargs
=
{
common_gemm_kwargs
=
{
"fuse_wgrad_accumulation"
:
fuse_wgrad_accumulation
,
"fuse_wgrad_accumulation"
:
fuse_wgrad_accumulation
,
...
@@ -300,13 +310,9 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -300,13 +310,9 @@ class MultiheadAttention(torch.nn.Module):
"device"
:
device
,
"device"
:
device
,
}
}
# Initialize L2 normalization modules for query and key if enabled
self
.
q_norm
,
self
.
k_norm
=
self
.
_create_qk_norm_modules
(
if
self
.
use_qk_norm
:
qk_norm_type
,
qk_norm_eps
,
device
,
seq_length
,
micro_batch_size
self
.
qk_norm
=
L2Normalization
(
)
eps
=
qk_norm_eps
,
seq_length
=
seq_length
,
micro_batch_size
=
micro_batch_size
,
)
qkv_parallel_mode
=
"column"
if
set_parallel_mode
else
None
qkv_parallel_mode
=
"column"
if
set_parallel_mode
else
None
...
@@ -427,6 +433,78 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -427,6 +433,78 @@ class MultiheadAttention(torch.nn.Module):
**
common_gemm_kwargs
,
**
common_gemm_kwargs
,
)
)
def
_create_qk_norm_modules
(
self
,
qk_norm_type
:
Optional
[
str
],
qk_norm_eps
:
float
,
device
:
Union
[
torch
.
device
,
str
],
seq_length
:
Optional
[
int
]
=
None
,
micro_batch_size
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
Optional
[
torch
.
nn
.
Module
],
Optional
[
torch
.
nn
.
Module
]]:
"""
Create query and key normalization modules based on the specified normalization type.
Parameters
----------
qk_norm_type : Optional[str]
Type of normalization to apply. Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'
qk_norm_eps : float
Epsilon value for numerical stability
device : Union[torch.device, str]
Device to place the normalization modules on
seq_length : Optional[int], default = None
Sequence length for L2Normalization optimization
micro_batch_size : Optional[int], default = None
Micro batch size for L2Normalization optimization
Returns
-------
Tuple[Optional[torch.nn.Module], Optional[torch.nn.Module]]
Query and key normalization modules (q_norm, k_norm)
"""
if
qk_norm_type
is
None
:
return
None
,
None
if
qk_norm_type
==
"L2Normalization"
:
l2_norm
=
L2Normalization
(
eps
=
qk_norm_eps
,
seq_length
=
seq_length
,
micro_batch_size
=
micro_batch_size
,
)
# L2Normalization is parameter-free, so we can share the same instance
return
l2_norm
,
l2_norm
if
qk_norm_type
==
"RMSNorm"
:
q_norm
=
RMSNorm
(
normalized_shape
=
self
.
hidden_size_per_attention_head
,
eps
=
qk_norm_eps
,
device
=
device
,
)
k_norm
=
RMSNorm
(
normalized_shape
=
self
.
hidden_size_per_attention_head
,
eps
=
qk_norm_eps
,
device
=
device
,
)
return
q_norm
,
k_norm
if
qk_norm_type
==
"LayerNorm"
:
q_norm
=
LayerNorm
(
normalized_shape
=
self
.
hidden_size_per_attention_head
,
eps
=
qk_norm_eps
,
device
=
device
,
)
k_norm
=
LayerNorm
(
normalized_shape
=
self
.
hidden_size_per_attention_head
,
eps
=
qk_norm_eps
,
device
=
device
,
)
return
q_norm
,
k_norm
raise
ValueError
(
f
"Unsupported QK norm type:
{
qk_norm_type
}
. "
"Supported types: ['L2Normalization', 'RMSNorm', 'LayerNorm']"
)
def
set_tensor_parallel_group
(
self
,
tp_group
:
Union
[
dist_group_type
,
None
])
->
None
:
def
set_tensor_parallel_group
(
self
,
tp_group
:
Union
[
dist_group_type
,
None
])
->
None
:
"""
"""
Set the tensor parallel group for the given
Set the tensor parallel group for the given
...
@@ -789,6 +867,14 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -789,6 +867,14 @@ class MultiheadAttention(torch.nn.Module):
)
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ===========================
# Apply normalization to query and key tensors (before RoPE if configured)
# ===========================
if
self
.
q_norm
is
not
None
and
self
.
qk_norm_before_rope
:
query_layer
=
self
.
q_norm
(
query_layer
)
key_layer
=
self
.
k_norm
(
key_layer
)
# ======================================================
# ======================================================
# Apply relative positional encoding (rotary embedding)
# Apply relative positional encoding (rotary embedding)
# ======================================================
# ======================================================
...
@@ -821,12 +907,19 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -821,12 +907,19 @@ class MultiheadAttention(torch.nn.Module):
q_pos_emb
=
q_pos_emb
[
sequence_start
:
sequence_end
,
...]
q_pos_emb
=
q_pos_emb
[
sequence_start
:
sequence_end
,
...]
k_pos_emb
=
k_pos_emb
[
sequence_start
:
sequence_end
,
...]
k_pos_emb
=
k_pos_emb
[
sequence_start
:
sequence_end
,
...]
if
pad_between_seqs
:
rotary_pos_cu_seq_lens_q
=
cu_seqlens_q_padded
rotary_pos_cu_seq_lens_kv
=
cu_seqlens_kv_padded
else
:
rotary_pos_cu_seq_lens_q
=
cu_seqlens_q
rotary_pos_cu_seq_lens_kv
=
cu_seqlens_kv
query_layer
=
apply_rotary_pos_emb
(
query_layer
=
apply_rotary_pos_emb
(
query_layer
,
query_layer
,
q_pos_emb
,
q_pos_emb
,
self
.
qkv_format
,
self
.
qkv_format
,
fused
=
True
,
fused
=
True
,
cu_seqlens
=
cu_seqlens_q
,
cu_seqlens
=
rotary_pos_
cu_seq
_
lens_q
,
cp_size
=
self
.
cp_size
,
cp_size
=
self
.
cp_size
,
cp_rank
=
self
.
cp_rank
,
cp_rank
=
self
.
cp_rank
,
interleaved
=
self
.
rotary_pos_interleaved
,
interleaved
=
self
.
rotary_pos_interleaved
,
...
@@ -836,19 +929,19 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -836,19 +929,19 @@ class MultiheadAttention(torch.nn.Module):
k_pos_emb
,
k_pos_emb
,
self
.
qkv_format
,
self
.
qkv_format
,
fused
=
True
,
fused
=
True
,
cu_seqlens
=
cu_seqlens_kv
,
cu_seqlens
=
rotary_pos_
cu_seq
_
lens_kv
,
cp_size
=
self
.
cp_size
,
cp_size
=
self
.
cp_size
,
cp_rank
=
self
.
cp_rank
,
cp_rank
=
self
.
cp_rank
,
interleaved
=
self
.
rotary_pos_interleaved
,
interleaved
=
self
.
rotary_pos_interleaved
,
)
)
# ===========================
# ===========================
# Apply
L2
normalization to query and key tensors
# Apply normalization to query and key tensors
(after RoPE if not applied before)
# ===========================
# ===========================
if
self
.
use_qk_norm
:
if
self
.
q_norm
is
not
None
and
not
self
.
qk_norm_before_rope
:
query_layer
=
self
.
q
k
_norm
(
query_layer
)
query_layer
=
self
.
q_norm
(
query_layer
)
key_layer
=
self
.
q
k_norm
(
key_layer
)
key_layer
=
self
.
k_norm
(
key_layer
)
# ===========================
# ===========================
# Core attention computation
# Core attention computation
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
87e3e56e
...
@@ -46,6 +46,15 @@ __all__ = [
...
@@ -46,6 +46,15 @@ __all__ = [
]
]
def
validate_gemm_scale
(
scale
:
Optional
[
float
],
required
:
bool
)
->
float
:
"""Validate whether a GEMM scaling factor is consistent with its usage"""
if
required
:
return
scale
if
scale
is
not
None
else
1.0
if
scale
not
in
(
0.0
,
None
):
raise
ValueError
(
"scale must be zero"
)
return
0.0
def
general_gemm
(
def
general_gemm
(
A
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
...
@@ -54,6 +63,8 @@ def general_gemm(
...
@@ -54,6 +63,8 @@ def general_gemm(
quantization_params
:
Optional
[
Quantizer
]
=
None
,
quantization_params
:
Optional
[
Quantizer
]
=
None
,
gelu
:
bool
=
False
,
gelu
:
bool
=
False
,
gelu_in
:
torch
.
Tensor
=
None
,
gelu_in
:
torch
.
Tensor
=
None
,
alpha
:
float
=
1.0
,
beta
:
Optional
[
float
]
=
None
,
accumulate
:
bool
=
False
,
accumulate
:
bool
=
False
,
layout
:
str
=
"TN"
,
layout
:
str
=
"TN"
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -72,6 +83,9 @@ def general_gemm(
...
@@ -72,6 +83,9 @@ def general_gemm(
transb
=
layout
[
1
]
==
"T"
transb
=
layout
[
1
]
==
"T"
# assert quantization_params is None, "FP8 output not supported yet"
# assert quantization_params is None, "FP8 output not supported yet"
alpha
=
validate_gemm_scale
(
alpha
,
True
)
beta
=
validate_gemm_scale
(
beta
,
accumulate
)
# if ub_type is not None:
# if ub_type is not None:
# assert ub is not None, (
# assert ub is not None, (
# f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires"
# f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires"
...
@@ -349,6 +363,8 @@ def general_gemm(
...
@@ -349,6 +363,8 @@ def general_gemm(
"comm_type"
:
ub_type
,
"comm_type"
:
ub_type
,
"extra_output"
:
extra_output
,
"extra_output"
:
extra_output
,
"bulk_overlap"
:
bulk_overlap
,
"bulk_overlap"
:
bulk_overlap
,
"alpha"
:
alpha
,
"beta"
:
beta
,
}
}
out
,
bias_grad
,
gelu_input
,
extra_output
=
tex
.
generic_gemm
(
*
args
,
**
kwargs
)
out
,
bias_grad
,
gelu_input
,
extra_output
=
tex
.
generic_gemm
(
*
args
,
**
kwargs
)
...
...
transformer_engine/pytorch/cpu_offload.py
View file @
87e3e56e
...
@@ -431,7 +431,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
...
@@ -431,7 +431,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
tensor
=
self
.
fp8_tensor_object_map
.
pop
(
tensor_tag
)
tensor
=
self
.
fp8_tensor_object_map
.
pop
(
tensor_tag
)
if
self
.
double_buffering
:
if
self
.
double_buffering
:
tensor
.
do_not_clear
=
True
tensor
.
_
do_not_clear
=
True
self
.
tensor_tag_to_buf
.
pop
(
tensor_tag
,
None
)
self
.
tensor_tag_to_buf
.
pop
(
tensor_tag
,
None
)
# the tensor should have been copied back in on_group_commit_backward()
# the tensor should have been copied back in on_group_commit_backward()
...
@@ -556,21 +556,33 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
...
@@ -556,21 +556,33 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
for
tensor_label
,
state
in
self
.
tensor_tag_to_state
.
items
():
for
tensor_label
,
state
in
self
.
tensor_tag_to_state
.
items
():
group_id
,
_
=
tensor_label
group_id
,
_
=
tensor_label
if
group_id
==
group_to_reload
:
if
group_id
==
group_to_reload
:
if
self
.
double_buffering
:
reload_buffer
=
self
.
reload_double_buffer
[
double_buffer_idx
][
buffer_idx
]
else
:
reload_buffer
=
None
if
isinstance
(
state
,
tuple
):
if
isinstance
(
state
,
tuple
):
recovered_tensor
=
SynchronizedGroupOffloadHandler
.
reload
(
recovered_tensor
=
SynchronizedGroupOffloadHandler
.
reload
(
state
,
True
,
self
.
reload_
double_buffer
[
double_buffer_idx
][
buffer_idx
]
state
,
True
,
reload_
buffer
)
)
buffer_idx
=
buffer_idx
+
1
buffer_idx
=
buffer_idx
+
1
self
.
tensor_tag_to_state
[
tensor_label
]
=
recovered_tensor
self
.
tensor_tag_to_state
[
tensor_label
]
=
recovered_tensor
elif
isinstance
(
state
,
list
):
elif
isinstance
(
state
,
list
):
tensor_list
=
[]
tensor_list
=
[]
for
state_tuple
in
state
:
for
state_tuple
in
state
:
if
self
.
double_buffering
:
reload_buffer
=
self
.
reload_double_buffer
[
double_buffer_idx
][
buffer_idx
]
else
:
reload_buffer
=
None
if
isinstance
(
state_tuple
,
tuple
):
if
isinstance
(
state_tuple
,
tuple
):
tensor_list
.
append
(
tensor_list
.
append
(
SynchronizedGroupOffloadHandler
.
reload
(
SynchronizedGroupOffloadHandler
.
reload
(
state_tuple
,
state_tuple
,
True
,
True
,
self
.
reload_
double_buffer
[
double_buffer_idx
][
buffer_idx
]
,
reload_
buffer
,
)
)
)
)
buffer_idx
=
buffer_idx
+
1
buffer_idx
=
buffer_idx
+
1
...
...
transformer_engine/pytorch/cross_entropy.py
View file @
87e3e56e
...
@@ -29,6 +29,7 @@ class CrossEntropyFunction(torch.autograd.Function):
...
@@ -29,6 +29,7 @@ class CrossEntropyFunction(torch.autograd.Function):
reduce_loss
=
False
,
reduce_loss
=
False
,
dist_process_group
=
None
,
dist_process_group
=
None
,
ignore_idx
=-
100
,
ignore_idx
=-
100
,
is_cg_capturable
=
False
,
):
):
"""
"""
The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each
The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each
...
@@ -47,10 +48,16 @@ class CrossEntropyFunction(torch.autograd.Function):
...
@@ -47,10 +48,16 @@ class CrossEntropyFunction(torch.autograd.Function):
tensor: The computed loss.
tensor: The computed loss.
"""
"""
loss
,
_input
=
triton_cross_entropy
.
cross_entropy_forward
(
loss
,
_input
=
triton_cross_entropy
.
cross_entropy_forward
(
_input
,
target
,
label_smoothing
,
reduce_loss
,
dist_process_group
,
ignore_idx
_input
,
target
,
label_smoothing
,
reduce_loss
,
dist_process_group
,
ignore_idx
,
)
)
ctx
.
save_for_backward
(
_input
.
detach
())
ctx
.
save_for_backward
(
_input
.
detach
())
ctx
.
is_cg_capturable
=
is_cg_capturable
return
loss
return
loss
@
staticmethod
@
staticmethod
...
@@ -66,13 +73,17 @@ class CrossEntropyFunction(torch.autograd.Function):
...
@@ -66,13 +73,17 @@ class CrossEntropyFunction(torch.autograd.Function):
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
"""
(
_input
,)
=
ctx
.
saved_tensors
(
_input
,)
=
ctx
.
saved_tensors
_input
=
triton_cross_entropy
.
cross_entropy_backward
(
_input
,
grad_output
)
_input
=
triton_cross_entropy
.
cross_entropy_backward
(
_input
,
grad_output
,
ctx
.
is_cg_capturable
)
return
(
return
(
_input
,
_input
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
)
...
...
transformer_engine/pytorch/csrc/common.cpp
View file @
87e3e56e
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
std
::
vector
<
size_t
>
getTensorShape
(
at
::
Tensor
t
)
{
std
::
vector
<
size_t
>
getTensorShape
(
const
at
::
Tensor
&
t
)
{
std
::
vector
<
size_t
>
shape
;
std
::
vector
<
size_t
>
shape
;
for
(
auto
s
:
t
.
sizes
())
{
for
(
auto
s
:
t
.
sizes
())
{
shape
.
push_back
(
s
);
shape
.
push_back
(
s
);
...
@@ -286,7 +286,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape) {
...
@@ -286,7 +286,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape) {
return
std
::
vector
<
size_t
>
(
shape
.
data
,
shape
.
data
+
shape
.
ndim
);
return
std
::
vector
<
size_t
>
(
shape
.
data
,
shape
.
data
+
shape
.
ndim
);
}
}
in
t
roundup
(
const
in
t
value
,
const
in
t
multiple
)
{
size_
t
roundup
(
const
size_
t
value
,
const
size_
t
multiple
)
{
assert
(
multiple
>
0
);
assert
(
multiple
>
0
);
return
((
value
+
multiple
-
1
)
/
multiple
)
*
multiple
;
return
((
value
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
}
...
...
transformer_engine/pytorch/csrc/common.h
View file @
87e3e56e
...
@@ -116,9 +116,21 @@ class Quantizer {
...
@@ -116,9 +116,21 @@ class Quantizer {
virtual
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
=
0
;
virtual
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
=
0
;
virtual
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
/*! @brief Construct a tensor with uninitialized data */
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
virtual
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
=
0
;
DType
dtype
)
const
=
0
;
/*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor
*
* The PyTorch tensor's attributes are modified to match the
* quantizer's configuration.
*/
virtual
std
::
pair
<
TensorWrapper
,
py
::
object
>
convert_and_update_tensor
(
py
::
object
tensor
)
const
=
0
;
/*! @brief Convert to a quantized data format */
virtual
void
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
)
=
0
;
virtual
~
Quantizer
()
=
default
;
virtual
~
Quantizer
()
=
default
;
...
@@ -139,9 +151,17 @@ class NoneQuantizer : public Quantizer {
...
@@ -139,9 +151,17 @@ class NoneQuantizer : public Quantizer {
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
override
{}
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
override
{}
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
DType
dtype
)
const
override
;
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
/*! @brief Construct a tensor with pre-initialized data */
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
at
::
Tensor
data
)
const
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
convert_and_update_tensor
(
py
::
object
tensor
)
const
override
;
void
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
)
override
;
};
};
class
Float8Quantizer
:
public
Quantizer
{
class
Float8Quantizer
:
public
Quantizer
{
...
@@ -157,9 +177,19 @@ class Float8Quantizer : public Quantizer {
...
@@ -157,9 +177,19 @@ class Float8Quantizer : public Quantizer {
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
override
;
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
DType
dtype
)
const
override
;
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
/*! @brief Construct a tensor with pre-initialized data */
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
data
,
std
::
optional
<
at
::
Tensor
>
transpose
,
std
::
optional
<
at
::
Tensor
>
scale_inv
)
const
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
convert_and_update_tensor
(
py
::
object
shape
)
const
override
;
void
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
)
override
;
};
};
class
Float8CurrentScalingQuantizer
:
public
Quantizer
{
class
Float8CurrentScalingQuantizer
:
public
Quantizer
{
...
@@ -179,9 +209,29 @@ class Float8CurrentScalingQuantizer : public Quantizer {
...
@@ -179,9 +209,29 @@ class Float8CurrentScalingQuantizer : public Quantizer {
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
override
;
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
DType
dtype
)
const
override
;
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
/*! @brief Construct a high precision tensor giving it this quantizer's amax
Note: this member function also zeros out the amax, as it is meant to be used in conjunction with
a kernel computing the amax, which might expect the amax to be initialized to zero
*/
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_hp_tensor_with_amax
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
);
std
::
pair
<
TensorWrapper
,
py
::
object
>
convert_and_update_tensor
(
py
::
object
shape
)
const
override
;
void
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
)
override
;
/*! @brief Convert to a quantized data format avoiding amax computation */
void
quantize_with_amax
(
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
);
private:
void
quantize_impl
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
,
bool
compute_amax
);
};
};
class
Float8BlockQuantizer
:
public
Quantizer
{
class
Float8BlockQuantizer
:
public
Quantizer
{
...
@@ -213,9 +263,13 @@ class Float8BlockQuantizer : public Quantizer {
...
@@ -213,9 +263,13 @@ class Float8BlockQuantizer : public Quantizer {
// Create a python Float8BlockQuantized tensor and C++ wrapper
// Create a python Float8BlockQuantized tensor and C++ wrapper
// for the tensor. Should set quantized data, scales for rowwise
// for the tensor. Should set quantized data, scales for rowwise
// and optionally columnwise usage.
// and optionally columnwise usage.
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
DType
dtype
)
const
override
;
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
convert_and_update_tensor
(
py
::
object
shape
)
const
override
;
void
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
)
override
;
std
::
vector
<
size_t
>
get_scale_shape
(
const
std
::
vector
<
size_t
>&
shape
,
bool
columnwise
)
const
;
std
::
vector
<
size_t
>
get_scale_shape
(
const
std
::
vector
<
size_t
>&
shape
,
bool
columnwise
)
const
;
};
};
...
@@ -230,16 +284,20 @@ class MXFP8Quantizer : public Quantizer {
...
@@ -230,16 +284,20 @@ class MXFP8Quantizer : public Quantizer {
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
override
;
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
DType
dtype
)
const
override
;
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
convert_and_update_tensor
(
py
::
object
shape
)
const
override
;
void
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
)
override
;
std
::
vector
<
size_t
>
get_scale_shape
(
const
std
::
vector
<
size_t
>&
shape
,
bool
columnwise
)
const
;
std
::
vector
<
size_t
>
get_scale_shape
(
const
std
::
vector
<
size_t
>&
shape
,
bool
columnwise
)
const
;
};
};
std
::
unique_ptr
<
Quantizer
>
convert_quantizer
(
py
::
handle
quantizer
);
std
::
unique_ptr
<
Quantizer
>
convert_quantizer
(
py
::
handle
quantizer
);
std
::
vector
<
size_t
>
getTensorShape
(
at
::
Tensor
t
);
std
::
vector
<
size_t
>
getTensorShape
(
const
at
::
Tensor
&
t
);
transformer_engine
::
DType
getTransformerEngineFP8Type
(
bool
e4m3_if_hybrid
,
transformer_engine
::
DType
getTransformerEngineFP8Type
(
bool
e4m3_if_hybrid
,
const
std
::
string
&
fp8_recipe
);
const
std
::
string
&
fp8_recipe
);
...
@@ -382,7 +440,7 @@ void* getDataPtr(at::Tensor tensor, int offset = 0);
...
@@ -382,7 +440,7 @@ void* getDataPtr(at::Tensor tensor, int offset = 0);
std
::
vector
<
size_t
>
convertShape
(
const
NVTEShape
&
shape
);
std
::
vector
<
size_t
>
convertShape
(
const
NVTEShape
&
shape
);
in
t
roundup
(
const
in
t
value
,
const
in
t
multiple
);
size_
t
roundup
(
const
size_
t
value
,
const
size_
t
multiple
);
NVTEShape
convertTorchShape
(
const
c10
::
IntArrayRef
torch_shape
);
NVTEShape
convertTorchShape
(
const
c10
::
IntArrayRef
torch_shape
);
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
...
...
transformer_engine/pytorch/csrc/extensions.h
View file @
87e3e56e
...
@@ -11,6 +11,10 @@
...
@@ -11,6 +11,10 @@
#include "common.h"
#include "common.h"
class
CommOverlapHelper
;
class
CommOverlap
;
class
CommOverlapP2P
;
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
/***************************************************************************************************
/***************************************************************************************************
...
@@ -118,7 +122,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
...
@@ -118,7 +122,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
=
nullptr
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
=
nullptr
,
std
::
optional
<
CommOverlapType
>
comm_type
=
std
::
nullopt
,
std
::
optional
<
CommOverlapType
>
comm_type
=
std
::
nullopt
,
MaybeTensor
extra_output
=
std
::
nullopt
,
bool
bulk_overlap
=
false
);
MaybeTensor
extra_output
=
std
::
nullopt
,
bool
bulk_overlap
=
false
,
float
alpha
=
1.0
f
,
std
::
optional
<
float
>
beta
=
std
::
nullopt
);
void
te_atomic_gemm
(
at
::
Tensor
A
,
at
::
Tensor
A_scale_inverse
,
DType
A_type
,
void
te_atomic_gemm
(
at
::
Tensor
A
,
at
::
Tensor
A_scale_inverse
,
DType
A_type
,
std
::
vector
<
int64_t
>
A_scaling_mode
,
bool
transa
,
at
::
Tensor
B
,
std
::
vector
<
int64_t
>
A_scaling_mode
,
bool
transa
,
at
::
Tensor
B
,
...
@@ -179,6 +184,8 @@ std::vector<at::Tensor> te_batchgemm_ts(
...
@@ -179,6 +184,8 @@ std::vector<at::Tensor> te_batchgemm_ts(
at
::
Tensor
fp8_transpose
(
at
::
Tensor
input
,
DType
otype
,
at
::
Tensor
fp8_transpose
(
at
::
Tensor
input
,
DType
otype
,
std
::
optional
<
at
::
Tensor
>
output
=
std
::
nullopt
);
std
::
optional
<
at
::
Tensor
>
output
=
std
::
nullopt
);
at
::
Tensor
swap_first_dims
(
at
::
Tensor
tensor
,
std
::
optional
<
at
::
Tensor
>
out
=
std
::
nullopt
);
/***************************************************************************************************
/***************************************************************************************************
* Activations
* Activations
**************************************************************************************************/
**************************************************************************************************/
...
@@ -455,6 +462,13 @@ void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_k
...
@@ -455,6 +462,13 @@ void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_k
void
nvshmem_finalize
();
void
nvshmem_finalize
();
/***************************************************************************************************
* Comm+GEMM Overlap Wrappers
**************************************************************************************************/
void
bulk_overlap_ag_with_external_gemm
(
CommOverlap
&
allgather_communicator
,
at
::
Stream
send_stream
,
at
::
Stream
recv_stream
);
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
/***************************************************************************************************
/***************************************************************************************************
...
@@ -504,7 +518,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
...
@@ -504,7 +518,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
at
::
Tensor
get_buffer
(
bool
local_chunk
=
false
,
at
::
Tensor
get_buffer
(
bool
local_chunk
=
false
,
std
::
optional
<
std
::
vector
<
int64_t
>>
shape
=
std
::
nullopt
);
std
::
optional
<
std
::
vector
<
int64_t
>>
shape
=
std
::
nullopt
);
at
::
Stream
get_communication_stream
();
std
::
pair
<
at
::
Stream
,
at
::
Stream
>
get_communication_stream
();
};
// CommOverlap
};
// CommOverlap
...
@@ -525,7 +539,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
...
@@ -525,7 +539,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
at
::
Tensor
get_buffer
(
bool
local_chunk
=
false
,
at
::
Tensor
get_buffer
(
bool
local_chunk
=
false
,
std
::
optional
<
std
::
vector
<
int64_t
>>
shape
=
std
::
nullopt
);
std
::
optional
<
std
::
vector
<
int64_t
>>
shape
=
std
::
nullopt
);
at
::
Stream
get_communication_stream
();
std
::
pair
<
at
::
Stream
,
at
::
Stream
>
get_communication_stream
();
};
// CommOverlapP2P
};
// CommOverlapP2P
...
...
transformer_engine/pytorch/csrc/extensions/activation.cpp
View file @
87e3e56e
...
@@ -13,87 +13,92 @@ namespace transformer_engine::pytorch {
...
@@ -13,87 +13,92 @@ namespace transformer_engine::pytorch {
template
<
void
(
*
act_func
)(
const
NVTETensor
,
NVTETensor
,
cudaStream_t
)>
template
<
void
(
*
act_func
)(
const
NVTETensor
,
NVTETensor
,
cudaStream_t
)>
py
::
object
activation_helper
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
,
int
shape_divisor
=
1
)
{
py
::
object
activation_helper
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
,
int
shape_divisor
=
1
)
{
init_extension
();
init_extension
();
auto
my_quantizer
=
convert_quantizer
(
quantizer
);
auto
input_tensor
=
input
.
contiguous
();
const
TensorWrapper
&
te_input
=
makeTransformerEngineTensor
(
input_tensor
);
const
auto
&
te_input_shape
=
te_input
.
shape
();
std
::
vector
<
size_t
>
input_shape
(
te_input_shape
.
data
,
te_input_shape
.
data
+
te_input_shape
.
ndim
);
input_shape
[
input_shape
.
size
()
-
1
]
/=
shape_divisor
;
auto
fake_tensor_type
=
input
.
scalar_type
();
auto
[
te_output
,
out
]
=
my_quantizer
->
create_tensor
(
input_shape
,
GetTransformerEngineDType
(
fake_tensor_type
));
// for current scaling, we need to compute amax first and then quantize
// because cache cannot fit in the entire tensor to compute amax and quantize
// the quantizer should not need amax reduction, no process group needed here
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// activation function might change the input data range, we need to first call the activation function
// and then find the amax and scale of that and then do the quantization
// get a NoneQuantizer to calculate amax of activation output
auto
my_quantizer_none
=
std
::
make_unique
<
NoneQuantizer
>
(
py
::
none
());
auto
[
te_output_act
,
out_act
]
=
my_quantizer_none
->
create_tensor
(
input_shape
,
GetTransformerEngineDType
(
fake_tensor_type
));
NVTE_SCOPED_GIL_RELEASE
({
act_func
(
te_input
.
data
(),
te_output_act
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
// use te_output_act as input to the compute amax and find the amax of activated tensor
nvte_compute_amax
(
te_output_act
.
data
(),
te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// my_quantizer here has to be a Float8CurrentScalingQuantizer
// Input tensor
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
auto
input_tensor
=
input
.
contiguous
();
if
(
my_quantizer_cs
->
with_amax_reduction
)
{
const
TensorWrapper
&
input_cpp
=
makeTransformerEngineTensor
(
input_tensor
);
NVTE_ERROR
(
"per-tensor current scaling amax reduction is not supported in activation functions."
);
// Construct output tensor
}
auto
quantizer_cpp
=
convert_quantizer
(
quantizer
);
QuantizationConfigWrapper
quant_config
;
const
auto
input_shape
=
input_cpp
.
shape
();
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
std
::
vector
<
size_t
>
output_shape
(
input_shape
.
data
,
input_shape
.
data
+
input_shape
.
ndim
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
output_shape
.
back
()
/=
shape_divisor
;
auto
fake_dtype
=
GetTransformerEngineDType
(
input_tensor
.
scalar_type
());
auto
[
out_cpp
,
out_py
]
=
quantizer_cpp
->
create_tensor
(
output_shape
,
fake_dtype
);
// Compute activation
if
(
quantizer
.
is_none
()
||
detail
::
IsFloat8Quantizers
(
quantizer
.
ptr
())
||
detail
::
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
// Compute activation directly
NVTE_SCOPED_GIL_RELEASE
(
{
act_func
(
input_cpp
.
data
(),
out_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// Compute activation in high-precision fused together with amax, then quantize.
NVTE_SCOPED_GIL_RELEASE
({
auto
quantizer_cpp_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
nvte_compute_scale_from_amax
(
te_output
.
data
(),
quant_config
,
auto
[
temp_cpp
,
_
]
=
quantizer_cpp_cs
->
create_hp_tensor_with_amax
(
output_shape
,
fake_dtype
);
at
::
cuda
::
getCurrentCUDAStream
());
NVTE_SCOPED_GIL_RELEASE
(
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
{
act_func
(
input_cpp
.
data
(),
temp_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
te_output
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
te_output
.
defaultShape
);
quantizer_cpp_cs
->
quantize_with_amax
(
temp_cpp
,
out_cpp
);
nvte_quantize_v2
(
te_output_act
.
data
(),
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
else
if
(
detail
::
IsFloat8BlockwiseQuantizers
(
quantizer
.
ptr
()))
{
// sanity check, since activation fusion is not supported for blockwise quantization yet
// need to raise an error here instead of silently going into act_func with wrong numerics
NVTE_ERROR
(
"Activation fusion is not supported for blockwise quantization yet."
);
}
else
{
}
else
{
// Compute activation in high-precision, then quantize
auto
[
temp_cpp
,
_
]
=
NoneQuantizer
(
py
::
none
()).
create_tensor
(
output_shape
,
fake_dtype
);
NVTE_SCOPED_GIL_RELEASE
(
NVTE_SCOPED_GIL_RELEASE
(
{
act_func
(
te_input
.
data
(),
te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
{
act_func
(
input_cpp
.
data
(),
temp_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
quantizer_cpp
->
quantize
(
temp_cpp
,
out_cpp
);
}
}
return
out
;
return
out
_py
;
}
}
template
<
void
(
*
act_func
)(
const
NVTETensor
,
const
NVTETensor
,
NVTETensor
,
cudaStream_t
)>
template
<
void
(
*
d
act_func
)(
const
NVTETensor
,
const
NVTETensor
,
NVTETensor
,
cudaStream_t
)>
py
::
object
dactivation_helper
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
py
::
object
dactivation_helper
(
const
at
::
Tensor
&
grad
_output
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
handle
quantizer
)
{
init_extension
();
init_extension
();
auto
my_quantizer
=
convert_quantizer
(
quantizer
);
auto
input_tensor
=
input
.
contiguous
();
auto
grad_tensor
=
grad
.
contiguous
();
const
TensorWrapper
&
te_input
=
makeTransformerEngineTensor
(
input_tensor
);
// Grad output and input tensors
const
TensorWrapper
&
te_grad
=
makeTransformerEngineTensor
(
grad_tensor
);
auto
grad_output_tensor
=
grad_output
.
contiguous
();
const
auto
&
te_input_shape
=
te_input
.
shape
();
auto
input_tensor
=
input
.
contiguous
();
std
::
vector
<
size_t
>
input_shape
(
te_input_shape
.
data
,
te_input_shape
.
data
+
te_input_shape
.
ndim
);
const
TensorWrapper
&
grad_output_cpp
=
makeTransformerEngineTensor
(
grad_output_tensor
);
auto
fake_tensor_type
=
input
.
scalar_type
();
const
TensorWrapper
&
input_cpp
=
makeTransformerEngineTensor
(
input_tensor
);
auto
[
te_output
,
out
]
=
// Construct grad input tensor
my_quantizer
->
create_tensor
(
input_shape
,
GetTransformerEngineDType
(
fake_tensor_type
));
auto
quantizer_cpp
=
convert_quantizer
(
quantizer
);
const
auto
input_shape_te
=
input_cpp
.
shape
();
NVTE_SCOPED_GIL_RELEASE
({
const
std
::
vector
<
size_t
>
input_shape
(
input_shape_te
.
data
,
act_func
(
te_grad
.
data
(),
te_input
.
data
(),
te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
input_shape_te
.
data
+
input_shape_te
.
ndim
);
});
auto
fake_dtype
=
GetTransformerEngineDType
(
input_tensor
.
scalar_type
());
auto
[
grad_input_cpp
,
grad_input_py
]
=
quantizer_cpp
->
create_tensor
(
input_shape
,
fake_dtype
);
// Compute activation backward
if
(
quantizer
.
is_none
()
||
detail
::
IsFloat8Quantizers
(
quantizer
.
ptr
())
||
detail
::
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
// Compute activation backward directly
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_cpp
.
data
(),
input_cpp
.
data
(),
grad_input_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// Compute activation backward in high-precision fused together with amax, then quantize.
auto
quantizer_cpp_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
auto
[
temp_cpp
,
_
]
=
quantizer_cpp_cs
->
create_hp_tensor_with_amax
(
input_shape
,
fake_dtype
);
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_cpp
.
data
(),
input_cpp
.
data
(),
temp_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
quantizer_cpp_cs
->
quantize_with_amax
(
temp_cpp
,
grad_input_cpp
);
}
else
{
// Compute activation backward in high-precision, then quantize
auto
[
temp_cpp
,
_
]
=
NoneQuantizer
(
py
::
none
()).
create_tensor
(
input_shape
,
fake_dtype
);
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_cpp
.
data
(),
input_cpp
.
data
(),
temp_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
quantizer_cpp
->
quantize
(
temp_cpp
,
grad_input_cpp
);
}
return
out
;
return
grad_input_py
;
}
}
py
::
object
gelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
py
::
object
gelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
...
...
transformer_engine/pytorch/csrc/extensions/attention.cpp
View file @
87e3e56e
...
@@ -18,7 +18,7 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
...
@@ -18,7 +18,7 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
auto
max_tokens
=
shape
[
0
];
auto
max_tokens
=
shape
[
0
];
auto
fcd_size
=
1
;
auto
fcd_size
=
1
;
for
(
in
t
i
=
1
;
i
<=
shape
.
size
();
i
++
)
{
for
(
size_
t
i
=
1
;
i
<=
shape
.
size
();
i
++
)
{
fcd_size
*=
shape
[
i
];
fcd_size
*=
shape
[
i
];
}
}
...
@@ -110,8 +110,20 @@ std::vector<py::object> fused_attn_fwd(
...
@@ -110,8 +110,20 @@ std::vector<py::object> fused_attn_fwd(
auto
o_shape
=
std
::
vector
<
size_t
>
{
q_shape
.
begin
(),
q_shape
.
end
()};
auto
o_shape
=
std
::
vector
<
size_t
>
{
q_shape
.
begin
(),
q_shape
.
end
()};
o_shape
[
o_shape
.
size
()
-
1
]
=
v_shape
[
v_shape
.
size
()
-
1
];
o_shape
[
o_shape
.
size
()
-
1
]
=
v_shape
[
v_shape
.
size
()
-
1
];
py
::
object
o_python
,
s_python
;
py
::
object
o_python
,
s_python
;
std
::
tie
(
te_O
,
o_python
)
=
O_quantizer
->
create_tensor
(
o_shape
,
fake_dtype_te
);
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer
->
create_tensor
({
0
},
DType
::
kFloat32
);
// Initialize FP8 tensor with scale-inverse
auto
*
O_quantizer_fp8
=
dynamic_cast
<
Float8Quantizer
*>
(
O_quantizer
.
get
());
auto
*
S_quantizer_fp8
=
dynamic_cast
<
Float8Quantizer
*>
(
S_quantizer
.
get
());
NVTE_CHECK
(
O_quantizer_fp8
!=
nullptr
,
"Expected Float8Quantizer when dtype is FP8"
);
NVTE_CHECK
(
S_quantizer_fp8
!=
nullptr
,
"Expected Float8Quantizer when dtype is FP8"
);
std
::
tie
(
te_O
,
o_python
)
=
O_quantizer_fp8
->
create_tensor
(
o_shape
,
fake_dtype_te
,
std
::
nullopt
,
std
::
nullopt
,
std
::
nullopt
);
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer_fp8
->
create_tensor
({
0
},
DType
::
kFloat32
,
std
::
nullopt
,
std
::
nullopt
,
std
::
nullopt
);
}
else
{
std
::
tie
(
te_O
,
o_python
)
=
O_quantizer
->
create_tensor
(
o_shape
,
fake_dtype_te
);
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer
->
create_tensor
({
0
},
DType
::
kFloat32
);
}
auto
o_shape_int64
=
std
::
vector
<
int64_t
>
{
o_shape
.
begin
(),
o_shape
.
end
()};
auto
o_shape_int64
=
std
::
vector
<
int64_t
>
{
o_shape
.
begin
(),
o_shape
.
end
()};
// construct NVTE tensors
// construct NVTE tensors
...
@@ -295,8 +307,20 @@ std::vector<py::object> fused_attn_bwd(
...
@@ -295,8 +307,20 @@ std::vector<py::object> fused_attn_bwd(
py
::
object
s_python
,
dp_python
;
py
::
object
s_python
,
dp_python
;
std
::
unique_ptr
<
Quantizer
>
S_quantizer
=
convert_quantizer
(
s_quantizer
);
std
::
unique_ptr
<
Quantizer
>
S_quantizer
=
convert_quantizer
(
s_quantizer
);
std
::
unique_ptr
<
Quantizer
>
dP_quantizer
=
convert_quantizer
(
dp_quantizer
);
std
::
unique_ptr
<
Quantizer
>
dP_quantizer
=
convert_quantizer
(
dp_quantizer
);
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer
->
create_tensor
({
0
},
DType
::
kFloat32
);
std
::
tie
(
te_dP
,
dp_python
)
=
dP_quantizer
->
create_tensor
({
0
},
DType
::
kFloat32
);
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
auto
*
S_quantizer_fp8
=
dynamic_cast
<
Float8Quantizer
*>
(
S_quantizer
.
get
());
auto
*
dP_quantizer_fp8
=
dynamic_cast
<
Float8Quantizer
*>
(
dP_quantizer
.
get
());
NVTE_CHECK
(
S_quantizer_fp8
!=
nullptr
,
"Expected Float8Quantizer when dtype is FP8"
);
NVTE_CHECK
(
dP_quantizer_fp8
!=
nullptr
,
"Expected Float8Quantizer when dtype is FP8"
);
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer_fp8
->
create_tensor
({
0
},
DType
::
kFloat32
,
std
::
nullopt
,
std
::
nullopt
,
std
::
nullopt
);
std
::
tie
(
te_dP
,
dp_python
)
=
dP_quantizer_fp8
->
create_tensor
({
0
},
DType
::
kFloat32
,
std
::
nullopt
,
std
::
nullopt
,
std
::
nullopt
);
}
else
{
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer
->
create_tensor
({
0
},
DType
::
kFloat32
);
std
::
tie
(
te_dP
,
dp_python
)
=
dP_quantizer
->
create_tensor
({
0
},
DType
::
kFloat32
);
}
std
::
vector
<
size_t
>
q_shape
=
convertShape
(
te_Q
.
shape
());
std
::
vector
<
size_t
>
q_shape
=
convertShape
(
te_Q
.
shape
());
std
::
vector
<
size_t
>
k_shape
=
convertShape
(
te_K
.
shape
());
std
::
vector
<
size_t
>
k_shape
=
convertShape
(
te_K
.
shape
());
...
@@ -385,9 +409,22 @@ std::vector<py::object> fused_attn_bwd(
...
@@ -385,9 +409,22 @@ std::vector<py::object> fused_attn_bwd(
default:
default:
NVTE_ERROR
(
"QKV layout not supported!"
);
NVTE_ERROR
(
"QKV layout not supported!"
);
}
}
std
::
tie
(
te_dQ
,
py_dQ
)
=
dQKV_quantizer
->
create_tensor
(
q_shape
,
fake_dtype_te
,
dQ
);
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
std
::
tie
(
te_dK
,
py_dK
)
=
dQKV_quantizer
->
create_tensor
(
k_shape
,
fake_dtype_te
,
dK
);
auto
*
fp8_quantizer
=
dynamic_cast
<
Float8Quantizer
*>
(
dQKV_quantizer
.
get
());
std
::
tie
(
te_dV
,
py_dV
)
=
dQKV_quantizer
->
create_tensor
(
v_shape
,
fake_dtype_te
,
dV
);
NVTE_CHECK
(
fp8_quantizer
!=
nullptr
,
"Expected Float8Quantizer when dtype is FP8"
);
std
::
tie
(
te_dQ
,
py_dQ
)
=
fp8_quantizer
->
create_tensor
(
q_shape
,
fake_dtype_te
,
dQ
,
std
::
nullopt
,
std
::
nullopt
);
std
::
tie
(
te_dK
,
py_dK
)
=
fp8_quantizer
->
create_tensor
(
k_shape
,
fake_dtype_te
,
dK
,
std
::
nullopt
,
std
::
nullopt
);
std
::
tie
(
te_dV
,
py_dV
)
=
fp8_quantizer
->
create_tensor
(
v_shape
,
fake_dtype_te
,
dV
,
std
::
nullopt
,
std
::
nullopt
);
}
else
{
auto
*
none_quantizer
=
dynamic_cast
<
NoneQuantizer
*>
(
dQKV_quantizer
.
get
());
NVTE_CHECK
(
none_quantizer
!=
nullptr
,
"Expected NoneQuantizer when dtype is not FP8"
);
std
::
tie
(
te_dQ
,
py_dQ
)
=
none_quantizer
->
create_tensor
(
q_shape
,
fake_dtype_te
,
dQ
);
std
::
tie
(
te_dK
,
py_dK
)
=
none_quantizer
->
create_tensor
(
k_shape
,
fake_dtype_te
,
dK
);
std
::
tie
(
te_dV
,
py_dV
)
=
none_quantizer
->
create_tensor
(
v_shape
,
fake_dtype_te
,
dV
);
}
// construct NVTE tensors
// construct NVTE tensors
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
...
...
transformer_engine/pytorch/csrc/extensions/bias.cpp
View file @
87e3e56e
...
@@ -4,80 +4,223 @@
...
@@ -4,80 +4,223 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include <ATen/ATen.h>
#include <pybind11/pybind11.h>
#include <utility>
#include <vector>
#include "common.h"
#include "common.h"
#include "extensions.h"
#include "pybind.h"
#include "pybind.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transformer_engine.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
{
namespace
pytorch
{
std
::
vector
<
py
::
object
>
bgrad_quantize
(
const
at
::
Tensor
&
input
,
py
::
handle
py_quantizer
)
{
std
::
vector
<
py
::
object
>
bgrad_quantize
(
const
at
::
Tensor
&
grad_output
,
py
::
handle
quantizer
)
{
auto
quantizer
=
convert_quantizer
(
py_quantizer
);
using
namespace
transformer_engine
::
pytorch
::
detail
;
init_extension
();
auto
input_tensor
=
makeTransformerEngineTensor
(
input
);
// Grad output tensor
auto
grad_output_torch
=
grad_output
.
contiguous
();
const
TensorWrapper
&
grad_output_nvte
=
makeTransformerEngineTensor
(
grad_output_torch
);
const
auto
shape
=
getTensorShape
(
grad_output_torch
);
auto
grad_output_dtype
=
GetTransformerEngineDType
(
grad_output_torch
.
scalar_type
());
auto
dbias
=
allocateTorchTensor
(
input
.
size
(
-
1
),
input_tensor
.
dtype
());
// Construct grad bias tensor
const
int64_t
bias_size
=
static_cast
<
int64_t
>
(
shape
.
back
());
auto
grad_bias_torch
=
allocateTorchTensor
(
bias_size
,
grad_output_dtype
);
auto
grad_bias_nvte
=
makeTransformerEngineTensor
(
grad_bias_torch
);
std
::
vector
<
size_t
>
output_shape
;
// Unquantized impl only requires computing grad bias
for
(
auto
s
:
input
.
sizes
())
{
if
(
quantizer
.
is_none
())
{
output_shape
.
emplace_back
(
static_cast
<
size_t
>
(
s
));
if
(
product
(
shape
)
==
0
)
{
grad_bias_torch
.
zero_
();
}
else
{
at
::
sum_out
(
grad_bias_torch
,
grad_output_torch
.
reshape
({
-
1
,
bias_size
}),
{
0
});
}
return
{
py
::
cast
(
std
::
move
(
grad_bias_torch
)),
py
::
cast
(
std
::
move
(
grad_output_torch
))};
}
}
auto
[
out_tensor
,
out
]
=
quantizer
->
create_tensor
(
output_shape
,
input_tensor
.
dtype
());
// Return immediately if tensors are empty
// Construct grad input tensor
if
(
product
(
output_shape
)
==
0
)
{
auto
quantizer_cpp
=
convert_quantizer
(
quantizer
);
return
{
py
::
cast
(
dbias
.
zero_
()),
out
};
auto
[
grad_input_nvte
,
grad_input_py
]
=
quantizer_cpp
->
create_tensor
(
shape
,
grad_output_dtype
);
// Trivial impl if tensors are empty
if
(
product
(
shape
)
==
0
)
{
grad_bias_torch
.
zero_
();
return
{
py
::
cast
(
std
::
move
(
grad_bias_torch
)),
std
::
move
(
grad_input_py
)};
}
// Unfused impl if quantizer is not supported
const
bool
with_fused_dbias_quantize_kernel
=
detail
::
IsFloat8Quantizers
(
quantizer
.
ptr
())
||
detail
::
IsMXFP8Quantizers
(
quantizer
.
ptr
());
if
(
!
with_fused_dbias_quantize_kernel
)
{
at
::
sum_out
(
grad_bias_torch
,
grad_output_torch
.
reshape
({
-
1
,
bias_size
}),
{
0
});
quantizer_cpp
->
quantize
(
grad_output_nvte
,
grad_input_nvte
);
return
{
py
::
cast
(
std
::
move
(
grad_bias_torch
)),
std
::
move
(
grad_input_py
)};
}
}
auto
dbias_tensor
=
makeTransformerEngineTensor
(
dbias
);
// Query workspace size
// Query workspace size and allocate workspace
TensorWrapper
workspace_nvte
;
transformer_engine
::
TensorWrapper
workspace
;
at
::
Tensor
workspace_torch
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_dbias
(
input_tensor
.
data
(),
out_tensor
.
data
(),
dbias_tensor
.
data
(),
nvte_quantize_dbias
(
grad_output_nvte
.
data
(),
grad_input_nvte
.
data
(),
grad_bias_nvte
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAS
tream
()
);
workspace
_nvte
.
data
(),
s
tream
);
});
});
void
*
workspace_data_ptr
=
nullptr
;
// Allocate workspace
if
(
workspace
.
shape
().
ndim
>
0
)
{
if
(
workspace_nvte
.
ndim
()
>
0
&&
workspace_nvte
.
numel
()
>
0
)
{
auto
workspace_data
=
allocateSpace
(
workspace
.
shape
(),
workspace
.
dtype
());
workspace_torch
=
allocateSpace
(
workspace_nvte
.
shape
(),
workspace_nvte
.
dtype
());
workspace_data_ptr
=
workspace_data
.
data_ptr
();
workspace_nvte
=
makeTransformerEngineTensor
(
workspace_torch
.
data_ptr
(),
workspace_nvte
.
shape
(),
}
workspace_nvte
.
dtype
());
workspace
=
makeTransformerEngineTensor
(
workspace_data_ptr
,
workspace
.
shape
(),
workspace
.
dtype
());
// Launch kernel
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
py_quantizer
.
ptr
()))
{
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer
.
get
());
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_amax
(
input_tensor
.
data
(),
out_tensor
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// check if we need to do amax reudction (depending on model parallel configs)
if
(
my_quantizer_cs
->
with_amax_reduction
)
{
c10
::
intrusive_ptr
<
dist_group_type
>
process_group_ptr
=
my_quantizer_cs
->
amax_reduction_group
;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at
::
Tensor
&
amax_tensor_torch
=
my_quantizer_cs
->
amax
;
std
::
vector
<
at
::
Tensor
>
tensors
=
{
amax_tensor_torch
};
// allreduce amax tensor
c10d
::
AllreduceOptions
allreduce_opts
;
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
}
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_scale_from_amax
(
out_tensor
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_tensor
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
out_tensor
.
defaultShape
);
}
}
// Launch fused kernel
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_dbias
(
input_tensor
.
data
(),
out_tensor
.
data
(),
dbias_tensor
.
data
(),
nvte_quantize_dbias
(
grad_output_nvte
.
data
(),
grad_input_nvte
.
data
(),
grad_bias_nvte
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAS
tream
()
);
workspace
_nvte
.
data
(),
s
tream
);
});
});
return
{
py
::
cast
(
dbias
),
out
};
return
{
py
::
cast
(
std
::
move
(
grad_bias_torch
)),
std
::
move
(
grad_input_py
)};
}
namespace
{
std
::
vector
<
py
::
object
>
dact_dbias
(
void
(
*
dact_dbias_func
)(
const
NVTETensor
,
const
NVTETensor
,
NVTETensor
,
NVTETensor
,
NVTETensor
,
cudaStream_t
),
void
(
*
dact_func
)(
const
NVTETensor
,
const
NVTETensor
,
NVTETensor
,
cudaStream_t
),
at
::
Tensor
grad_output_torch
,
at
::
Tensor
act_input_torch
,
py
::
handle
quantizer_py
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
init_extension
();
// Grad output and activation input tensors
grad_output_torch
=
grad_output_torch
.
contiguous
();
const
TensorWrapper
&
grad_output_nvte
=
makeTransformerEngineTensor
(
grad_output_torch
);
const
auto
output_shape
=
getTensorShape
(
grad_output_torch
);
auto
grad_output_dtype
=
GetTransformerEngineDType
(
grad_output_torch
.
scalar_type
());
act_input_torch
=
act_input_torch
.
contiguous
();
const
TensorWrapper
&
act_input_nvte
=
makeTransformerEngineTensor
(
act_input_torch
);
const
auto
input_shape
=
getTensorShape
(
act_input_torch
);
// Construct tensors
auto
quantizer_cpp
=
convert_quantizer
(
quantizer_py
);
auto
[
grad_input_nvte
,
grad_input_py
]
=
quantizer_cpp
->
create_tensor
(
input_shape
,
grad_output_dtype
);
const
int64_t
bias_size
=
static_cast
<
int64_t
>
(
input_shape
.
back
());
auto
grad_bias_torch
=
allocateTorchTensor
(
bias_size
,
grad_output_dtype
);
auto
grad_bias_nvte
=
makeTransformerEngineTensor
(
grad_bias_torch
);
// Return immediately if tensors are empty
if
(
product
(
output_shape
)
==
0
)
{
grad_bias_torch
.
zero_
();
return
{
py
::
cast
(
std
::
move
(
grad_bias_torch
)),
std
::
move
(
grad_input_py
)};
}
// Choose implementation
enum
class
Impl
{
UNFUSED
,
FUSED_DACT_DBIAS_QUANTIZE
,
FUSED_DACT_AMAX
};
Impl
impl
=
Impl
::
UNFUSED
;
if
(
detail
::
IsFloat8Quantizers
(
quantizer_py
.
ptr
())
||
detail
::
IsMXFP8Quantizers
(
quantizer_py
.
ptr
()))
{
impl
=
Impl
::
FUSED_DACT_DBIAS_QUANTIZE
;
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer_py
.
ptr
()))
{
impl
=
Impl
::
FUSED_DACT_AMAX
;
}
// Perform compute
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
impl
)
{
case
Impl
::
UNFUSED
:
// Unfused dact, dbias, quantize
{
auto
[
temp_nvte
,
temp_py
]
=
NoneQuantizer
(
py
::
none
()).
create_tensor
(
input_shape
,
grad_output_dtype
);
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_nvte
.
data
(),
act_input_nvte
.
data
(),
temp_nvte
.
data
(),
stream
);
});
const
auto
temp_torch
=
temp_py
.
cast
<
at
::
Tensor
>
();
at
::
sum_out
(
grad_bias_torch
,
temp_torch
.
reshape
({
-
1
,
bias_size
}),
{
0
});
quantizer_cpp
->
quantize
(
temp_nvte
,
grad_input_nvte
);
break
;
}
case
Impl
::
FUSED_DACT_DBIAS_QUANTIZE
:
// Fused dact-dbias-quantize kernel
{
// Query workspace size
TensorWrapper
workspace_nvte
;
NVTE_SCOPED_GIL_RELEASE
({
dact_dbias_func
(
grad_output_nvte
.
data
(),
act_input_nvte
.
data
(),
grad_input_nvte
.
data
(),
grad_bias_nvte
.
data
(),
workspace_nvte
.
data
(),
stream
);
});
// Allocate workspace
at
::
Tensor
workspace_torch
;
if
(
workspace_nvte
.
ndim
()
>
0
&&
workspace_nvte
.
numel
()
>
0
)
{
workspace_torch
=
allocateSpace
(
workspace_nvte
.
shape
(),
workspace_nvte
.
dtype
());
workspace_nvte
=
makeTransformerEngineTensor
(
workspace_torch
.
data_ptr
(),
workspace_nvte
.
shape
(),
workspace_nvte
.
dtype
());
}
// Launch kernel
NVTE_SCOPED_GIL_RELEASE
({
dact_dbias_func
(
grad_output_nvte
.
data
(),
act_input_nvte
.
data
(),
grad_input_nvte
.
data
(),
grad_bias_nvte
.
data
(),
workspace_nvte
.
data
(),
stream
);
});
break
;
}
case
Impl
::
FUSED_DACT_AMAX
:
// Fused dact-amax kernel, unfused dbias and quantize
{
auto
*
quantizer_cpp_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
NVTE_CHECK
(
quantizer_cpp_cs
!=
nullptr
,
"Invalid quantizer for fused dact-amax kernel impl"
);
auto
[
temp_nvte
,
temp_py
]
=
quantizer_cpp_cs
->
create_hp_tensor_with_amax
(
input_shape
,
grad_output_dtype
);
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_nvte
.
data
(),
act_input_nvte
.
data
(),
temp_nvte
.
data
(),
stream
);
});
const
auto
temp_torch
=
temp_py
.
cast
<
at
::
Tensor
>
();
at
::
sum_out
(
grad_bias_torch
,
temp_torch
.
reshape
({
-
1
,
bias_size
}),
{
0
});
quantizer_cpp_cs
->
quantize_with_amax
(
temp_nvte
,
grad_input_nvte
);
break
;
}
default:
NVTE_ERROR
(
"Invalid implementation"
);
}
return
{
py
::
cast
(
std
::
move
(
grad_bias_torch
)),
std
::
move
(
grad_input_py
)};
}
}
// namespace
std
::
vector
<
py
::
object
>
dbias_dgelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dact_dbias
(
nvte_quantize_dbias_dgelu
,
nvte_dgelu
,
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_dsilu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dact_dbias
(
nvte_quantize_dbias_dsilu
,
nvte_dsilu
,
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_drelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dact_dbias
(
nvte_quantize_dbias_drelu
,
nvte_drelu
,
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_dqgelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dact_dbias
(
nvte_quantize_dbias_dqgelu
,
nvte_dqgelu
,
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_dsrelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dact_dbias
(
nvte_quantize_dbias_dsrelu
,
nvte_dsrelu
,
grad_output
,
act_input
,
quantizer
);
}
}
}
// namespace transformer_engine::pytorch
}
// namespace pytorch
}
// namespace transformer_engine
transformer_engine/pytorch/csrc/extensions/cast.cpp
View file @
87e3e56e
...
@@ -28,60 +28,6 @@ std::vector<size_t> get_tensor_shape(const TensorWrapper &tensor) {
...
@@ -28,60 +28,6 @@ std::vector<size_t> get_tensor_shape(const TensorWrapper &tensor) {
return
std
::
vector
<
size_t
>
(
shape
.
data
,
shape
.
data
+
shape
.
ndim
);
return
std
::
vector
<
size_t
>
(
shape
.
data
,
shape
.
data
+
shape
.
ndim
);
}
}
void
quantize_impl
(
const
TensorWrapper
&
input
,
py
::
handle
&
quantizer_py
,
std
::
unique_ptr
<
Quantizer
>
&
quantizer_cpp
,
TensorWrapper
&
output
,
TensorWrapper
&
noop_flag
)
{
// Check tensor dims
NVTE_CHECK
(
get_tensor_shape
(
input
)
==
get_tensor_shape
(
output
),
"Input tensor (shape="
,
get_tensor_shape
(
input
),
") and output tensor (shape="
,
get_tensor_shape
(
output
),
") do not match"
);
if
(
input
.
numel
()
==
0
)
{
return
;
}
// Recipe-specific configuration
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_noop_tensor
(
noop_flag
.
data
());
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer_py
.
ptr
()))
{
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
NVTE_SCOPED_GIL_RELEASE
(
{
nvte_compute_amax
(
input
.
data
(),
output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// check if we need to do amax reudction (depending on model parallel configs)
if
(
my_quantizer_cs
->
with_amax_reduction
)
{
c10
::
intrusive_ptr
<
dist_group_type
>
process_group_ptr
=
my_quantizer_cs
->
amax_reduction_group
;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at
::
Tensor
&
amax_tensor_torch
=
my_quantizer_cs
->
amax
;
std
::
vector
<
at
::
Tensor
>
tensors
=
{
amax_tensor_torch
};
// allreduce amax tensor
c10d
::
AllreduceOptions
allreduce_opts
;
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
}
// this config is used for cs scaling factor computation
// because compute scale is cannot be fused with quantize kernel
// so in nvte_quantize_v2 with current scaling, the quant config is not used again
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_scale_from_amax
(
output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// set amax ptr to null in output TensorWrapper to avoid atomic amax updates in kernel
output
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
output
.
defaultShape
);
}
else
if
(
detail
::
IsFloat8BlockwiseQuantizers
(
quantizer_py
.
ptr
()))
{
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
quantizer_cpp
.
get
());
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
if
(
my_quantizer_bw
->
all_gather_usage
)
{
quant_config
.
set_float8_block_scale_tensor_format
(
Float8BlockScaleTensorFormat
::
COMPACT
);
}
}
// Perform quantization
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
input
.
data
(),
output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
}
// namespace
}
// namespace
py
::
object
quantize
(
const
at
::
Tensor
&
tensor
,
py
::
handle
quantizer
,
const
py
::
object
&
output
,
py
::
object
quantize
(
const
at
::
Tensor
&
tensor
,
py
::
handle
quantizer
,
const
py
::
object
&
output
,
...
@@ -101,18 +47,17 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
...
@@ -101,18 +47,17 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
const
auto
fake_dtype
=
input_cpp
.
dtype
();
const
auto
fake_dtype
=
input_cpp
.
dtype
();
std
::
tie
(
output_cpp
,
output_py
)
=
quantizer_cpp
->
create_tensor
(
shape
,
fake_dtype
);
std
::
tie
(
output_cpp
,
output_py
)
=
quantizer_cpp
->
create_tensor
(
shape
,
fake_dtype
);
}
else
{
}
else
{
output_py
=
output
;
std
::
tie
(
output_cpp
,
output_py
)
=
quantizer_cpp
->
convert_and_update_tensor
(
output
);
output_cpp
=
makeTransformerEngineTensor
(
output_py
,
quantizer
);
}
}
// Initialize no-op flag
// Initialize no-op flag
TensorWrapper
noop_flag_cpp
;
std
::
optional
<
TensorWrapper
>
noop_flag_cpp
;
if
(
noop_flag
.
has_value
())
{
if
(
noop_flag
.
has_value
())
{
noop_flag_cpp
=
makeTransformerEngineTensor
(
*
noop_flag
);
noop_flag_cpp
=
makeTransformerEngineTensor
(
*
noop_flag
);
}
}
// Perform quantization
// Perform quantization
quantize
_impl
(
input
_cpp
,
quantize
r
,
quantizer
_cpp
,
output_cpp
,
noop_flag_cpp
);
quantize
r
_cpp
->
quantize
(
input
_cpp
,
output_cpp
,
noop_flag_cpp
);
return
output_py
;
return
output_py
;
}
}
...
@@ -182,10 +127,8 @@ void multi_tensor_quantize_impl(const std::vector<TensorWrapper> &input_list,
...
@@ -182,10 +127,8 @@ void multi_tensor_quantize_impl(const std::vector<TensorWrapper> &input_list,
});
});
}
else
{
}
else
{
// Quantize kernels individually
// Quantize kernels individually
TensorWrapper
dummy_noop_flag
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
quantize_impl
(
input_list
[
i
],
quantizer_py_list
[
i
],
quantizer_cpp_list
[
i
],
output_list
[
i
],
quantizer_cpp_list
[
i
]
->
quantize
(
input_list
[
i
],
output_list
[
i
]);
dummy_noop_flag
);
}
}
}
}
}
}
...
@@ -455,11 +398,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
...
@@ -455,11 +398,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
}
}
// Allocate full buffer
// Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto
buffer
=
std
::
make_shared
<
at
::
Tensor
>
(
auto
buffer
=
std
::
make_shared
<
at
::
Tensor
>
(
at
::
zeros
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
)));
at
::
empty
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
// Construct tensor views
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
...
@@ -498,11 +438,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
...
@@ -498,11 +438,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
}
}
// Allocate full buffer
// Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto
buffer
=
std
::
make_shared
<
at
::
Tensor
>
(
auto
buffer
=
std
::
make_shared
<
at
::
Tensor
>
(
at
::
zeros
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
)));
at
::
empty
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
// Construct tensor views
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
...
@@ -650,66 +587,5 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
...
@@ -650,66 +587,5 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
return
output_py_list
;
return
output_py_list
;
}
}
template
<
void
(
*
func
)(
const
NVTETensor
,
const
NVTETensor
,
NVTETensor
,
NVTETensor
,
NVTETensor
,
cudaStream_t
)>
std
::
vector
<
py
::
object
>
dbias_dact
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
init_extension
();
auto
my_quantizer
=
convert_quantizer
(
quantizer
);
auto
grad_tensor
=
makeTransformerEngineTensor
(
grad_output
);
auto
grad_bias
=
allocateTorchTensor
(
grad_output
.
size
(
-
1
),
grad_tensor
.
dtype
());
auto
act_input_tensor
=
makeTransformerEngineTensor
(
act_input
);
const
auto
&
shape
=
convertShape
(
grad_tensor
.
shape
());
auto
[
dact_tensor
,
dact
]
=
my_quantizer
->
create_tensor
(
shape
,
act_input_tensor
.
dtype
());
auto
dbias_tensor
=
makeTransformerEngineTensor
(
grad_bias
);
// Query workspace size and allocate workspace
transformer_engine
::
TensorWrapper
workspace
;
NVTE_SCOPED_GIL_RELEASE
({
func
(
grad_tensor
.
data
(),
act_input_tensor
.
data
(),
dact_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
auto
workspace_data
=
allocateSpace
(
workspace
.
shape
(),
workspace
.
dtype
());
workspace
=
makeTransformerEngineTensor
(
workspace_data
.
data_ptr
(),
workspace
.
shape
(),
workspace
.
dtype
());
// Launch kernel
NVTE_SCOPED_GIL_RELEASE
({
func
(
grad_tensor
.
data
(),
act_input_tensor
.
data
(),
dact_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
return
{
py
::
cast
(
grad_bias
),
dact
};
}
std
::
vector
<
py
::
object
>
dbias_dgelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dbias_dact
<
nvte_quantize_dbias_dgelu
>
(
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_dsilu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dbias_dact
<
nvte_quantize_dbias_dsilu
>
(
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_drelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dbias_dact
<
nvte_quantize_dbias_drelu
>
(
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_dqgelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dbias_dact
<
nvte_quantize_dbias_dqgelu
>
(
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_dsrelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dbias_dact
<
nvte_quantize_dbias_dsrelu
>
(
grad_output
,
act_input
,
quantizer
);
}
}
// namespace pytorch
}
// namespace pytorch
}
// namespace transformer_engine
}
// namespace transformer_engine
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
View file @
87e3e56e
...
@@ -216,8 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<i
...
@@ -216,8 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<i
return
torch
::
from_blob
(
ubuf_ptr
,
*
shape
,
at
::
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
return
torch
::
from_blob
(
ubuf_ptr
,
*
shape
,
at
::
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
}
}
at
::
Stream
CommOverlap
::
get_communication_stream
()
{
std
::
pair
<
at
::
Stream
,
at
::
Stream
>
CommOverlap
::
get_communication_stream
()
{
return
at
::
cuda
::
getStreamFromExternal
(
_stream_comm
,
at
::
cuda
::
current_device
());
// Return the same stream for both send and recv
return
{
at
::
cuda
::
getStreamFromExternal
(
_stream_comm
,
at
::
cuda
::
current_device
()),
at
::
cuda
::
getStreamFromExternal
(
_stream_comm
,
at
::
cuda
::
current_device
())};
}
}
/***************************************************************************************************
/***************************************************************************************************
...
@@ -305,6 +307,14 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto
...
@@ -305,6 +307,14 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto
return
torch
::
from_blob
(
ubuf_ptr
,
*
shape
,
at
::
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
return
torch
::
from_blob
(
ubuf_ptr
,
*
shape
,
at
::
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
}
}
at
::
Stream
CommOverlapP2P
::
get_communication_stream
()
{
std
::
pair
<
at
::
Stream
,
at
::
Stream
>
CommOverlapP2P
::
get_communication_stream
()
{
return
at
::
cuda
::
getStreamFromExternal
(
_stream_recv
,
at
::
cuda
::
current_device
());
return
{
at
::
cuda
::
getStreamFromExternal
(
_stream_send
[
0
],
at
::
cuda
::
current_device
()),
at
::
cuda
::
getStreamFromExternal
(
_stream_recv
,
at
::
cuda
::
current_device
())};
}
void
transformer_engine
::
pytorch
::
bulk_overlap_ag_with_external_gemm
(
CommOverlap
&
allgather_communicator
,
at
::
Stream
send_stream
,
at
::
Stream
recv_stream
)
{
auto
main_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
allgather_communicator
.
bulk_overlap_external_ag
(
at
::
cuda
::
CUDAStream
(
send_stream
),
at
::
cuda
::
CUDAStream
(
recv_stream
),
main_stream
);
}
}
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
87e3e56e
...
@@ -94,7 +94,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
...
@@ -94,7 +94,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
,
std
::
optional
<
CommOverlapType
>
comm_type
,
MaybeTensor
extra_output
,
std
::
optional
<
CommOverlapType
>
comm_type
,
MaybeTensor
extra_output
,
bool
bulk_overlap
)
{
bool
bulk_overlap
,
float
alpha
,
std
::
optional
<
float
>
beta
)
{
// Input tensors
// Input tensors
NVTE_CHECK
(
!
A
.
is_none
(),
"Tensor A has not been provided"
);
NVTE_CHECK
(
!
A
.
is_none
(),
"Tensor A has not been provided"
);
NVTE_CHECK
(
!
B
.
is_none
(),
"Tensor B has not been provided"
);
NVTE_CHECK
(
!
B
.
is_none
(),
"Tensor B has not been provided"
);
...
@@ -112,6 +112,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
...
@@ -112,6 +112,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
NVTE_CHECK
(
A_shape
.
ndim
>=
1
,
"Tensor A needs to have at least 1 dimension"
);
NVTE_CHECK
(
A_shape
.
ndim
>=
1
,
"Tensor A needs to have at least 1 dimension"
);
NVTE_CHECK
(
B_shape
.
ndim
>=
1
,
"Tensor B needs to have at least 1 dimension"
);
NVTE_CHECK
(
B_shape
.
ndim
>=
1
,
"Tensor B needs to have at least 1 dimension"
);
// Check scaling factors
if
(
accumulate
)
{
if
(
!
beta
)
{
beta
=
1.0
f
;
}
}
else
{
if
(
!
beta
)
{
beta
=
0.0
f
;
}
NVTE_CHECK
(
beta
==
0.0
,
"Trying to use non-zero beta while not accumulating "
,
"into D tensor. Beta has nothing to be applied to."
);
}
// Output tensor
// Output tensor
TensorWrapper
D_tensor
;
TensorWrapper
D_tensor
;
if
(
D
.
is_none
())
{
if
(
D
.
is_none
())
{
...
@@ -240,9 +253,10 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
...
@@ -240,9 +253,10 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
}
else
{
}
else
{
// Launch GEMM
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_gemm
(
A_tensor
.
data
(),
B_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
nvte_cublas_gemm_scaled
(
A_tensor
.
data
(),
B_tensor
.
data
(),
D_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
accumulate
,
use_split_accumulator
,
num_math_sms
,
main_stream
);
te_workspace
.
data
(),
alpha
,
*
beta
,
use_split_accumulator
,
num_math_sms
,
main_stream
);
});
});
}
}
}
else
{
}
else
{
...
@@ -328,10 +342,8 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
...
@@ -328,10 +342,8 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
)
{
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
)
{
std
::
vector
<
NVTETensor
>
te_A_vector
,
te_B_vector
,
te_D_vector
,
te_bias_vector
,
std
::
vector
<
NVTETensor
>
te_A_vector
,
te_B_vector
,
te_D_vector
,
te_bias_vector
,
te_pre_gelu_out_vector
,
te_workspace_vector
;
te_pre_gelu_out_vector
,
te_workspace_vector
;
std
::
vector
<
TensorWrapper
>
wrappers
;
std
::
vector
<
TensorWrapper
>
te_A_wrappers
,
te_B_wrappers
,
wrappers
;
std
::
vector
<
at
::
Tensor
>
D_vectors
;
std
::
vector
<
at
::
Tensor
>
D_vectors
;
// Keep the swizzled scaling factor tensors alive during the GEMMs.
std
::
vector
<
std
::
optional
<
at
::
Tensor
>>
swizzled_scale_inverses_list
;
auto
none
=
py
::
none
();
auto
none
=
py
::
none
();
...
@@ -398,10 +410,6 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
...
@@ -398,10 +410,6 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
continue
;
continue
;
}
}
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
te_A
,
transa
)));
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
te_B
,
!
transb
)));
auto
te_D
=
makeTransformerEngineTensor
(
out_tensor
);
auto
te_D
=
makeTransformerEngineTensor
(
out_tensor
);
auto
te_bias
=
makeTransformerEngineTensor
(
bias
[
i
]);
auto
te_bias
=
makeTransformerEngineTensor
(
bias
[
i
]);
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
pre_gelu_out
[
i
]);
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
pre_gelu_out
[
i
]);
...
@@ -421,18 +429,25 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
...
@@ -421,18 +429,25 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_bias_vector
.
emplace_back
(
te_bias
.
data
());
te_bias_vector
.
emplace_back
(
te_bias
.
data
());
te_pre_gelu_out_vector
.
emplace_back
(
te_pre_gelu_out
.
data
());
te_pre_gelu_out_vector
.
emplace_back
(
te_pre_gelu_out
.
data
());
wrappers
.
emplace_back
(
std
::
move
(
te_A
));
te_A_
wrappers
.
emplace_back
(
std
::
move
(
te_A
));
wrappers
.
emplace_back
(
std
::
move
(
te_B
));
te_B_
wrappers
.
emplace_back
(
std
::
move
(
te_B
));
wrappers
.
emplace_back
(
std
::
move
(
te_D
));
wrappers
.
emplace_back
(
std
::
move
(
te_D
));
wrappers
.
emplace_back
(
std
::
move
(
te_bias
));
wrappers
.
emplace_back
(
std
::
move
(
te_bias
));
wrappers
.
emplace_back
(
std
::
move
(
te_pre_gelu_out
));
wrappers
.
emplace_back
(
std
::
move
(
te_pre_gelu_out
));
}
}
// Optionally swizzle the scaling factors
// Keep the swizzled scaling factor tensors alive during the GEMMs.
auto
swizzled_scale_inv_A
=
multi_tensor_swizzle_scaling_factors
(
te_A_wrappers
,
transa
);
auto
swizzled_scale_inv_B
=
multi_tensor_swizzle_scaling_factors
(
te_B_wrappers
,
!
transb
);
for
(
size_t
i
=
0
;
i
<
workspace
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
workspace
.
size
();
i
++
)
{
auto
wsp
=
makeTransformerEngineTensor
(
workspace
[
i
].
data_ptr
(),
auto
wsp
=
makeTransformerEngineTensor
(
workspace
[
i
].
data_ptr
(),
std
::
vector
<
size_t
>
{
workspaceSize
},
DType
::
kByte
);
std
::
vector
<
size_t
>
{
workspaceSize
},
DType
::
kByte
);
te_workspace_vector
.
emplace_back
(
wsp
.
data
());
te_workspace_vector
.
emplace_back
(
wsp
.
data
());
wrappers
.
emplace_back
(
std
::
move
(
wsp
));
wrappers
.
emplace_back
(
std
::
move
(
wsp
));
}
}
// For now, we only have multi-stream cublas backend.
// For now, we only have multi-stream cublas backend.
const
char
*
NVTE_USE_HIPBLASLT_GROUPEDGEMM
=
std
::
getenv
(
"NVTE_USE_HIPBLASLT_GROUPEDGEMM"
);
const
char
*
NVTE_USE_HIPBLASLT_GROUPEDGEMM
=
std
::
getenv
(
"NVTE_USE_HIPBLASLT_GROUPEDGEMM"
);
if
(
NVTE_USE_HIPBLASLT_GROUPEDGEMM
!=
nullptr
&&
NVTE_USE_HIPBLASLT_GROUPEDGEMM
[
0
]
==
'1'
){
if
(
NVTE_USE_HIPBLASLT_GROUPEDGEMM
!=
nullptr
&&
NVTE_USE_HIPBLASLT_GROUPEDGEMM
[
0
]
==
'1'
){
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp
View file @
87e3e56e
...
@@ -16,11 +16,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
...
@@ -16,11 +16,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_adam_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
nvte_multi_tensor_adam_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
num_tensors
,
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
weight_decay
,
at
::
cuda
::
getCurrentCUDAStream
());
}
}
void
multi_tensor_adam_param_remainder_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
void
multi_tensor_adam_param_remainder_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
...
@@ -31,12 +30,10 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag
...
@@ -31,12 +30,10 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_adam_param_remainder_cuda
(
nvte_multi_tensor_adam_param_remainder_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr
,
beta1
,
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
device_id
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
at
::
cuda
::
getCurrentCUDAStream
());
at
::
cuda
::
getCurrentCUDAStream
());
}
}
void
multi_tensor_adam_fp8_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
void
multi_tensor_adam_fp8_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
...
@@ -47,12 +44,11 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
...
@@ -47,12 +44,11 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_adam_fp8_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
nvte_multi_tensor_adam_fp8_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
num_lists
,
num_tensors
,
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
static_cast
<
NVTEDType
>
(
fp8_dtype
),
bias_correction
,
weight_decay
,
static_cast
<
NVTEDType
>
(
fp8_dtype
),
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
at
::
cuda
::
getCurrentCUDAStream
());
}
}
void
multi_tensor_adam_capturable_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
void
multi_tensor_adam_capturable_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
...
@@ -67,12 +63,11 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
...
@@ -67,12 +63,11 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
auto
lr_cu
=
makeTransformerEngineTensor
(
lr
);
auto
lr_cu
=
makeTransformerEngineTensor
(
lr
);
auto
step_cu
=
makeTransformerEngineTensor
(
step
);
auto
step_cu
=
makeTransformerEngineTensor
(
step
);
auto
inv_scale_cu
=
makeTransformerEngineTensor
(
inv_scale
);
auto
inv_scale_cu
=
makeTransformerEngineTensor
(
inv_scale
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_adam_capturable_cuda
(
nvte_multi_tensor_adam_capturable_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr_cu
.
data
(),
beta1
,
beta2
,
epsilon
,
step_cu
.
data
(),
mode
,
bias_correction
,
weight_decay
,
lr_cu
.
data
(),
beta1
,
beta2
,
epsilon
,
step_cu
.
data
(),
mode
,
bias_correction
,
weight_decay
,
inv_scale_cu
.
data
(),
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
inv_scale_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
}
}
void
multi_tensor_adam_capturable_master_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
void
multi_tensor_adam_capturable_master_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
...
@@ -87,12 +82,11 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl
...
@@ -87,12 +82,11 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl
auto
lr_cu
=
makeTransformerEngineTensor
(
lr
);
auto
lr_cu
=
makeTransformerEngineTensor
(
lr
);
auto
step_cu
=
makeTransformerEngineTensor
(
step
);
auto
step_cu
=
makeTransformerEngineTensor
(
step
);
auto
inv_scale_cu
=
makeTransformerEngineTensor
(
inv_scale
);
auto
inv_scale_cu
=
makeTransformerEngineTensor
(
inv_scale
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_adam_capturable_master_cuda
(
nvte_multi_tensor_adam_capturable_master_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr_cu
.
data
(),
beta1
,
beta2
,
epsilon
,
step_cu
.
data
(),
mode
,
bias_correction
,
weight_decay
,
lr_cu
.
data
(),
beta1
,
beta2
,
epsilon
,
step_cu
.
data
(),
mode
,
bias_correction
,
weight_decay
,
inv_scale_cu
.
data
(),
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
inv_scale_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
}
}
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
View file @
87e3e56e
...
@@ -14,11 +14,10 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(
...
@@ -14,11 +14,10 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_compute_scale_and_scale_inv_cuda
(
nvte_multi_tensor_compute_scale_and_scale_inv_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
max_fp8
,
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
max_fp8
,
force_pow_2_scales
,
epsilon
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
force_pow_2_scales
,
epsilon
,
at
::
cuda
::
getCurrentCUDAStream
());
}
}
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
View file @
87e3e56e
...
@@ -43,12 +43,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
...
@@ -43,12 +43,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
auto
output_per_tensor_cu
=
makeTransformerEngineTensor
(
output_per_tensor
);
auto
output_per_tensor_cu
=
makeTransformerEngineTensor
(
output_per_tensor
);
auto
ret_cu
=
makeTransformerEngineTensor
(
ret
);
auto
ret_cu
=
makeTransformerEngineTensor
(
ret
);
auto
ret_per_tensor_cu
=
makeTransformerEngineTensor
(
ret_per_tensor
);
auto
ret_per_tensor_cu
=
makeTransformerEngineTensor
(
ret_per_tensor
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
nvte_multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
output_cu
.
data
(),
output_per_tensor_cu
.
data
(),
num_tensors
,
output_cu
.
data
(),
output_per_tensor_cu
.
data
(),
ret_cu
.
data
(),
ret_per_tensor_cu
.
data
(),
per_tensor
,
ret_cu
.
data
(),
ret_per_tensor_cu
.
data
(),
per_tensor
,
max_chunks_per_tensor
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
max_chunks_per_tensor
,
at
::
cuda
::
getCurrentCUDAStream
());
return
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
(
ret
,
ret_per_tensor
);
return
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
(
ret
,
ret_per_tensor
);
}
}
...
@@ -91,13 +90,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
...
@@ -91,13 +90,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
auto
ret_cu
=
makeTransformerEngineTensor
(
ret
);
auto
ret_cu
=
makeTransformerEngineTensor
(
ret
);
auto
ret_per_tensor_cu
=
makeTransformerEngineTensor
(
ret_per_tensor
);
auto
ret_per_tensor_cu
=
makeTransformerEngineTensor
(
ret_per_tensor
);
auto
inv_scale_cu
=
makeTransformerEngineTensor
(
inv_scale
);
auto
inv_scale_cu
=
makeTransformerEngineTensor
(
inv_scale
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_unscale_l2norm_cuda
(
nvte_multi_tensor_unscale_l2norm_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
output_cu
.
data
(),
output_per_tensor_cu
.
data
(),
ret_cu
.
data
(),
ret_per_tensor_cu
.
data
(),
output_cu
.
data
(),
output_per_tensor_cu
.
data
(),
ret_cu
.
data
(),
ret_per_tensor_cu
.
data
(),
inv_scale_cu
.
data
(),
per_tensor
,
max_chunks_per_tensor
,
device_id
,
inv_scale_cu
.
data
(),
per_tensor
,
max_chunks_per_tensor
,
at
::
cuda
::
getCurrentCUDAStream
());
at
::
cuda
::
getCurrentCUDAStream
());
return
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
(
ret
,
ret_per_tensor
);
return
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
(
ret
,
ret_per_tensor
);
}
}
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
View file @
87e3e56e
...
@@ -13,10 +13,9 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
...
@@ -13,10 +13,9 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_scale_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
nvte_multi_tensor_scale_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
scale
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
num_tensors
,
scale
,
at
::
cuda
::
getCurrentCUDAStream
());
}
}
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
Prev
1
…
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment