Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
87e3e56e
Commit
87e3e56e
authored
Aug 27, 2025
by
yuguo
Browse files
Merge commit '
734bcedd
' of...
Merge commit '
734bcedd
' of
https://github.com/NVIDIA/TransformerEngine
parents
2f11bd2e
734bcedd
Changes
217
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
660 additions
and
390 deletions
+660
-390
transformer_engine/jax/sharding.py
transformer_engine/jax/sharding.py
+25
-34
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
.../attention/dot_product_attention/dot_product_attention.py
+1
-1
transformer_engine/pytorch/attention/dot_product_attention/utils.py
...r_engine/pytorch/attention/dot_product_attention/utils.py
+3
-3
transformer_engine/pytorch/attention/multi_head_attention.py
transformer_engine/pytorch/attention/multi_head_attention.py
+116
-23
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+16
-0
transformer_engine/pytorch/cpu_offload.py
transformer_engine/pytorch/cpu_offload.py
+15
-3
transformer_engine/pytorch/cross_entropy.py
transformer_engine/pytorch/cross_entropy.py
+13
-2
transformer_engine/pytorch/csrc/common.cpp
transformer_engine/pytorch/csrc/common.cpp
+2
-2
transformer_engine/pytorch/csrc/common.h
transformer_engine/pytorch/csrc/common.h
+78
-20
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+17
-3
transformer_engine/pytorch/csrc/extensions/activation.cpp
transformer_engine/pytorch/csrc/extensions/activation.cpp
+74
-69
transformer_engine/pytorch/csrc/extensions/attention.cpp
transformer_engine/pytorch/csrc/extensions/attention.cpp
+45
-8
transformer_engine/pytorch/csrc/extensions/bias.cpp
transformer_engine/pytorch/csrc/extensions/bias.cpp
+198
-55
transformer_engine/pytorch/csrc/extensions/cast.cpp
transformer_engine/pytorch/csrc/extensions/cast.cpp
+6
-130
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
...rmer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
+14
-4
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+28
-13
transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp
...rmer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp
+5
-11
transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
...ne/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
+1
-2
transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
...er_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
+2
-5
transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
...mer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
+1
-2
No files found.
transformer_engine/jax/sharding.py
View file @
87e3e56e
...
...
@@ -15,10 +15,10 @@ from dataclasses import dataclass
from
enum
import
Enum
from
typing
import
Callable
,
Optional
import
warnings
from
jax.interpreters
import
pxla
import
jax
import
jax.numpy
as
jnp
from
jax.sharding
import
PartitionSpec
from
jax.interpreters
import
pxla
from
jax.sharding
import
PartitionSpec
,
get_abstract_mesh
import
numpy
as
np
_PXLA_THREAD_RESOURCES
=
pxla
.
thread_resources
...
...
@@ -86,24 +86,29 @@ def get_sharding_map_logic_axis_to_mesh_axis():
return
te_logical_axis_to_mesh_axis
def
generate_pspec
(
logical_axis_names
):
def
_
generate_pspec
(
logical_axis_names
):
"""
Convert logical axes to PartitionSpec
Convert TransformerEngine logical axes (e.g. BATCH_AXES) to a JAX PartitionSpec.
Note, this method does not support Flax logical axes.
Args:
logical_axis_names: TransformerEngine logical axes to convert to a JAX PartitionSpec.
Returns:
A JAX PartitionSpec with the mesh axes corresponding to the given TransformerEngine logical axis names
"""
rules
=
get_sharding_map_logic_axis_to_mesh_axis
()
# mesh_axis_names = [rules[name] for name in logical_axis_names]
mesh_axis_names
=
[]
for
name
in
logical_axis_names
:
axis_name
=
rules
[
name
]
if
name
in
rules
else
None
mesh_axis_names
.
append
(
axis_name
)
mesh_axis_names
=
[
rules
.
get
(
name
)
for
name
in
logical_axis_names
]
pspec
=
jax
.
sharding
.
PartitionSpec
(
*
mesh_axis_names
)
return
pspec
def
with_sharding_constraint
(
x
:
jnp
.
array
,
pspec
:
PartitionSpec
):
"""
A wrapper function to jax.lax.with_sharding_constraint to
support the case that Mesh is empty.
A wrapper function to jax.lax.with_sharding_constraint
1. Does nothing if mesh is empty.
2. If all mesh axes are manual axes, replaces pspec with all Nones.
3. Otherwise, strips only the manual axes.
"""
if
pspec
is
None
:
return
x
...
...
@@ -111,7 +116,14 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
mesh
=
_PXLA_THREAD_RESOURCES
.
env
.
physical_mesh
if
mesh
.
empty
:
return
x
return
jax
.
lax
.
with_sharding_constraint
(
x
,
pspec
)
# We want to exclude the axes that already used by shard_map and shard_map
# only sets those in the abstract_mesh, not the physical one
manual_axis_names
=
get_abstract_mesh
().
manual_axes
cleaned_axis_names
=
tuple
(
name
if
name
not
in
manual_axis_names
else
None
for
name
in
pspec
)
cleaned_pspec
=
PartitionSpec
(
*
cleaned_axis_names
)
return
jax
.
lax
.
with_sharding_constraint
(
x
,
cleaned_pspec
)
def
with_sharding_constraint_by_logical_axes
(
...
...
@@ -159,7 +171,7 @@ def with_sharding_constraint_by_logical_axes(
# If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table
assert
len
(
x
.
shape
)
==
len
(
logical_axis_names
)
pspec
=
generate_pspec
(
logical_axis_names
)
pspec
=
_
generate_pspec
(
logical_axis_names
)
return
with_sharding_constraint
(
x
,
pspec
)
...
...
@@ -383,24 +395,3 @@ class ShardingType(Enum):
TP_ROW
=
(
MajorShardingType
.
TP
,
"tp_row"
)
DP_TP_COL
=
(
MajorShardingType
.
DPTP
,
"dp_tp_col"
)
DP_TP_ROW
=
(
MajorShardingType
.
DPTP
,
"dp_tp_row"
)
def
get_non_contracting_logical_axes
(
ndim
,
logical_axes
:
tuple
[
Optional
[
str
]],
contracting_dims
)
->
tuple
[
Optional
[
str
]]:
"""Get logical axes for non-contracting dimensions.
Args:
ndim: Number of dimensions in the tensor.
logical_axes: Tuple of logical axes for each dimension.
contracting_dims: Set of dimensions that are being contracted.
Returns:
Tuple of logical axes for non-contracting dimensions.
"""
assert
logical_axes
is
not
None
,
"Logical axes must be a tuple and cannot be None."
assert
len
(
logical_axes
)
==
ndim
,
"Logical axes must match the number of dimensions."
non_contracting_dims
=
[
i
for
i
in
range
(
ndim
)
if
i
not
in
contracting_dims
]
non_contracting_logical_axes
=
tuple
(
logical_axes
[
i
]
for
i
in
non_contracting_dims
)
return
non_contracting_logical_axes
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
View file @
87e3e56e
...
...
@@ -630,7 +630,7 @@ class DotProductAttention(TransformerEngineBaseModule):
If true, there are padding tokens between individual sequences in a packed batch.
"""
with
self
.
prepare_forward
(
with
torch
.
cuda
.
device
(
query_layer
.
device
),
self
.
prepare_forward
(
query_layer
,
num_gemms
=
3
,
allow_non_contiguous
=
True
,
...
...
transformer_engine/pytorch/attention/dot_product_attention/utils.py
View file @
87e3e56e
...
...
@@ -438,8 +438,8 @@ def get_attention_backend(
# | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if
inference_params
is
not
None
:
if
device_compute_capability
==
(
8
,
9
)
and
cudnn_version
<
(
9
,
12
,
0
):
logger
.
debug
(
"Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.12"
)
if
device_compute_capability
==
(
8
,
9
)
and
cudnn_version
<
=
(
9
,
12
,
0
):
logger
.
debug
(
"Disabling FusedAttention for KV caching for sm89 and cuDNN <
=
9.12"
)
use_fused_attention
=
False
if
context_parallel
:
logger
.
debug
(
"Disabling all backends for KV caching with context parallelism"
)
...
...
@@ -625,7 +625,7 @@ def get_attention_backend(
" bias for THD format"
)
use_fused_attention
=
False
elif
fp8
and
head_dim_qk
!=
head_dim_v
:
elif
fp8
and
fp8_meta
[
"recipe"
].
fp8_dpa
and
head_dim_qk
!=
head_dim_v
:
logger
.
debug
(
"Disabling FusedAttention as it does not support context parallelism with FP8"
" MLA attention"
...
...
transformer_engine/pytorch/attention/multi_head_attention.py
View file @
87e3e56e
...
...
@@ -11,7 +11,7 @@ from transformer_engine.debug.pytorch.debug_state import TEDebugState
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.module.base
import
TransformerEngineBaseModule
from
transformer_engine.pytorch.module
import
LayerNormLinear
,
Linear
from
transformer_engine.pytorch.module
import
LayerNormLinear
,
Linear
,
RMSNorm
,
LayerNorm
from
transformer_engine.pytorch.ops.basic.l2normalization
import
L2Normalization
from
transformer_engine.pytorch.utils
import
(
SplitAlongDim
,
...
...
@@ -175,14 +175,23 @@ class MultiheadAttention(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False'
if set to `True`, L2 normalization is applied to query and key tensors
after RoPE (if applicable) but before attention computation.
This follows the Llama4 approach for QK normalization to improve
training stability and model performance.
qk_norm_type: Optional[str], default = None
type of normalization to apply to query and key tensors.
Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied.
When 'L2Normalization', L2 normalization is applied to query and key tensors.
When 'RMSNorm', RMS normalization is applied to query and key tensors.
When 'LayerNorm', layer normalization is applied to query and key tensors.
Normalization is applied after RoPE (if applicable) but before attention computation
when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach
for QK normalization to improve training stability and model performance.
qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors.
Only used when `use_qk_norm` is True.
epsilon value for normalization of query and key tensors.
Only used when `qk_norm_type` is not None.
qk_norm_before_rope: bool, default = `False`
if set to `True`, query and key normalization is applied before rotary position
embedding. When `False` (default), normalization is applied after RoPE.
This parameter allows supporting different architectural variants that apply
QK normalization at different points.
seq_length: Optional[int], default = `None`
sequence length of input samples. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are used for
...
...
@@ -231,8 +240,9 @@ class MultiheadAttention(torch.nn.Module):
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
qkv_format
:
str
=
"sbhd"
,
name
:
str
=
None
,
use_
qk_norm
:
bool
=
Fals
e
,
qk_norm
_type
:
Optional
[
str
]
=
Non
e
,
qk_norm_eps
:
float
=
1e-6
,
qk_norm_before_rope
:
bool
=
False
,
seq_length
:
Optional
[
int
]
=
None
,
micro_batch_size
:
Optional
[
int
]
=
None
,
)
->
None
:
...
...
@@ -264,6 +274,7 @@ class MultiheadAttention(torch.nn.Module):
qkv_weight_interleaved
=
False
self
.
qkv_weight_interleaved
=
qkv_weight_interleaved
self
.
rotary_pos_interleaved
=
rotary_pos_interleaved
self
.
qk_norm_before_rope
=
qk_norm_before_rope
assert
attention_type
in
AttnTypes
,
f
"attention_type
{
attention_type
}
not supported"
if
layer_number
is
not
None
:
...
...
@@ -288,7 +299,6 @@ class MultiheadAttention(torch.nn.Module):
self
.
hidden_size_kv
=
self
.
hidden_size_per_attention_head
*
self
.
num_gqa_groups
self
.
name
=
name
self
.
use_qk_norm
=
use_qk_norm
common_gemm_kwargs
=
{
"fuse_wgrad_accumulation"
:
fuse_wgrad_accumulation
,
...
...
@@ -300,13 +310,9 @@ class MultiheadAttention(torch.nn.Module):
"device"
:
device
,
}
# Initialize L2 normalization modules for query and key if enabled
if
self
.
use_qk_norm
:
self
.
qk_norm
=
L2Normalization
(
eps
=
qk_norm_eps
,
seq_length
=
seq_length
,
micro_batch_size
=
micro_batch_size
,
)
self
.
q_norm
,
self
.
k_norm
=
self
.
_create_qk_norm_modules
(
qk_norm_type
,
qk_norm_eps
,
device
,
seq_length
,
micro_batch_size
)
qkv_parallel_mode
=
"column"
if
set_parallel_mode
else
None
...
...
@@ -427,6 +433,78 @@ class MultiheadAttention(torch.nn.Module):
**
common_gemm_kwargs
,
)
def
_create_qk_norm_modules
(
self
,
qk_norm_type
:
Optional
[
str
],
qk_norm_eps
:
float
,
device
:
Union
[
torch
.
device
,
str
],
seq_length
:
Optional
[
int
]
=
None
,
micro_batch_size
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
Optional
[
torch
.
nn
.
Module
],
Optional
[
torch
.
nn
.
Module
]]:
"""
Create query and key normalization modules based on the specified normalization type.
Parameters
----------
qk_norm_type : Optional[str]
Type of normalization to apply. Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'
qk_norm_eps : float
Epsilon value for numerical stability
device : Union[torch.device, str]
Device to place the normalization modules on
seq_length : Optional[int], default = None
Sequence length for L2Normalization optimization
micro_batch_size : Optional[int], default = None
Micro batch size for L2Normalization optimization
Returns
-------
Tuple[Optional[torch.nn.Module], Optional[torch.nn.Module]]
Query and key normalization modules (q_norm, k_norm)
"""
if
qk_norm_type
is
None
:
return
None
,
None
if
qk_norm_type
==
"L2Normalization"
:
l2_norm
=
L2Normalization
(
eps
=
qk_norm_eps
,
seq_length
=
seq_length
,
micro_batch_size
=
micro_batch_size
,
)
# L2Normalization is parameter-free, so we can share the same instance
return
l2_norm
,
l2_norm
if
qk_norm_type
==
"RMSNorm"
:
q_norm
=
RMSNorm
(
normalized_shape
=
self
.
hidden_size_per_attention_head
,
eps
=
qk_norm_eps
,
device
=
device
,
)
k_norm
=
RMSNorm
(
normalized_shape
=
self
.
hidden_size_per_attention_head
,
eps
=
qk_norm_eps
,
device
=
device
,
)
return
q_norm
,
k_norm
if
qk_norm_type
==
"LayerNorm"
:
q_norm
=
LayerNorm
(
normalized_shape
=
self
.
hidden_size_per_attention_head
,
eps
=
qk_norm_eps
,
device
=
device
,
)
k_norm
=
LayerNorm
(
normalized_shape
=
self
.
hidden_size_per_attention_head
,
eps
=
qk_norm_eps
,
device
=
device
,
)
return
q_norm
,
k_norm
raise
ValueError
(
f
"Unsupported QK norm type:
{
qk_norm_type
}
. "
"Supported types: ['L2Normalization', 'RMSNorm', 'LayerNorm']"
)
def
set_tensor_parallel_group
(
self
,
tp_group
:
Union
[
dist_group_type
,
None
])
->
None
:
"""
Set the tensor parallel group for the given
...
...
@@ -789,6 +867,14 @@ class MultiheadAttention(torch.nn.Module):
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ===========================
# Apply normalization to query and key tensors (before RoPE if configured)
# ===========================
if
self
.
q_norm
is
not
None
and
self
.
qk_norm_before_rope
:
query_layer
=
self
.
q_norm
(
query_layer
)
key_layer
=
self
.
k_norm
(
key_layer
)
# ======================================================
# Apply relative positional encoding (rotary embedding)
# ======================================================
...
...
@@ -821,12 +907,19 @@ class MultiheadAttention(torch.nn.Module):
q_pos_emb
=
q_pos_emb
[
sequence_start
:
sequence_end
,
...]
k_pos_emb
=
k_pos_emb
[
sequence_start
:
sequence_end
,
...]
if
pad_between_seqs
:
rotary_pos_cu_seq_lens_q
=
cu_seqlens_q_padded
rotary_pos_cu_seq_lens_kv
=
cu_seqlens_kv_padded
else
:
rotary_pos_cu_seq_lens_q
=
cu_seqlens_q
rotary_pos_cu_seq_lens_kv
=
cu_seqlens_kv
query_layer
=
apply_rotary_pos_emb
(
query_layer
,
q_pos_emb
,
self
.
qkv_format
,
fused
=
True
,
cu_seqlens
=
cu_seqlens_q
,
cu_seqlens
=
rotary_pos_
cu_seq
_
lens_q
,
cp_size
=
self
.
cp_size
,
cp_rank
=
self
.
cp_rank
,
interleaved
=
self
.
rotary_pos_interleaved
,
...
...
@@ -836,19 +929,19 @@ class MultiheadAttention(torch.nn.Module):
k_pos_emb
,
self
.
qkv_format
,
fused
=
True
,
cu_seqlens
=
cu_seqlens_kv
,
cu_seqlens
=
rotary_pos_
cu_seq
_
lens_kv
,
cp_size
=
self
.
cp_size
,
cp_rank
=
self
.
cp_rank
,
interleaved
=
self
.
rotary_pos_interleaved
,
)
# ===========================
# Apply
L2
normalization to query and key tensors
# Apply normalization to query and key tensors
(after RoPE if not applied before)
# ===========================
if
self
.
use_qk_norm
:
query_layer
=
self
.
q
k
_norm
(
query_layer
)
key_layer
=
self
.
q
k_norm
(
key_layer
)
if
self
.
q_norm
is
not
None
and
not
self
.
qk_norm_before_rope
:
query_layer
=
self
.
q_norm
(
query_layer
)
key_layer
=
self
.
k_norm
(
key_layer
)
# ===========================
# Core attention computation
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
87e3e56e
...
...
@@ -46,6 +46,15 @@ __all__ = [
]
def
validate_gemm_scale
(
scale
:
Optional
[
float
],
required
:
bool
)
->
float
:
"""Validate whether a GEMM scaling factor is consistent with its usage"""
if
required
:
return
scale
if
scale
is
not
None
else
1.0
if
scale
not
in
(
0.0
,
None
):
raise
ValueError
(
"scale must be zero"
)
return
0.0
def
general_gemm
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
...
...
@@ -54,6 +63,8 @@ def general_gemm(
quantization_params
:
Optional
[
Quantizer
]
=
None
,
gelu
:
bool
=
False
,
gelu_in
:
torch
.
Tensor
=
None
,
alpha
:
float
=
1.0
,
beta
:
Optional
[
float
]
=
None
,
accumulate
:
bool
=
False
,
layout
:
str
=
"TN"
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -72,6 +83,9 @@ def general_gemm(
transb
=
layout
[
1
]
==
"T"
# assert quantization_params is None, "FP8 output not supported yet"
alpha
=
validate_gemm_scale
(
alpha
,
True
)
beta
=
validate_gemm_scale
(
beta
,
accumulate
)
# if ub_type is not None:
# assert ub is not None, (
# f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires"
...
...
@@ -349,6 +363,8 @@ def general_gemm(
"comm_type"
:
ub_type
,
"extra_output"
:
extra_output
,
"bulk_overlap"
:
bulk_overlap
,
"alpha"
:
alpha
,
"beta"
:
beta
,
}
out
,
bias_grad
,
gelu_input
,
extra_output
=
tex
.
generic_gemm
(
*
args
,
**
kwargs
)
...
...
transformer_engine/pytorch/cpu_offload.py
View file @
87e3e56e
...
...
@@ -431,7 +431,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
tensor
=
self
.
fp8_tensor_object_map
.
pop
(
tensor_tag
)
if
self
.
double_buffering
:
tensor
.
do_not_clear
=
True
tensor
.
_
do_not_clear
=
True
self
.
tensor_tag_to_buf
.
pop
(
tensor_tag
,
None
)
# the tensor should have been copied back in on_group_commit_backward()
...
...
@@ -556,21 +556,33 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
for
tensor_label
,
state
in
self
.
tensor_tag_to_state
.
items
():
group_id
,
_
=
tensor_label
if
group_id
==
group_to_reload
:
if
self
.
double_buffering
:
reload_buffer
=
self
.
reload_double_buffer
[
double_buffer_idx
][
buffer_idx
]
else
:
reload_buffer
=
None
if
isinstance
(
state
,
tuple
):
recovered_tensor
=
SynchronizedGroupOffloadHandler
.
reload
(
state
,
True
,
self
.
reload_
double_buffer
[
double_buffer_idx
][
buffer_idx
]
state
,
True
,
reload_
buffer
)
buffer_idx
=
buffer_idx
+
1
self
.
tensor_tag_to_state
[
tensor_label
]
=
recovered_tensor
elif
isinstance
(
state
,
list
):
tensor_list
=
[]
for
state_tuple
in
state
:
if
self
.
double_buffering
:
reload_buffer
=
self
.
reload_double_buffer
[
double_buffer_idx
][
buffer_idx
]
else
:
reload_buffer
=
None
if
isinstance
(
state_tuple
,
tuple
):
tensor_list
.
append
(
SynchronizedGroupOffloadHandler
.
reload
(
state_tuple
,
True
,
self
.
reload_
double_buffer
[
double_buffer_idx
][
buffer_idx
]
,
reload_
buffer
,
)
)
buffer_idx
=
buffer_idx
+
1
...
...
transformer_engine/pytorch/cross_entropy.py
View file @
87e3e56e
...
...
@@ -29,6 +29,7 @@ class CrossEntropyFunction(torch.autograd.Function):
reduce_loss
=
False
,
dist_process_group
=
None
,
ignore_idx
=-
100
,
is_cg_capturable
=
False
,
):
"""
The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each
...
...
@@ -47,10 +48,16 @@ class CrossEntropyFunction(torch.autograd.Function):
tensor: The computed loss.
"""
loss
,
_input
=
triton_cross_entropy
.
cross_entropy_forward
(
_input
,
target
,
label_smoothing
,
reduce_loss
,
dist_process_group
,
ignore_idx
_input
,
target
,
label_smoothing
,
reduce_loss
,
dist_process_group
,
ignore_idx
,
)
ctx
.
save_for_backward
(
_input
.
detach
())
ctx
.
is_cg_capturable
=
is_cg_capturable
return
loss
@
staticmethod
...
...
@@ -66,13 +73,17 @@ class CrossEntropyFunction(torch.autograd.Function):
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
(
_input
,)
=
ctx
.
saved_tensors
_input
=
triton_cross_entropy
.
cross_entropy_backward
(
_input
,
grad_output
)
_input
=
triton_cross_entropy
.
cross_entropy_backward
(
_input
,
grad_output
,
ctx
.
is_cg_capturable
)
return
(
_input
,
None
,
None
,
None
,
None
,
None
,
None
,
)
...
...
transformer_engine/pytorch/csrc/common.cpp
View file @
87e3e56e
...
...
@@ -12,7 +12,7 @@
namespace
transformer_engine
::
pytorch
{
std
::
vector
<
size_t
>
getTensorShape
(
at
::
Tensor
t
)
{
std
::
vector
<
size_t
>
getTensorShape
(
const
at
::
Tensor
&
t
)
{
std
::
vector
<
size_t
>
shape
;
for
(
auto
s
:
t
.
sizes
())
{
shape
.
push_back
(
s
);
...
...
@@ -286,7 +286,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape) {
return
std
::
vector
<
size_t
>
(
shape
.
data
,
shape
.
data
+
shape
.
ndim
);
}
in
t
roundup
(
const
in
t
value
,
const
in
t
multiple
)
{
size_
t
roundup
(
const
size_
t
value
,
const
size_
t
multiple
)
{
assert
(
multiple
>
0
);
return
((
value
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
...
...
transformer_engine/pytorch/csrc/common.h
View file @
87e3e56e
...
...
@@ -116,9 +116,21 @@ class Quantizer {
virtual
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
=
0
;
virtual
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
=
0
;
/*! @brief Construct a tensor with uninitialized data */
virtual
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
)
const
=
0
;
/*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor
*
* The PyTorch tensor's attributes are modified to match the
* quantizer's configuration.
*/
virtual
std
::
pair
<
TensorWrapper
,
py
::
object
>
convert_and_update_tensor
(
py
::
object
tensor
)
const
=
0
;
/*! @brief Convert to a quantized data format */
virtual
void
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
)
=
0
;
virtual
~
Quantizer
()
=
default
;
...
...
@@ -139,9 +151,17 @@ class NoneQuantizer : public Quantizer {
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
override
{}
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
)
const
override
;
/*! @brief Construct a tensor with pre-initialized data */
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
at
::
Tensor
data
)
const
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
convert_and_update_tensor
(
py
::
object
tensor
)
const
override
;
void
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
)
override
;
};
class
Float8Quantizer
:
public
Quantizer
{
...
...
@@ -157,9 +177,19 @@ class Float8Quantizer : public Quantizer {
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
)
const
override
;
/*! @brief Construct a tensor with pre-initialized data */
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
data
,
std
::
optional
<
at
::
Tensor
>
transpose
,
std
::
optional
<
at
::
Tensor
>
scale_inv
)
const
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
convert_and_update_tensor
(
py
::
object
shape
)
const
override
;
void
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
)
override
;
};
class
Float8CurrentScalingQuantizer
:
public
Quantizer
{
...
...
@@ -179,9 +209,29 @@ class Float8CurrentScalingQuantizer : public Quantizer {
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
)
const
override
;
/*! @brief Construct a high precision tensor giving it this quantizer's amax
Note: this member function also zeros out the amax, as it is meant to be used in conjunction with
a kernel computing the amax, which might expect the amax to be initialized to zero
*/
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_hp_tensor_with_amax
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
);
std
::
pair
<
TensorWrapper
,
py
::
object
>
convert_and_update_tensor
(
py
::
object
shape
)
const
override
;
void
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
)
override
;
/*! @brief Convert to a quantized data format avoiding amax computation */
void
quantize_with_amax
(
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
);
private:
void
quantize_impl
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
,
bool
compute_amax
);
};
class
Float8BlockQuantizer
:
public
Quantizer
{
...
...
@@ -213,9 +263,13 @@ class Float8BlockQuantizer : public Quantizer {
// Create a python Float8BlockQuantized tensor and C++ wrapper
// for the tensor. Should set quantized data, scales for rowwise
// and optionally columnwise usage.
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
convert_and_update_tensor
(
py
::
object
shape
)
const
override
;
void
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
)
override
;
std
::
vector
<
size_t
>
get_scale_shape
(
const
std
::
vector
<
size_t
>&
shape
,
bool
columnwise
)
const
;
};
...
...
@@ -230,16 +284,20 @@ class MXFP8Quantizer : public Quantizer {
void
set_quantization_params
(
TensorWrapper
*
tensor
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
,
std
::
optional
<
at
::
Tensor
>
rowwise_data
=
std
::
nullopt
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
create_tensor
(
const
std
::
vector
<
size_t
>&
shape
,
DType
dtype
)
const
override
;
std
::
pair
<
TensorWrapper
,
py
::
object
>
convert_and_update_tensor
(
py
::
object
shape
)
const
override
;
void
quantize
(
const
TensorWrapper
&
input
,
TensorWrapper
&
out
,
const
std
::
optional
<
TensorWrapper
>&
noop_flag
=
std
::
nullopt
)
override
;
std
::
vector
<
size_t
>
get_scale_shape
(
const
std
::
vector
<
size_t
>&
shape
,
bool
columnwise
)
const
;
};
std
::
unique_ptr
<
Quantizer
>
convert_quantizer
(
py
::
handle
quantizer
);
std
::
vector
<
size_t
>
getTensorShape
(
at
::
Tensor
t
);
std
::
vector
<
size_t
>
getTensorShape
(
const
at
::
Tensor
&
t
);
transformer_engine
::
DType
getTransformerEngineFP8Type
(
bool
e4m3_if_hybrid
,
const
std
::
string
&
fp8_recipe
);
...
...
@@ -382,7 +440,7 @@ void* getDataPtr(at::Tensor tensor, int offset = 0);
std
::
vector
<
size_t
>
convertShape
(
const
NVTEShape
&
shape
);
in
t
roundup
(
const
in
t
value
,
const
in
t
multiple
);
size_
t
roundup
(
const
size_
t
value
,
const
size_
t
multiple
);
NVTEShape
convertTorchShape
(
const
c10
::
IntArrayRef
torch_shape
);
}
// namespace transformer_engine::pytorch
...
...
transformer_engine/pytorch/csrc/extensions.h
View file @
87e3e56e
...
...
@@ -11,6 +11,10 @@
#include "common.h"
class
CommOverlapHelper
;
class
CommOverlap
;
class
CommOverlapP2P
;
namespace
transformer_engine
::
pytorch
{
/***************************************************************************************************
...
...
@@ -118,7 +122,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
=
nullptr
,
std
::
optional
<
CommOverlapType
>
comm_type
=
std
::
nullopt
,
MaybeTensor
extra_output
=
std
::
nullopt
,
bool
bulk_overlap
=
false
);
MaybeTensor
extra_output
=
std
::
nullopt
,
bool
bulk_overlap
=
false
,
float
alpha
=
1.0
f
,
std
::
optional
<
float
>
beta
=
std
::
nullopt
);
void
te_atomic_gemm
(
at
::
Tensor
A
,
at
::
Tensor
A_scale_inverse
,
DType
A_type
,
std
::
vector
<
int64_t
>
A_scaling_mode
,
bool
transa
,
at
::
Tensor
B
,
...
...
@@ -179,6 +184,8 @@ std::vector<at::Tensor> te_batchgemm_ts(
at
::
Tensor
fp8_transpose
(
at
::
Tensor
input
,
DType
otype
,
std
::
optional
<
at
::
Tensor
>
output
=
std
::
nullopt
);
at
::
Tensor
swap_first_dims
(
at
::
Tensor
tensor
,
std
::
optional
<
at
::
Tensor
>
out
=
std
::
nullopt
);
/***************************************************************************************************
* Activations
**************************************************************************************************/
...
...
@@ -455,6 +462,13 @@ void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_k
void
nvshmem_finalize
();
/***************************************************************************************************
* Comm+GEMM Overlap Wrappers
**************************************************************************************************/
void
bulk_overlap_ag_with_external_gemm
(
CommOverlap
&
allgather_communicator
,
at
::
Stream
send_stream
,
at
::
Stream
recv_stream
);
}
// namespace transformer_engine::pytorch
/***************************************************************************************************
...
...
@@ -504,7 +518,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
at
::
Tensor
get_buffer
(
bool
local_chunk
=
false
,
std
::
optional
<
std
::
vector
<
int64_t
>>
shape
=
std
::
nullopt
);
at
::
Stream
get_communication_stream
();
std
::
pair
<
at
::
Stream
,
at
::
Stream
>
get_communication_stream
();
};
// CommOverlap
...
...
@@ -525,7 +539,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
at
::
Tensor
get_buffer
(
bool
local_chunk
=
false
,
std
::
optional
<
std
::
vector
<
int64_t
>>
shape
=
std
::
nullopt
);
at
::
Stream
get_communication_stream
();
std
::
pair
<
at
::
Stream
,
at
::
Stream
>
get_communication_stream
();
};
// CommOverlapP2P
...
...
transformer_engine/pytorch/csrc/extensions/activation.cpp
View file @
87e3e56e
...
...
@@ -13,87 +13,92 @@ namespace transformer_engine::pytorch {
template
<
void
(
*
act_func
)(
const
NVTETensor
,
NVTETensor
,
cudaStream_t
)>
py
::
object
activation_helper
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
,
int
shape_divisor
=
1
)
{
init_extension
();
auto
my_quantizer
=
convert_quantizer
(
quantizer
);
auto
input_tensor
=
input
.
contiguous
();
const
TensorWrapper
&
te_input
=
makeTransformerEngineTensor
(
input_tensor
);
const
auto
&
te_input_shape
=
te_input
.
shape
();
std
::
vector
<
size_t
>
input_shape
(
te_input_shape
.
data
,
te_input_shape
.
data
+
te_input_shape
.
ndim
);
input_shape
[
input_shape
.
size
()
-
1
]
/=
shape_divisor
;
auto
fake_tensor_type
=
input
.
scalar_type
();
auto
[
te_output
,
out
]
=
my_quantizer
->
create_tensor
(
input_shape
,
GetTransformerEngineDType
(
fake_tensor_type
));
// for current scaling, we need to compute amax first and then quantize
// because cache cannot fit in the entire tensor to compute amax and quantize
// the quantizer should not need amax reduction, no process group needed here
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// activation function might change the input data range, we need to first call the activation function
// and then find the amax and scale of that and then do the quantization
// get a NoneQuantizer to calculate amax of activation output
auto
my_quantizer_none
=
std
::
make_unique
<
NoneQuantizer
>
(
py
::
none
());
auto
[
te_output_act
,
out_act
]
=
my_quantizer_none
->
create_tensor
(
input_shape
,
GetTransformerEngineDType
(
fake_tensor_type
));
NVTE_SCOPED_GIL_RELEASE
({
act_func
(
te_input
.
data
(),
te_output_act
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
// use te_output_act as input to the compute amax and find the amax of activated tensor
nvte_compute_amax
(
te_output_act
.
data
(),
te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
if
(
my_quantizer_cs
->
with_amax_reduction
)
{
NVTE_ERROR
(
"per-tensor current scaling amax reduction is not supported in activation functions."
);
}
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
// Input tensor
auto
input_tensor
=
input
.
contiguous
();
const
TensorWrapper
&
input_cpp
=
makeTransformerEngineTensor
(
input_tensor
);
// Construct output tensor
auto
quantizer_cpp
=
convert_quantizer
(
quantizer
);
const
auto
input_shape
=
input_cpp
.
shape
();
std
::
vector
<
size_t
>
output_shape
(
input_shape
.
data
,
input_shape
.
data
+
input_shape
.
ndim
);
output_shape
.
back
()
/=
shape_divisor
;
auto
fake_dtype
=
GetTransformerEngineDType
(
input_tensor
.
scalar_type
());
auto
[
out_cpp
,
out_py
]
=
quantizer_cpp
->
create_tensor
(
output_shape
,
fake_dtype
);
// Compute activation
if
(
quantizer
.
is_none
()
||
detail
::
IsFloat8Quantizers
(
quantizer
.
ptr
())
||
detail
::
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
// Compute activation directly
NVTE_SCOPED_GIL_RELEASE
(
{
act_func
(
input_cpp
.
data
(),
out_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// Compute activation in high-precision fused together with amax, then quantize.
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_scale_from_amax
(
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
te_output
.
defaultShape
);
nvte_quantize_v2
(
te_output_act
.
data
(),
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
else
if
(
detail
::
IsFloat8BlockwiseQuantizers
(
quantizer
.
ptr
()))
{
// sanity check, since activation fusion is not supported for blockwise quantization yet
// need to raise an error here instead of silently going into act_func with wrong numerics
NVTE_ERROR
(
"Activation fusion is not supported for blockwise quantization yet."
);
auto
quantizer_cpp_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
auto
[
temp_cpp
,
_
]
=
quantizer_cpp_cs
->
create_hp_tensor_with_amax
(
output_shape
,
fake_dtype
);
NVTE_SCOPED_GIL_RELEASE
(
{
act_func
(
input_cpp
.
data
(),
temp_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
quantizer_cpp_cs
->
quantize_with_amax
(
temp_cpp
,
out_cpp
);
}
else
{
// Compute activation in high-precision, then quantize
auto
[
temp_cpp
,
_
]
=
NoneQuantizer
(
py
::
none
()).
create_tensor
(
output_shape
,
fake_dtype
);
NVTE_SCOPED_GIL_RELEASE
(
{
act_func
(
te_input
.
data
(),
te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
{
act_func
(
input_cpp
.
data
(),
temp_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
quantizer_cpp
->
quantize
(
temp_cpp
,
out_cpp
);
}
return
out
;
return
out
_py
;
}
template
<
void
(
*
act_func
)(
const
NVTETensor
,
const
NVTETensor
,
NVTETensor
,
cudaStream_t
)>
py
::
object
dactivation_helper
(
const
at
::
Tensor
&
grad
,
const
at
::
Tensor
&
input
,
template
<
void
(
*
d
act_func
)(
const
NVTETensor
,
const
NVTETensor
,
NVTETensor
,
cudaStream_t
)>
py
::
object
dactivation_helper
(
const
at
::
Tensor
&
grad
_output
,
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
init_extension
();
auto
my_quantizer
=
convert_quantizer
(
quantizer
);
auto
input_tensor
=
input
.
contiguous
();
auto
grad_tensor
=
grad
.
contiguous
();
const
TensorWrapper
&
te_input
=
makeTransformerEngineTensor
(
input_tensor
);
const
TensorWrapper
&
te_grad
=
makeTransformerEngineTensor
(
grad_tensor
);
const
auto
&
te_input_shape
=
te_input
.
shape
();
std
::
vector
<
size_t
>
input_shape
(
te_input_shape
.
data
,
te_input_shape
.
data
+
te_input_shape
.
ndim
);
auto
fake_tensor_type
=
input
.
scalar_type
();
auto
[
te_output
,
out
]
=
my_quantizer
->
create_tensor
(
input_shape
,
GetTransformerEngineDType
(
fake_tensor_type
));
NVTE_SCOPED_GIL_RELEASE
({
act_func
(
te_grad
.
data
(),
te_input
.
data
(),
te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// Grad output and input tensors
auto
grad_output_tensor
=
grad_output
.
contiguous
();
auto
input_tensor
=
input
.
contiguous
();
const
TensorWrapper
&
grad_output_cpp
=
makeTransformerEngineTensor
(
grad_output_tensor
);
const
TensorWrapper
&
input_cpp
=
makeTransformerEngineTensor
(
input_tensor
);
// Construct grad input tensor
auto
quantizer_cpp
=
convert_quantizer
(
quantizer
);
const
auto
input_shape_te
=
input_cpp
.
shape
();
const
std
::
vector
<
size_t
>
input_shape
(
input_shape_te
.
data
,
input_shape_te
.
data
+
input_shape_te
.
ndim
);
auto
fake_dtype
=
GetTransformerEngineDType
(
input_tensor
.
scalar_type
());
auto
[
grad_input_cpp
,
grad_input_py
]
=
quantizer_cpp
->
create_tensor
(
input_shape
,
fake_dtype
);
// Compute activation backward
if
(
quantizer
.
is_none
()
||
detail
::
IsFloat8Quantizers
(
quantizer
.
ptr
())
||
detail
::
IsMXFP8Quantizers
(
quantizer
.
ptr
()))
{
// Compute activation backward directly
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_cpp
.
data
(),
input_cpp
.
data
(),
grad_input_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// Compute activation backward in high-precision fused together with amax, then quantize.
auto
quantizer_cpp_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
auto
[
temp_cpp
,
_
]
=
quantizer_cpp_cs
->
create_hp_tensor_with_amax
(
input_shape
,
fake_dtype
);
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_cpp
.
data
(),
input_cpp
.
data
(),
temp_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
quantizer_cpp_cs
->
quantize_with_amax
(
temp_cpp
,
grad_input_cpp
);
}
else
{
// Compute activation backward in high-precision, then quantize
auto
[
temp_cpp
,
_
]
=
NoneQuantizer
(
py
::
none
()).
create_tensor
(
input_shape
,
fake_dtype
);
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_cpp
.
data
(),
input_cpp
.
data
(),
temp_cpp
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
quantizer_cpp
->
quantize
(
temp_cpp
,
grad_input_cpp
);
}
return
out
;
return
grad_input_py
;
}
py
::
object
gelu
(
const
at
::
Tensor
&
input
,
py
::
handle
quantizer
)
{
...
...
transformer_engine/pytorch/csrc/extensions/attention.cpp
View file @
87e3e56e
...
...
@@ -18,7 +18,7 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
auto
max_tokens
=
shape
[
0
];
auto
fcd_size
=
1
;
for
(
in
t
i
=
1
;
i
<=
shape
.
size
();
i
++
)
{
for
(
size_
t
i
=
1
;
i
<=
shape
.
size
();
i
++
)
{
fcd_size
*=
shape
[
i
];
}
...
...
@@ -110,8 +110,20 @@ std::vector<py::object> fused_attn_fwd(
auto
o_shape
=
std
::
vector
<
size_t
>
{
q_shape
.
begin
(),
q_shape
.
end
()};
o_shape
[
o_shape
.
size
()
-
1
]
=
v_shape
[
v_shape
.
size
()
-
1
];
py
::
object
o_python
,
s_python
;
std
::
tie
(
te_O
,
o_python
)
=
O_quantizer
->
create_tensor
(
o_shape
,
fake_dtype_te
);
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer
->
create_tensor
({
0
},
DType
::
kFloat32
);
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
// Initialize FP8 tensor with scale-inverse
auto
*
O_quantizer_fp8
=
dynamic_cast
<
Float8Quantizer
*>
(
O_quantizer
.
get
());
auto
*
S_quantizer_fp8
=
dynamic_cast
<
Float8Quantizer
*>
(
S_quantizer
.
get
());
NVTE_CHECK
(
O_quantizer_fp8
!=
nullptr
,
"Expected Float8Quantizer when dtype is FP8"
);
NVTE_CHECK
(
S_quantizer_fp8
!=
nullptr
,
"Expected Float8Quantizer when dtype is FP8"
);
std
::
tie
(
te_O
,
o_python
)
=
O_quantizer_fp8
->
create_tensor
(
o_shape
,
fake_dtype_te
,
std
::
nullopt
,
std
::
nullopt
,
std
::
nullopt
);
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer_fp8
->
create_tensor
({
0
},
DType
::
kFloat32
,
std
::
nullopt
,
std
::
nullopt
,
std
::
nullopt
);
}
else
{
std
::
tie
(
te_O
,
o_python
)
=
O_quantizer
->
create_tensor
(
o_shape
,
fake_dtype_te
);
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer
->
create_tensor
({
0
},
DType
::
kFloat32
);
}
auto
o_shape_int64
=
std
::
vector
<
int64_t
>
{
o_shape
.
begin
(),
o_shape
.
end
()};
// construct NVTE tensors
...
...
@@ -295,8 +307,20 @@ std::vector<py::object> fused_attn_bwd(
py
::
object
s_python
,
dp_python
;
std
::
unique_ptr
<
Quantizer
>
S_quantizer
=
convert_quantizer
(
s_quantizer
);
std
::
unique_ptr
<
Quantizer
>
dP_quantizer
=
convert_quantizer
(
dp_quantizer
);
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer
->
create_tensor
({
0
},
DType
::
kFloat32
);
std
::
tie
(
te_dP
,
dp_python
)
=
dP_quantizer
->
create_tensor
({
0
},
DType
::
kFloat32
);
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
auto
*
S_quantizer_fp8
=
dynamic_cast
<
Float8Quantizer
*>
(
S_quantizer
.
get
());
auto
*
dP_quantizer_fp8
=
dynamic_cast
<
Float8Quantizer
*>
(
dP_quantizer
.
get
());
NVTE_CHECK
(
S_quantizer_fp8
!=
nullptr
,
"Expected Float8Quantizer when dtype is FP8"
);
NVTE_CHECK
(
dP_quantizer_fp8
!=
nullptr
,
"Expected Float8Quantizer when dtype is FP8"
);
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer_fp8
->
create_tensor
({
0
},
DType
::
kFloat32
,
std
::
nullopt
,
std
::
nullopt
,
std
::
nullopt
);
std
::
tie
(
te_dP
,
dp_python
)
=
dP_quantizer_fp8
->
create_tensor
({
0
},
DType
::
kFloat32
,
std
::
nullopt
,
std
::
nullopt
,
std
::
nullopt
);
}
else
{
std
::
tie
(
te_S
,
s_python
)
=
S_quantizer
->
create_tensor
({
0
},
DType
::
kFloat32
);
std
::
tie
(
te_dP
,
dp_python
)
=
dP_quantizer
->
create_tensor
({
0
},
DType
::
kFloat32
);
}
std
::
vector
<
size_t
>
q_shape
=
convertShape
(
te_Q
.
shape
());
std
::
vector
<
size_t
>
k_shape
=
convertShape
(
te_K
.
shape
());
...
...
@@ -385,9 +409,22 @@ std::vector<py::object> fused_attn_bwd(
default:
NVTE_ERROR
(
"QKV layout not supported!"
);
}
std
::
tie
(
te_dQ
,
py_dQ
)
=
dQKV_quantizer
->
create_tensor
(
q_shape
,
fake_dtype_te
,
dQ
);
std
::
tie
(
te_dK
,
py_dK
)
=
dQKV_quantizer
->
create_tensor
(
k_shape
,
fake_dtype_te
,
dK
);
std
::
tie
(
te_dV
,
py_dV
)
=
dQKV_quantizer
->
create_tensor
(
v_shape
,
fake_dtype_te
,
dV
);
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
auto
*
fp8_quantizer
=
dynamic_cast
<
Float8Quantizer
*>
(
dQKV_quantizer
.
get
());
NVTE_CHECK
(
fp8_quantizer
!=
nullptr
,
"Expected Float8Quantizer when dtype is FP8"
);
std
::
tie
(
te_dQ
,
py_dQ
)
=
fp8_quantizer
->
create_tensor
(
q_shape
,
fake_dtype_te
,
dQ
,
std
::
nullopt
,
std
::
nullopt
);
std
::
tie
(
te_dK
,
py_dK
)
=
fp8_quantizer
->
create_tensor
(
k_shape
,
fake_dtype_te
,
dK
,
std
::
nullopt
,
std
::
nullopt
);
std
::
tie
(
te_dV
,
py_dV
)
=
fp8_quantizer
->
create_tensor
(
v_shape
,
fake_dtype_te
,
dV
,
std
::
nullopt
,
std
::
nullopt
);
}
else
{
auto
*
none_quantizer
=
dynamic_cast
<
NoneQuantizer
*>
(
dQKV_quantizer
.
get
());
NVTE_CHECK
(
none_quantizer
!=
nullptr
,
"Expected NoneQuantizer when dtype is not FP8"
);
std
::
tie
(
te_dQ
,
py_dQ
)
=
none_quantizer
->
create_tensor
(
q_shape
,
fake_dtype_te
,
dQ
);
std
::
tie
(
te_dK
,
py_dK
)
=
none_quantizer
->
create_tensor
(
k_shape
,
fake_dtype_te
,
dK
);
std
::
tie
(
te_dV
,
py_dV
)
=
none_quantizer
->
create_tensor
(
v_shape
,
fake_dtype_te
,
dV
);
}
// construct NVTE tensors
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
...
...
transformer_engine/pytorch/csrc/extensions/bias.cpp
View file @
87e3e56e
...
...
@@ -4,80 +4,223 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <pybind11/pybind11.h>
#include <utility>
#include <vector>
#include "common.h"
#include "extensions.h"
#include "pybind.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/transformer_engine.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
{
namespace
pytorch
{
std
::
vector
<
py
::
object
>
bgrad_quantize
(
const
at
::
Tensor
&
input
,
py
::
handle
py_quantizer
)
{
auto
quantizer
=
convert_quantizer
(
py_quantizer
);
std
::
vector
<
py
::
object
>
bgrad_quantize
(
const
at
::
Tensor
&
grad_output
,
py
::
handle
quantizer
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
init_extension
();
auto
input_tensor
=
makeTransformerEngineTensor
(
input
);
// Grad output tensor
auto
grad_output_torch
=
grad_output
.
contiguous
();
const
TensorWrapper
&
grad_output_nvte
=
makeTransformerEngineTensor
(
grad_output_torch
);
const
auto
shape
=
getTensorShape
(
grad_output_torch
);
auto
grad_output_dtype
=
GetTransformerEngineDType
(
grad_output_torch
.
scalar_type
());
auto
dbias
=
allocateTorchTensor
(
input
.
size
(
-
1
),
input_tensor
.
dtype
());
// Construct grad bias tensor
const
int64_t
bias_size
=
static_cast
<
int64_t
>
(
shape
.
back
());
auto
grad_bias_torch
=
allocateTorchTensor
(
bias_size
,
grad_output_dtype
);
auto
grad_bias_nvte
=
makeTransformerEngineTensor
(
grad_bias_torch
);
std
::
vector
<
size_t
>
output_shape
;
for
(
auto
s
:
input
.
sizes
())
{
output_shape
.
emplace_back
(
static_cast
<
size_t
>
(
s
));
// Unquantized impl only requires computing grad bias
if
(
quantizer
.
is_none
())
{
if
(
product
(
shape
)
==
0
)
{
grad_bias_torch
.
zero_
();
}
else
{
at
::
sum_out
(
grad_bias_torch
,
grad_output_torch
.
reshape
({
-
1
,
bias_size
}),
{
0
});
}
return
{
py
::
cast
(
std
::
move
(
grad_bias_torch
)),
py
::
cast
(
std
::
move
(
grad_output_torch
))};
}
auto
[
out_tensor
,
out
]
=
quantizer
->
create_tensor
(
output_shape
,
input_tensor
.
dtype
());
// Return immediately if tensors are empty
if
(
product
(
output_shape
)
==
0
)
{
return
{
py
::
cast
(
dbias
.
zero_
()),
out
};
// Construct grad input tensor
auto
quantizer_cpp
=
convert_quantizer
(
quantizer
);
auto
[
grad_input_nvte
,
grad_input_py
]
=
quantizer_cpp
->
create_tensor
(
shape
,
grad_output_dtype
);
// Trivial impl if tensors are empty
if
(
product
(
shape
)
==
0
)
{
grad_bias_torch
.
zero_
();
return
{
py
::
cast
(
std
::
move
(
grad_bias_torch
)),
std
::
move
(
grad_input_py
)};
}
// Unfused impl if quantizer is not supported
const
bool
with_fused_dbias_quantize_kernel
=
detail
::
IsFloat8Quantizers
(
quantizer
.
ptr
())
||
detail
::
IsMXFP8Quantizers
(
quantizer
.
ptr
());
if
(
!
with_fused_dbias_quantize_kernel
)
{
at
::
sum_out
(
grad_bias_torch
,
grad_output_torch
.
reshape
({
-
1
,
bias_size
}),
{
0
});
quantizer_cpp
->
quantize
(
grad_output_nvte
,
grad_input_nvte
);
return
{
py
::
cast
(
std
::
move
(
grad_bias_torch
)),
std
::
move
(
grad_input_py
)};
}
auto
dbias_tensor
=
makeTransformerEngineTensor
(
dbias
);
// Query workspace size and allocate workspace
transformer_engine
::
TensorWrapper
workspace
;
// Query workspace size
TensorWrapper
workspace_nvte
;
at
::
Tensor
workspace_torch
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_dbias
(
input_tensor
.
data
(),
out_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAS
tream
()
);
nvte_quantize_dbias
(
grad_output_nvte
.
data
(),
grad_input_nvte
.
data
(),
grad_bias_nvte
.
data
(),
workspace
_nvte
.
data
(),
s
tream
);
});
void
*
workspace_data_ptr
=
nullptr
;
if
(
workspace
.
shape
().
ndim
>
0
)
{
auto
workspace_data
=
allocateSpace
(
workspace
.
shape
(),
workspace
.
dtype
());
workspace_data_ptr
=
workspace_data
.
data_ptr
();
}
workspace
=
makeTransformerEngineTensor
(
workspace_data_ptr
,
workspace
.
shape
(),
workspace
.
dtype
());
// Launch kernel
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
py_quantizer
.
ptr
()))
{
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer
.
get
());
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_amax
(
input_tensor
.
data
(),
out_tensor
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// check if we need to do amax reudction (depending on model parallel configs)
if
(
my_quantizer_cs
->
with_amax_reduction
)
{
c10
::
intrusive_ptr
<
dist_group_type
>
process_group_ptr
=
my_quantizer_cs
->
amax_reduction_group
;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at
::
Tensor
&
amax_tensor_torch
=
my_quantizer_cs
->
amax
;
std
::
vector
<
at
::
Tensor
>
tensors
=
{
amax_tensor_torch
};
// allreduce amax tensor
c10d
::
AllreduceOptions
allreduce_opts
;
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
}
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_scale_from_amax
(
out_tensor
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_tensor
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
out_tensor
.
defaultShape
);
// Allocate workspace
if
(
workspace_nvte
.
ndim
()
>
0
&&
workspace_nvte
.
numel
()
>
0
)
{
workspace_torch
=
allocateSpace
(
workspace_nvte
.
shape
(),
workspace_nvte
.
dtype
());
workspace_nvte
=
makeTransformerEngineTensor
(
workspace_torch
.
data_ptr
(),
workspace_nvte
.
shape
(),
workspace_nvte
.
dtype
());
}
// Launch fused kernel
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_dbias
(
input_tensor
.
data
(),
out_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAS
tream
()
);
nvte_quantize_dbias
(
grad_output_nvte
.
data
(),
grad_input_nvte
.
data
(),
grad_bias_nvte
.
data
(),
workspace
_nvte
.
data
(),
s
tream
);
});
return
{
py
::
cast
(
dbias
),
out
};
return
{
py
::
cast
(
std
::
move
(
grad_bias_torch
)),
std
::
move
(
grad_input_py
)};
}
namespace
{
std
::
vector
<
py
::
object
>
dact_dbias
(
void
(
*
dact_dbias_func
)(
const
NVTETensor
,
const
NVTETensor
,
NVTETensor
,
NVTETensor
,
NVTETensor
,
cudaStream_t
),
void
(
*
dact_func
)(
const
NVTETensor
,
const
NVTETensor
,
NVTETensor
,
cudaStream_t
),
at
::
Tensor
grad_output_torch
,
at
::
Tensor
act_input_torch
,
py
::
handle
quantizer_py
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
init_extension
();
// Grad output and activation input tensors
grad_output_torch
=
grad_output_torch
.
contiguous
();
const
TensorWrapper
&
grad_output_nvte
=
makeTransformerEngineTensor
(
grad_output_torch
);
const
auto
output_shape
=
getTensorShape
(
grad_output_torch
);
auto
grad_output_dtype
=
GetTransformerEngineDType
(
grad_output_torch
.
scalar_type
());
act_input_torch
=
act_input_torch
.
contiguous
();
const
TensorWrapper
&
act_input_nvte
=
makeTransformerEngineTensor
(
act_input_torch
);
const
auto
input_shape
=
getTensorShape
(
act_input_torch
);
// Construct tensors
auto
quantizer_cpp
=
convert_quantizer
(
quantizer_py
);
auto
[
grad_input_nvte
,
grad_input_py
]
=
quantizer_cpp
->
create_tensor
(
input_shape
,
grad_output_dtype
);
const
int64_t
bias_size
=
static_cast
<
int64_t
>
(
input_shape
.
back
());
auto
grad_bias_torch
=
allocateTorchTensor
(
bias_size
,
grad_output_dtype
);
auto
grad_bias_nvte
=
makeTransformerEngineTensor
(
grad_bias_torch
);
// Return immediately if tensors are empty
if
(
product
(
output_shape
)
==
0
)
{
grad_bias_torch
.
zero_
();
return
{
py
::
cast
(
std
::
move
(
grad_bias_torch
)),
std
::
move
(
grad_input_py
)};
}
// Choose implementation
enum
class
Impl
{
UNFUSED
,
FUSED_DACT_DBIAS_QUANTIZE
,
FUSED_DACT_AMAX
};
Impl
impl
=
Impl
::
UNFUSED
;
if
(
detail
::
IsFloat8Quantizers
(
quantizer_py
.
ptr
())
||
detail
::
IsMXFP8Quantizers
(
quantizer_py
.
ptr
()))
{
impl
=
Impl
::
FUSED_DACT_DBIAS_QUANTIZE
;
}
else
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer_py
.
ptr
()))
{
impl
=
Impl
::
FUSED_DACT_AMAX
;
}
// Perform compute
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
impl
)
{
case
Impl
::
UNFUSED
:
// Unfused dact, dbias, quantize
{
auto
[
temp_nvte
,
temp_py
]
=
NoneQuantizer
(
py
::
none
()).
create_tensor
(
input_shape
,
grad_output_dtype
);
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_nvte
.
data
(),
act_input_nvte
.
data
(),
temp_nvte
.
data
(),
stream
);
});
const
auto
temp_torch
=
temp_py
.
cast
<
at
::
Tensor
>
();
at
::
sum_out
(
grad_bias_torch
,
temp_torch
.
reshape
({
-
1
,
bias_size
}),
{
0
});
quantizer_cpp
->
quantize
(
temp_nvte
,
grad_input_nvte
);
break
;
}
case
Impl
::
FUSED_DACT_DBIAS_QUANTIZE
:
// Fused dact-dbias-quantize kernel
{
// Query workspace size
TensorWrapper
workspace_nvte
;
NVTE_SCOPED_GIL_RELEASE
({
dact_dbias_func
(
grad_output_nvte
.
data
(),
act_input_nvte
.
data
(),
grad_input_nvte
.
data
(),
grad_bias_nvte
.
data
(),
workspace_nvte
.
data
(),
stream
);
});
// Allocate workspace
at
::
Tensor
workspace_torch
;
if
(
workspace_nvte
.
ndim
()
>
0
&&
workspace_nvte
.
numel
()
>
0
)
{
workspace_torch
=
allocateSpace
(
workspace_nvte
.
shape
(),
workspace_nvte
.
dtype
());
workspace_nvte
=
makeTransformerEngineTensor
(
workspace_torch
.
data_ptr
(),
workspace_nvte
.
shape
(),
workspace_nvte
.
dtype
());
}
// Launch kernel
NVTE_SCOPED_GIL_RELEASE
({
dact_dbias_func
(
grad_output_nvte
.
data
(),
act_input_nvte
.
data
(),
grad_input_nvte
.
data
(),
grad_bias_nvte
.
data
(),
workspace_nvte
.
data
(),
stream
);
});
break
;
}
case
Impl
::
FUSED_DACT_AMAX
:
// Fused dact-amax kernel, unfused dbias and quantize
{
auto
*
quantizer_cpp_cs
=
dynamic_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
NVTE_CHECK
(
quantizer_cpp_cs
!=
nullptr
,
"Invalid quantizer for fused dact-amax kernel impl"
);
auto
[
temp_nvte
,
temp_py
]
=
quantizer_cpp_cs
->
create_hp_tensor_with_amax
(
input_shape
,
grad_output_dtype
);
NVTE_SCOPED_GIL_RELEASE
({
dact_func
(
grad_output_nvte
.
data
(),
act_input_nvte
.
data
(),
temp_nvte
.
data
(),
stream
);
});
const
auto
temp_torch
=
temp_py
.
cast
<
at
::
Tensor
>
();
at
::
sum_out
(
grad_bias_torch
,
temp_torch
.
reshape
({
-
1
,
bias_size
}),
{
0
});
quantizer_cpp_cs
->
quantize_with_amax
(
temp_nvte
,
grad_input_nvte
);
break
;
}
default:
NVTE_ERROR
(
"Invalid implementation"
);
}
return
{
py
::
cast
(
std
::
move
(
grad_bias_torch
)),
std
::
move
(
grad_input_py
)};
}
}
// namespace
std
::
vector
<
py
::
object
>
dbias_dgelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dact_dbias
(
nvte_quantize_dbias_dgelu
,
nvte_dgelu
,
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_dsilu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dact_dbias
(
nvte_quantize_dbias_dsilu
,
nvte_dsilu
,
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_drelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dact_dbias
(
nvte_quantize_dbias_drelu
,
nvte_drelu
,
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_dqgelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dact_dbias
(
nvte_quantize_dbias_dqgelu
,
nvte_dqgelu
,
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_dsrelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dact_dbias
(
nvte_quantize_dbias_dsrelu
,
nvte_dsrelu
,
grad_output
,
act_input
,
quantizer
);
}
}
// namespace transformer_engine::pytorch
}
// namespace pytorch
}
// namespace transformer_engine
transformer_engine/pytorch/csrc/extensions/cast.cpp
View file @
87e3e56e
...
...
@@ -28,60 +28,6 @@ std::vector<size_t> get_tensor_shape(const TensorWrapper &tensor) {
return
std
::
vector
<
size_t
>
(
shape
.
data
,
shape
.
data
+
shape
.
ndim
);
}
void
quantize_impl
(
const
TensorWrapper
&
input
,
py
::
handle
&
quantizer_py
,
std
::
unique_ptr
<
Quantizer
>
&
quantizer_cpp
,
TensorWrapper
&
output
,
TensorWrapper
&
noop_flag
)
{
// Check tensor dims
NVTE_CHECK
(
get_tensor_shape
(
input
)
==
get_tensor_shape
(
output
),
"Input tensor (shape="
,
get_tensor_shape
(
input
),
") and output tensor (shape="
,
get_tensor_shape
(
output
),
") do not match"
);
if
(
input
.
numel
()
==
0
)
{
return
;
}
// Recipe-specific configuration
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_noop_tensor
(
noop_flag
.
data
());
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer_py
.
ptr
()))
{
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
quantizer_cpp
.
get
());
NVTE_SCOPED_GIL_RELEASE
(
{
nvte_compute_amax
(
input
.
data
(),
output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// check if we need to do amax reudction (depending on model parallel configs)
if
(
my_quantizer_cs
->
with_amax_reduction
)
{
c10
::
intrusive_ptr
<
dist_group_type
>
process_group_ptr
=
my_quantizer_cs
->
amax_reduction_group
;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at
::
Tensor
&
amax_tensor_torch
=
my_quantizer_cs
->
amax
;
std
::
vector
<
at
::
Tensor
>
tensors
=
{
amax_tensor_torch
};
// allreduce amax tensor
c10d
::
AllreduceOptions
allreduce_opts
;
allreduce_opts
.
reduceOp
=
c10d
::
ReduceOp
::
MAX
;
process_group_ptr
->
allreduce
(
tensors
,
allreduce_opts
)
->
wait
();
}
// this config is used for cs scaling factor computation
// because compute scale is cannot be fused with quantize kernel
// so in nvte_quantize_v2 with current scaling, the quant config is not used again
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_scale_from_amax
(
output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// set amax ptr to null in output TensorWrapper to avoid atomic amax updates in kernel
output
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
output
.
defaultShape
);
}
else
if
(
detail
::
IsFloat8BlockwiseQuantizers
(
quantizer_py
.
ptr
()))
{
auto
my_quantizer_bw
=
static_cast
<
Float8BlockQuantizer
*>
(
quantizer_cpp
.
get
());
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
if
(
my_quantizer_bw
->
all_gather_usage
)
{
quant_config
.
set_float8_block_scale_tensor_format
(
Float8BlockScaleTensorFormat
::
COMPACT
);
}
}
// Perform quantization
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
input
.
data
(),
output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
}
// namespace
py
::
object
quantize
(
const
at
::
Tensor
&
tensor
,
py
::
handle
quantizer
,
const
py
::
object
&
output
,
...
...
@@ -101,18 +47,17 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
const
auto
fake_dtype
=
input_cpp
.
dtype
();
std
::
tie
(
output_cpp
,
output_py
)
=
quantizer_cpp
->
create_tensor
(
shape
,
fake_dtype
);
}
else
{
output_py
=
output
;
output_cpp
=
makeTransformerEngineTensor
(
output_py
,
quantizer
);
std
::
tie
(
output_cpp
,
output_py
)
=
quantizer_cpp
->
convert_and_update_tensor
(
output
);
}
// Initialize no-op flag
TensorWrapper
noop_flag_cpp
;
std
::
optional
<
TensorWrapper
>
noop_flag_cpp
;
if
(
noop_flag
.
has_value
())
{
noop_flag_cpp
=
makeTransformerEngineTensor
(
*
noop_flag
);
}
// Perform quantization
quantize
_impl
(
input
_cpp
,
quantize
r
,
quantizer
_cpp
,
output_cpp
,
noop_flag_cpp
);
quantize
r
_cpp
->
quantize
(
input
_cpp
,
output_cpp
,
noop_flag_cpp
);
return
output_py
;
}
...
...
@@ -182,10 +127,8 @@ void multi_tensor_quantize_impl(const std::vector<TensorWrapper> &input_list,
});
}
else
{
// Quantize kernels individually
TensorWrapper
dummy_noop_flag
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
quantize_impl
(
input_list
[
i
],
quantizer_py_list
[
i
],
quantizer_cpp_list
[
i
],
output_list
[
i
],
dummy_noop_flag
);
quantizer_cpp_list
[
i
]
->
quantize
(
input_list
[
i
],
output_list
[
i
]);
}
}
}
...
...
@@ -455,11 +398,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
}
// Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto
buffer
=
std
::
make_shared
<
at
::
Tensor
>
(
at
::
zeros
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
at
::
empty
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
)));
// Construct tensor views
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
...
...
@@ -498,11 +438,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
}
// Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto
buffer
=
std
::
make_shared
<
at
::
Tensor
>
(
at
::
zeros
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
at
::
empty
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
)));
// Construct tensor views
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
...
...
@@ -650,66 +587,5 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
return
output_py_list
;
}
template
<
void
(
*
func
)(
const
NVTETensor
,
const
NVTETensor
,
NVTETensor
,
NVTETensor
,
NVTETensor
,
cudaStream_t
)>
std
::
vector
<
py
::
object
>
dbias_dact
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
init_extension
();
auto
my_quantizer
=
convert_quantizer
(
quantizer
);
auto
grad_tensor
=
makeTransformerEngineTensor
(
grad_output
);
auto
grad_bias
=
allocateTorchTensor
(
grad_output
.
size
(
-
1
),
grad_tensor
.
dtype
());
auto
act_input_tensor
=
makeTransformerEngineTensor
(
act_input
);
const
auto
&
shape
=
convertShape
(
grad_tensor
.
shape
());
auto
[
dact_tensor
,
dact
]
=
my_quantizer
->
create_tensor
(
shape
,
act_input_tensor
.
dtype
());
auto
dbias_tensor
=
makeTransformerEngineTensor
(
grad_bias
);
// Query workspace size and allocate workspace
transformer_engine
::
TensorWrapper
workspace
;
NVTE_SCOPED_GIL_RELEASE
({
func
(
grad_tensor
.
data
(),
act_input_tensor
.
data
(),
dact_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
auto
workspace_data
=
allocateSpace
(
workspace
.
shape
(),
workspace
.
dtype
());
workspace
=
makeTransformerEngineTensor
(
workspace_data
.
data_ptr
(),
workspace
.
shape
(),
workspace
.
dtype
());
// Launch kernel
NVTE_SCOPED_GIL_RELEASE
({
func
(
grad_tensor
.
data
(),
act_input_tensor
.
data
(),
dact_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
return
{
py
::
cast
(
grad_bias
),
dact
};
}
std
::
vector
<
py
::
object
>
dbias_dgelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dbias_dact
<
nvte_quantize_dbias_dgelu
>
(
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_dsilu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dbias_dact
<
nvte_quantize_dbias_dsilu
>
(
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_drelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dbias_dact
<
nvte_quantize_dbias_drelu
>
(
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_dqgelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dbias_dact
<
nvte_quantize_dbias_dqgelu
>
(
grad_output
,
act_input
,
quantizer
);
}
std
::
vector
<
py
::
object
>
dbias_dsrelu
(
const
at
::
Tensor
&
grad_output
,
const
at
::
Tensor
&
act_input
,
py
::
handle
quantizer
)
{
return
dbias_dact
<
nvte_quantize_dbias_dsrelu
>
(
grad_output
,
act_input
,
quantizer
);
}
}
// namespace pytorch
}
// namespace transformer_engine
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
View file @
87e3e56e
...
...
@@ -216,8 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<i
return
torch
::
from_blob
(
ubuf_ptr
,
*
shape
,
at
::
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
}
at
::
Stream
CommOverlap
::
get_communication_stream
()
{
return
at
::
cuda
::
getStreamFromExternal
(
_stream_comm
,
at
::
cuda
::
current_device
());
std
::
pair
<
at
::
Stream
,
at
::
Stream
>
CommOverlap
::
get_communication_stream
()
{
// Return the same stream for both send and recv
return
{
at
::
cuda
::
getStreamFromExternal
(
_stream_comm
,
at
::
cuda
::
current_device
()),
at
::
cuda
::
getStreamFromExternal
(
_stream_comm
,
at
::
cuda
::
current_device
())};
}
/***************************************************************************************************
...
...
@@ -305,6 +307,14 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto
return
torch
::
from_blob
(
ubuf_ptr
,
*
shape
,
at
::
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
}
at
::
Stream
CommOverlapP2P
::
get_communication_stream
()
{
return
at
::
cuda
::
getStreamFromExternal
(
_stream_recv
,
at
::
cuda
::
current_device
());
std
::
pair
<
at
::
Stream
,
at
::
Stream
>
CommOverlapP2P
::
get_communication_stream
()
{
return
{
at
::
cuda
::
getStreamFromExternal
(
_stream_send
[
0
],
at
::
cuda
::
current_device
()),
at
::
cuda
::
getStreamFromExternal
(
_stream_recv
,
at
::
cuda
::
current_device
())};
}
void
transformer_engine
::
pytorch
::
bulk_overlap_ag_with_external_gemm
(
CommOverlap
&
allgather_communicator
,
at
::
Stream
send_stream
,
at
::
Stream
recv_stream
)
{
auto
main_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
allgather_communicator
.
bulk_overlap_external_ag
(
at
::
cuda
::
CUDAStream
(
send_stream
),
at
::
cuda
::
CUDAStream
(
recv_stream
),
main_stream
);
}
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
87e3e56e
...
...
@@ -94,7 +94,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapCore
*
comm_overlap
,
std
::
optional
<
CommOverlapType
>
comm_type
,
MaybeTensor
extra_output
,
bool
bulk_overlap
)
{
bool
bulk_overlap
,
float
alpha
,
std
::
optional
<
float
>
beta
)
{
// Input tensors
NVTE_CHECK
(
!
A
.
is_none
(),
"Tensor A has not been provided"
);
NVTE_CHECK
(
!
B
.
is_none
(),
"Tensor B has not been provided"
);
...
...
@@ -112,6 +112,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
NVTE_CHECK
(
A_shape
.
ndim
>=
1
,
"Tensor A needs to have at least 1 dimension"
);
NVTE_CHECK
(
B_shape
.
ndim
>=
1
,
"Tensor B needs to have at least 1 dimension"
);
// Check scaling factors
if
(
accumulate
)
{
if
(
!
beta
)
{
beta
=
1.0
f
;
}
}
else
{
if
(
!
beta
)
{
beta
=
0.0
f
;
}
NVTE_CHECK
(
beta
==
0.0
,
"Trying to use non-zero beta while not accumulating "
,
"into D tensor. Beta has nothing to be applied to."
);
}
// Output tensor
TensorWrapper
D_tensor
;
if
(
D
.
is_none
())
{
...
...
@@ -240,9 +253,10 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
}
else
{
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_gemm
(
A_tensor
.
data
(),
B_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sms
,
main_stream
);
nvte_cublas_gemm_scaled
(
A_tensor
.
data
(),
B_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
alpha
,
*
beta
,
use_split_accumulator
,
num_math_sms
,
main_stream
);
});
}
}
else
{
...
...
@@ -328,10 +342,8 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
)
{
std
::
vector
<
NVTETensor
>
te_A_vector
,
te_B_vector
,
te_D_vector
,
te_bias_vector
,
te_pre_gelu_out_vector
,
te_workspace_vector
;
std
::
vector
<
TensorWrapper
>
wrappers
;
std
::
vector
<
TensorWrapper
>
te_A_wrappers
,
te_B_wrappers
,
wrappers
;
std
::
vector
<
at
::
Tensor
>
D_vectors
;
// Keep the swizzled scaling factor tensors alive during the GEMMs.
std
::
vector
<
std
::
optional
<
at
::
Tensor
>>
swizzled_scale_inverses_list
;
auto
none
=
py
::
none
();
...
...
@@ -398,10 +410,6 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
continue
;
}
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
te_A
,
transa
)));
swizzled_scale_inverses_list
.
emplace_back
(
std
::
move
(
swizzle_scaling_factors
(
te_B
,
!
transb
)));
auto
te_D
=
makeTransformerEngineTensor
(
out_tensor
);
auto
te_bias
=
makeTransformerEngineTensor
(
bias
[
i
]);
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
pre_gelu_out
[
i
]);
...
...
@@ -421,18 +429,25 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_bias_vector
.
emplace_back
(
te_bias
.
data
());
te_pre_gelu_out_vector
.
emplace_back
(
te_pre_gelu_out
.
data
());
wrappers
.
emplace_back
(
std
::
move
(
te_A
));
wrappers
.
emplace_back
(
std
::
move
(
te_B
));
te_A_
wrappers
.
emplace_back
(
std
::
move
(
te_A
));
te_B_
wrappers
.
emplace_back
(
std
::
move
(
te_B
));
wrappers
.
emplace_back
(
std
::
move
(
te_D
));
wrappers
.
emplace_back
(
std
::
move
(
te_bias
));
wrappers
.
emplace_back
(
std
::
move
(
te_pre_gelu_out
));
}
// Optionally swizzle the scaling factors
// Keep the swizzled scaling factor tensors alive during the GEMMs.
auto
swizzled_scale_inv_A
=
multi_tensor_swizzle_scaling_factors
(
te_A_wrappers
,
transa
);
auto
swizzled_scale_inv_B
=
multi_tensor_swizzle_scaling_factors
(
te_B_wrappers
,
!
transb
);
for
(
size_t
i
=
0
;
i
<
workspace
.
size
();
i
++
)
{
auto
wsp
=
makeTransformerEngineTensor
(
workspace
[
i
].
data_ptr
(),
std
::
vector
<
size_t
>
{
workspaceSize
},
DType
::
kByte
);
te_workspace_vector
.
emplace_back
(
wsp
.
data
());
wrappers
.
emplace_back
(
std
::
move
(
wsp
));
}
// For now, we only have multi-stream cublas backend.
const
char
*
NVTE_USE_HIPBLASLT_GROUPEDGEMM
=
std
::
getenv
(
"NVTE_USE_HIPBLASLT_GROUPEDGEMM"
);
if
(
NVTE_USE_HIPBLASLT_GROUPEDGEMM
!=
nullptr
&&
NVTE_USE_HIPBLASLT_GROUPEDGEMM
[
0
]
==
'1'
){
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp
View file @
87e3e56e
...
...
@@ -16,11 +16,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_adam_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
weight_decay
,
at
::
cuda
::
getCurrentCUDAStream
());
}
void
multi_tensor_adam_param_remainder_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
...
...
@@ -31,12 +30,10 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_adam_param_remainder_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
at
::
cuda
::
getCurrentCUDAStream
());
}
void
multi_tensor_adam_fp8_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
...
...
@@ -47,12 +44,11 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_adam_fp8_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
static_cast
<
NVTEDType
>
(
fp8_dtype
),
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
at
::
cuda
::
getCurrentCUDAStream
());
}
void
multi_tensor_adam_capturable_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
...
...
@@ -67,12 +63,11 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
auto
lr_cu
=
makeTransformerEngineTensor
(
lr
);
auto
step_cu
=
makeTransformerEngineTensor
(
step
);
auto
inv_scale_cu
=
makeTransformerEngineTensor
(
inv_scale
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_adam_capturable_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr_cu
.
data
(),
beta1
,
beta2
,
epsilon
,
step_cu
.
data
(),
mode
,
bias_correction
,
weight_decay
,
inv_scale_cu
.
data
(),
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
inv_scale_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
}
void
multi_tensor_adam_capturable_master_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
...
...
@@ -87,12 +82,11 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl
auto
lr_cu
=
makeTransformerEngineTensor
(
lr
);
auto
step_cu
=
makeTransformerEngineTensor
(
step
);
auto
inv_scale_cu
=
makeTransformerEngineTensor
(
inv_scale
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_adam_capturable_master_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr_cu
.
data
(),
beta1
,
beta2
,
epsilon
,
step_cu
.
data
(),
mode
,
bias_correction
,
weight_decay
,
inv_scale_cu
.
data
(),
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
inv_scale_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
View file @
87e3e56e
...
...
@@ -14,11 +14,10 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_compute_scale_and_scale_inv_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
max_fp8
,
force_pow_2_scales
,
epsilon
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
force_pow_2_scales
,
epsilon
,
at
::
cuda
::
getCurrentCUDAStream
());
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
View file @
87e3e56e
...
...
@@ -43,12 +43,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
auto
output_per_tensor_cu
=
makeTransformerEngineTensor
(
output_per_tensor
);
auto
ret_cu
=
makeTransformerEngineTensor
(
ret
);
auto
ret_per_tensor_cu
=
makeTransformerEngineTensor
(
ret_per_tensor
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
output_cu
.
data
(),
output_per_tensor_cu
.
data
(),
ret_cu
.
data
(),
ret_per_tensor_cu
.
data
(),
per_tensor
,
max_chunks_per_tensor
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
max_chunks_per_tensor
,
at
::
cuda
::
getCurrentCUDAStream
());
return
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
(
ret
,
ret_per_tensor
);
}
...
...
@@ -91,13 +90,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
auto
ret_cu
=
makeTransformerEngineTensor
(
ret
);
auto
ret_per_tensor_cu
=
makeTransformerEngineTensor
(
ret_per_tensor
);
auto
inv_scale_cu
=
makeTransformerEngineTensor
(
inv_scale
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_unscale_l2norm_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
output_cu
.
data
(),
output_per_tensor_cu
.
data
(),
ret_cu
.
data
(),
ret_per_tensor_cu
.
data
(),
inv_scale_cu
.
data
(),
per_tensor
,
max_chunks_per_tensor
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
inv_scale_cu
.
data
(),
per_tensor
,
max_chunks_per_tensor
,
at
::
cuda
::
getCurrentCUDAStream
());
return
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
(
ret
,
ret_per_tensor
);
}
...
...
transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
View file @
87e3e56e
...
...
@@ -13,10 +13,9 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_scale_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
scale
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
num_tensors
,
scale
,
at
::
cuda
::
getCurrentCUDAStream
());
}
}
// namespace transformer_engine::pytorch
Prev
1
…
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment