Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
2b05e121
Commit
2b05e121
authored
Jun 17, 2025
by
yuguo
Browse files
Merge commit '
a69692ac
' of...
Merge commit '
a69692ac
' of
https://github.com/NVIDIA/TransformerEngine
parents
0fd441c2
a69692ac
Changes
245
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1769 additions
and
432 deletions
+1769
-432
transformer_engine/jax/dense.py
transformer_engine/jax/dense.py
+203
-98
transformer_engine/jax/flax/transformer.py
transformer_engine/jax/flax/transformer.py
+12
-3
transformer_engine/jax/pyproject.toml
transformer_engine/jax/pyproject.toml
+10
-0
transformer_engine/jax/quantize/dequantizer.py
transformer_engine/jax/quantize/dequantizer.py
+157
-40
transformer_engine/jax/quantize/quantizer.py
transformer_engine/jax/quantize/quantizer.py
+237
-49
transformer_engine/jax/quantize/scaling_modes.py
transformer_engine/jax/quantize/scaling_modes.py
+173
-2
transformer_engine/jax/quantize/tensor.py
transformer_engine/jax/quantize/tensor.py
+211
-23
transformer_engine/jax/setup.py
transformer_engine/jax/setup.py
+4
-16
transformer_engine/jax/sharding.py
transformer_engine/jax/sharding.py
+26
-0
transformer_engine/pytorch/attention/dot_product_attention/backends.py
...ngine/pytorch/attention/dot_product_attention/backends.py
+6
-1
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
...torch/attention/dot_product_attention/context_parallel.py
+509
-154
transformer_engine/pytorch/attention/dot_product_attention/utils.py
...r_engine/pytorch/attention/dot_product_attention/utils.py
+34
-25
transformer_engine/pytorch/attention/multi_head_attention.py
transformer_engine/pytorch/attention/multi_head_attention.py
+60
-0
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+8
-0
transformer_engine/pytorch/cpu_offload.py
transformer_engine/pytorch/cpu_offload.py
+83
-7
transformer_engine/pytorch/csrc/common.cpp
transformer_engine/pytorch/csrc/common.cpp
+14
-0
transformer_engine/pytorch/csrc/common.h
transformer_engine/pytorch/csrc/common.h
+11
-5
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+9
-7
transformer_engine/pytorch/csrc/extensions/activation.cpp
transformer_engine/pytorch/csrc/extensions/activation.cpp
+1
-1
transformer_engine/pytorch/csrc/extensions/apply_rope.cpp
transformer_engine/pytorch/csrc/extensions/apply_rope.cpp
+1
-1
No files found.
transformer_engine/jax/dense.py
View file @
2b05e121
...
@@ -153,28 +153,28 @@ def _dense_bwd_rule(
...
@@ -153,28 +153,28 @@ def _dense_bwd_rule(
# GEMM NT
# GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_con
s
tracting_dim
=
tuple
(
g_contracting_dim
=
tuple
(
range
(
grad
.
ndim
-
len
(
kernel_shape
)
+
len
(
fwd_k_contracting_dims
),
grad
.
ndim
)
range
(
grad
.
ndim
-
len
(
kernel_shape
)
+
len
(
fwd_k_contracting_dims
),
grad
.
ndim
)
)
)
# k_non_contracting_dims
# k_non_contracting_dims
k_con
s
tracting_dim
=
tuple
(
k_contracting_dim
=
tuple
(
dim
for
dim
in
range
(
len
(
kernel_shape
))
if
dim
not
in
fwd_k_contracting_dims
dim
for
dim
in
range
(
len
(
kernel_shape
))
if
dim
not
in
fwd_k_contracting_dims
)
)
dgrad
=
tex
.
gemm
(
dgrad
=
tex
.
gemm
(
casted_grad
.
get_rowwise_tensor
(),
casted_grad
.
get_rowwise_tensor
(),
rowwise_casted_kernel
,
rowwise_casted_kernel
,
(
g_con
s
tracting_dim
,
k_con
s
tracting_dim
),
(
g_contracting_dim
,
k_contracting_dim
),
)
)
dgrad
=
with_sharding_constraint_by_logical_axes
(
dgrad
,
input_axes
)
dgrad
=
with_sharding_constraint_by_logical_axes
(
dgrad
,
input_axes
)
# GEMM TN
# GEMM TN
# x_non_contracting_dims
# x_non_contracting_dims
g_con
s
tracting_dim
=
x_con
s
tracting_dim
=
tuple
(
g_contracting_dim
=
x_contracting_dim
=
tuple
(
range
(
0
,
len
(
x_shape
)
-
len
(
fwd_x_contracting_dims
))
range
(
0
,
len
(
x_shape
)
-
len
(
fwd_x_contracting_dims
))
)
)
wgrad
=
tex
.
gemm
(
wgrad
=
tex
.
gemm
(
colwise_casted_x
,
casted_grad
.
get_colwise_tensor
(),
(
x_con
s
tracting_dim
,
g_con
s
tracting_dim
)
colwise_casted_x
,
casted_grad
.
get_colwise_tensor
(),
(
x_contracting_dim
,
g_contracting_dim
)
)
)
wgrad
=
with_sharding_constraint_by_logical_axes
(
wgrad
,
kernel_axes
)
wgrad
=
with_sharding_constraint_by_logical_axes
(
wgrad
,
kernel_axes
)
...
@@ -184,135 +184,240 @@ def _dense_bwd_rule(
...
@@ -184,135 +184,240 @@ def _dense_bwd_rule(
_dense
.
defvjp
(
_dense_fwd_rule
,
_dense_bwd_rule
)
_dense
.
defvjp
(
_dense_fwd_rule
,
_dense_bwd_rule
)
"""
def
grouped_dense
(
def
grouped_dense
(
x_list,
x
:
jnp
.
ndarray
,
kernel_list,
kernel
:
jnp
.
ndarray
,
bias_list,
group_sizes
:
jnp
.
ndarray
,
contracting_dims_list,
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((
1
,),
(
1
,)),
quantizer_set_list=None,
bias
:
jnp
.
ndarray
=
None
,
precision
:
jax
.
lax
.
Precision
=
jax
.
lax
.
Precision
.
DEFAULT
,
preferred_element_type
:
jnp
.
dtype
=
None
,
group_offset
:
jnp
.
array
=
None
,
quantizer_set
:
QuantizerSet
=
noop_quantizer_set
,
):
):
# Perform grouped_dense layer transformation with optional quantization.
"""
Perform grouped dense (linear) layer transformation with optional quantization.
output_list = _grouped_dense(
Args:
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
x: Input tensor of shape (M, K)
kernel: Weight matrix of shape (G, K, N)
group_sizes: 1D array of shape (G,) specifying the size of each group
contracting_dims: Tuple of sequences specifying which dimensions to contract
(currently only supports ((1,), (1,)))
bias: Bias tensor of shape (G, N)
precision: JAX precision for the GEMM operation
preferred_element_type: Preferred data type for the output tensor
group_offset: 1D array containing offsets for each group (not yet implemented)
quantizer_set: Set of quantizers for FP8 quantization of the input and output
Returns:
A jnp.ndarray containing the result of the grouped linear operation
"""
output
=
_grouped_dense
(
x
,
kernel
,
group_sizes
,
contracting_dims
,
bias
,
precision
,
preferred_element_type
,
group_offset
,
quantizer_set
,
)
)
return output
_list
return
output
@partial(jax.custom_vjp, nondiff_argnums=(3,))
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
3
,
5
,
6
,
7
))
def _grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list):
def
_grouped_dense
(
output_list, _ = _grouped_dense_fwd_rule(
x
,
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
kernel
,
group_sizes
,
contracting_dims
,
bias
,
precision
,
preferred_element_type
,
group_offset
,
quantizer_set
,
):
output
,
_
=
_grouped_dense_fwd_rule
(
x
,
kernel
,
group_sizes
,
contracting_dims
,
bias
,
precision
,
preferred_element_type
,
group_offset
,
quantizer_set
,
)
)
return output
_list
return
output
def
_grouped_dense_fwd_rule
(
def
_grouped_dense_fwd_rule
(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
x
,
kernel
,
group_sizes
,
contracting_dims
,
bias
,
precision
,
preferred_element_type
,
group_offset
,
quantizer_set
,
):
):
use_bias = bias_list is not None
use_bias
=
bias
is
not
None
output_list = []
is_noop_quantizer_set
=
quantizer_set
==
noop_quantizer_set
x_rowwise_list = []
x_colwise_list = []
if
is_noop_quantizer_set
:
kernel_colwise_list = []
grouped_gemm_x
=
x
kernel_rowwise_list = []
grouped_gemm_kernel
=
kernel
x_shape_list = []
ctx_x
=
x
kernel_shape_list = []
ctx_kernel
=
kernel
if quantizer_set_list is None:
flatten_axis_k
=
None
x_rowwise_list = x_list
x_colwise_list = x_list
kernel_colwise_list = kernel_list
kernel_rowwise_list = kernel_list
x_shape_list = [x.shape for x in x_list]
kernel_shape_list = [kernel.shape for kernel in kernel_list]
else
:
else
:
for i in range(len(x_list)): # pylint: disable=consider-using-enumerate
x_contracting_dims
,
k_contracting_dims
=
contracting_dims
q_x = tex.quantize(x_list[i], quantizer_set_list[i].x)
flatten_axis_x
=
-
len
(
x_contracting_dims
)
q_kernel = tex.quantize(kernel_list[i], quantizer_set_list[i].kernel)
flatten_axis_k
=
len
(
k_contracting_dims
)
-
len
(
kernel
.
shape
)
+
1
# +1 for G axis
x_rowwise_list.append(q_x.get_rowwise_tensor())
x_colwise_list.append(q_x.get_colwise_tensor())
assert
x
.
ndim
==
2
,
"Grouped dense expects a 2D input tensor of shape (M, K)"
kernel_colwise_list.append(q_kernel.get_colwise_tensor())
assert
kernel
.
ndim
==
3
,
"Grouped dense expects a 3D kernel tensor of shape (G, K, N)"
kernel_rowwise_list.append(q_kernel.get_rowwise_tensor())
# Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose
x_shape_list.append(x_rowwise_list[-1].data.shape)
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
kernel_shape_list.append(kernel_rowwise_list[-1].data.shape)
assert
x_contracting_dims
==
(
1
,)
and
k_contracting_dims
==
(
1
,),
(
"grouped_dense for FP8 can only handle x_contracting_dims=(1,) "
output_list = tex.grouped_gemm(
"and k_contracting_dims=(1,) for now, "
x_rowwise_list, kernel_colwise_list, contracting_dims_list, bias_list
f
"got
{
x_contracting_dims
=
}
and
{
k_contracting_dims
=
}
"
)
k_contracting_dims
=
(
0
,)
casted_x
=
tex
.
grouped_quantize
(
x
,
quantizer_set
.
x
,
group_sizes
,
flatten_axis
=
flatten_axis_x
)
casted_kernel
=
tex
.
grouped_quantize
(
kernel
,
quantizer_set
.
kernel
,
flatten_axis
=
flatten_axis_k
)
contracting_dims
=
(
x_contracting_dims
,
k_contracting_dims
)
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
# rowwise_casted_x.original_shape == (M, K)
# colwise_casted_kernel.original_shape == (G, N, K)
grouped_gemm_x
=
casted_x
.
get_rowwise_tensor
()
grouped_gemm_kernel
=
casted_kernel
.
get_colwise_tensor
()
# TODO(Hua): Shall we give warning/error if not quantizer_set.x.is_2x2x()?
ctx_x
=
casted_x
.
get_colwise_tensor
()
if
quantizer_set
.
x
.
is_2x2x
()
else
None
ctx_kernel
=
casted_kernel
.
get_rowwise_tensor
()
if
quantizer_set
.
kernel
.
is_2x2x
()
else
None
output
=
tex
.
grouped_gemm
(
grouped_gemm_x
,
grouped_gemm_kernel
,
group_sizes
,
contracting_dims
,
bias
,
precision
,
preferred_element_type
,
group_offset
,
)
)
ctx
=
(
ctx
=
(
x_colwise_list,
group_sizes
,
kernel_rowwise_list,
ctx_x
,
x_shape_list,
ctx_kernel
,
kernel_shape_list,
x
.
shape
,
kernel
.
shape
,
use_bias
,
use_bias
,
quantizer_set_list,
is_noop_quantizer_set
,
quantizer_set
,
flatten_axis_k
,
)
)
return output_list, ctx
return
output
,
ctx
def
_grouped_dense_bwd_rule
(
contracting_dims
,
precision
,
preferred_element_type
,
group_offset
,
ctx
,
grad
):
fwd_x_contracting_dims
,
fwd_k_contracting_dims
=
contracting_dims
def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list):
(
(
colwise_x_list,
group_sizes
,
rowwise_kernel_list,
ctx_x
,
x_shape_list,
ctx_kernel
,
kernel_shape_list,
x_shape
,
kernel_shape
,
use_bias
,
use_bias
,
quantizer_set_list,
is_noop_quantizer_set
,
quantizer_set
,
flatten_axis_k
,
)
=
ctx
)
=
ctx
group_size = len(grad_list)
if
is_noop_quantizer_set
:
dbias_list = []
# The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?)
grad_rowwise_list = []
# g_contracting_dim = (1, )
grad_colwise_list = []
# k_contracting_dim = (2, )
dgrad_contracting_dims_list = []
wgrad_contracting_dims_list = []
for i in range(group_size):
grad = grad_list[i]
x_shape = x_shape_list[i]
kernel_shape = kernel_shape_list[i]
fwd_contracting_dims = contracting_dims_list[i]
if quantizer_set_list is None:
casted_grad = grad
dbias = tex.quantization._jax_dbias(grad)
grad_rowwise_list.append(grad)
grad_colwise_list.append(grad)
else:
quantizer_set = quantizer_set_list[i]
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad
)
grad_rowwise_list.append(casted_grad.get_rowwise_tensor())
grad_colwise_list.append(casted_grad.get_colwise_tensor())
dbias_list.append(dbias)
# GEMM NT
fwd_x_contracting_dims, fwd_k_contracting_dims = fwd_contracting_dims
g_contracting_dim
=
tuple
(
g_contracting_dim
=
tuple
(
range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
range
(
1
+
grad
.
ndim
-
len
(
kernel_shape
)
+
len
(
fwd_k_contracting_dims
),
grad
.
ndim
)
)
)
k_contracting_dim
=
tuple
(
k_contracting_dim
=
tuple
(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
dim
for
dim
in
range
(
1
,
len
(
kernel_shape
))
if
dim
not
in
fwd_k_contracting_dims
)
)
dgrad_contracting_dims
=
(
g_contracting_dim
,
k_contracting_dim
)
dgrad_contracting_dims
=
(
g_contracting_dim
,
k_contracting_dim
)
dgrad_contracting_dims_list.append(dgrad_contracting_dims)
dgrad_grad
=
grad
dgrad_kernel_T
=
ctx_kernel
# GEMM TN
# g_contracting_dim = (0, )
# x_contracting_dim = (0, )
g_contracting_dim
=
x_contracting_dim
=
tuple
(
g_contracting_dim
=
x_contracting_dim
=
tuple
(
range
(
0
,
len
(
x_shape
)
-
len
(
fwd_x_contracting_dims
))
range
(
0
,
len
(
x_shape
)
-
len
(
fwd_x_contracting_dims
))
)
)
wgrad_contracting_dims
=
(
x_contracting_dim
,
g_contracting_dim
)
wgrad_contracting_dims
=
(
x_contracting_dim
,
g_contracting_dim
)
wgrad_contracting_dims_list.append(wgrad_contracting_dims)
wgrad_x_T
=
ctx_x
wgrad_grad
=
grad
else
:
casted_grad
=
tex
.
grouped_quantize
(
grad
,
quantizer_set
.
dgrad
,
group_sizes
,
flatten_axis
=
flatten_axis_k
)
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use
# g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the
# extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim
=
(
1
,)
k_contracting_dim
=
(
2
,)
dgrad_contracting_dims
=
(
g_contracting_dim
,
k_contracting_dim
)
dgrad_grad
=
casted_grad
.
get_rowwise_tensor
()
dgrad_kernel_T
=
ctx_kernel
# We need to use g_contracting_dim = (0,) and x_contracting_dim = (1,) to make it work
# after the extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim
=
(
0
,)
x_contracting_dim
=
(
0
,)
wgrad_contracting_dims
=
(
x_contracting_dim
,
g_contracting_dim
)
wgrad_x_T
=
ctx_x
wgrad_grad
=
casted_grad
.
get_colwise_tensor
()
dgrad
=
tex
.
grouped_gemm
(
dgrad_grad
,
dgrad_kernel_T
,
group_sizes
,
dgrad_contracting_dims
,
precision
=
precision
,
preferred_element_type
=
preferred_element_type
,
group_offset
=
group_offset
,
)
dgrad_list = tex.grouped_gemm(
wgrad
=
tex
.
grouped_gemm
(
grad_rowwise_list, rowwise_kernel_list, dgrad_contracting_dims_list
wgrad_x_T
,
wgrad_grad
,
group_sizes
,
wgrad_contracting_dims
,
precision
=
precision
,
preferred_element_type
=
preferred_element_type
,
group_offset
=
group_offset
,
)
)
wgrad_list = tex.grouped_gemm(colwise_x_list, grad_colwise_list, wgrad_contracting_dims_list)
return dgrad_list, wgrad_list, dbias_list, quantizer_set_list
group_sizes_grad
=
None
dbias
=
tex
.
grouped_dbias
(
grad
,
group_sizes
)
if
use_bias
else
None
return
dgrad
,
wgrad
,
group_sizes_grad
,
dbias
,
quantizer_set
_grouped_dense
.
defvjp
(
_grouped_dense_fwd_rule
,
_grouped_dense_bwd_rule
)
_grouped_dense
.
defvjp
(
_grouped_dense_fwd_rule
,
_grouped_dense_bwd_rule
)
"""
transformer_engine/jax/flax/transformer.py
View file @
2b05e121
...
@@ -594,8 +594,16 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -594,8 +594,16 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
seqlen_kv
=
seqlen_q
seqlen_kv
=
seqlen_q
else
:
else
:
seqlen_kv
=
key
.
shape
[
sequence_dim
]
seqlen_kv
=
key
.
shape
[
sequence_dim
]
if
qkv_layout
.
is_separate
():
head_dim_qk
=
query
.
shape
[
-
1
]
head_dim_v
=
value
.
shape
[
-
1
]
else
:
head_dim_qk
=
self
.
head_dim
head_dim_v
=
self
.
head_dim
has_fused_attn_kernel
=
is_fused_attn_kernel_available
(
has_fused_attn_kernel
=
is_fused_attn_kernel_available
(
# This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
not
deterministic
,
self
.
dtype
,
self
.
dtype
,
self
.
dtype
,
self
.
dtype
,
qkv_layout
,
qkv_layout
,
...
@@ -606,7 +614,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -606,7 +614,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
self
.
num_gqa_groups
,
self
.
num_gqa_groups
,
seqlen_q
,
seqlen_q
,
seqlen_kv
,
seqlen_kv
,
self
.
head_dim
,
head_dim_qk
,
head_dim_v
,
self
.
window_size
,
self
.
window_size
,
)
)
...
@@ -619,7 +628,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -619,7 +628,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
"Please try to update the cuDNN and TE to the latest version.
\n
"
"Please try to update the cuDNN and TE to the latest version.
\n
"
f
"
{
self
.
dtype
=
}
\n
{
qkv_layout
=
}
\n
{
attn_bias_type
=
}
\n
{
attn_mask_type
=
}
\n
"
f
"
{
self
.
dtype
=
}
\n
{
qkv_layout
=
}
\n
{
attn_bias_type
=
}
\n
{
attn_mask_type
=
}
\n
"
f
"
{
self
.
attention_dropout
=
}
\n
{
self
.
num_attention_heads
=
}
\n
"
f
"
{
self
.
attention_dropout
=
}
\n
{
self
.
num_attention_heads
=
}
\n
"
f
"
{
self
.
num_gqa_groups
=
}
\n
{
seqlen_q
=
}
\n
{
seqlen_kv
=
}
\n
{
self
.
head_dim
=
}
\n
"
f
"
{
self
.
num_gqa_groups
=
}
\n
{
seqlen_q
=
}
\n
{
seqlen_kv
=
}
\n
{
head_dim_qk
=
}
\n
{
head_dim
_v
=
}
\n
"
)
)
dropout_rng
=
None
dropout_rng
=
None
...
@@ -627,7 +636,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -627,7 +636,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
dropout_rng
=
self
.
make_rng
(
self
.
dropout_rng_name
)
dropout_rng
=
self
.
make_rng
(
self
.
dropout_rng_name
)
if
self
.
scale_factor
is
None
:
if
self
.
scale_factor
is
None
:
scale_factor
=
1.0
/
sqrt
(
self
.
head_dim
)
scale_factor
=
1.0
/
sqrt
(
head_dim
_qk
)
else
:
else
:
scale_factor
=
self
.
scale_factor
scale_factor
=
self
.
scale_factor
del
self
.
scale_factor
del
self
.
scale_factor
...
...
transformer_engine/jax/pyproject.toml
0 → 100755
View file @
2b05e121
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires
=
[
"setuptools>=61.0"
,
"pybind11[global]"
,
"pip"
,
"jax[cuda12]"
,
"flax>=0.7.1"
]
# Use legacy backend to import local packages in setup.py
build-backend
=
"setuptools.build_meta:__legacy__"
transformer_engine/jax/quantize/dequantizer.py
View file @
2b05e121
...
@@ -7,24 +7,54 @@ Dequantization utilities for TE/JAX.
...
@@ -7,24 +7,54 @@ Dequantization utilities for TE/JAX.
This module provides utilities for dequantizing tensors that have been quantized
This module provides utilities for dequantizing tensors that have been quantized
using various scaling modes, including delayed scaling and block scaling.
using various scaling modes, including delayed scaling and block scaling.
"""
"""
import
math
from
dataclasses
import
dataclass
from
abc
import
ABC
,
abstractmethod
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
.scaling_modes
import
ScalingMode
from
.scaling_modes
import
ScalingMode
__all__
=
[
"Dequantizer"
]
__all__
=
[
"ScalingModeToDequantizerMap"
]
@
dataclass
class
Dequantizer
(
ABC
):
"""
Base Dequantizer Class
"""
@
staticmethod
@
abstractmethod
def
_dequantize_func
(
data
,
scale_inv
,
dq_dtype
,
**
kwargs
):
pass
@
staticmethod
@
abstractmethod
def
dequantize
(
scaled_tensor
):
"""Dequantizing given tensor to higher precision."""
class
Dequantizer
:
class
TensorScaleDequantizer
(
Dequantizer
):
"""Encapsulation class for dequantization helpers.
"""
TensorScaling Dequantizer Class
This class provides static methods for dequantizing tensors that have been
This class provides static methods for dequantizing tensors that have been
quantized using different scaling modes. It supports both delayed scaling
quantized using different
tensor
scaling modes. It supports both delayed scaling
and
block
scaling modes.
and
current
scaling modes.
"""
"""
@
staticmethod
@
staticmethod
def
_dq_func_tensor_scaling
(
scaled_tensor
):
def
_dequantize_func
(
data
,
scale_inv
,
dq_dtype
,
**
kwargs
):
del
kwargs
return
jnp
.
asarray
(
data
.
astype
(
jnp
.
float32
)
*
scale_inv
.
astype
(
jnp
.
float32
),
dq_dtype
,
)
@
staticmethod
def
dequantize
(
scaled_tensor
):
"""Dequantize a tensor using delayed scaling.
"""Dequantize a tensor using delayed scaling.
This function dequantizes a tensor that was quantized using delayed scaling
This function dequantizes a tensor that was quantized using delayed scaling
...
@@ -36,36 +66,48 @@ class Dequantizer:
...
@@ -36,36 +66,48 @@ class Dequantizer:
Returns:
Returns:
The dequantized tensor in the specified data type
The dequantized tensor in the specified data type
"""
"""
return
jnp
.
asarray
(
return
TensorScaleDequantizer
.
_dequantize_func
(
scaled_tensor
.
data
.
astype
(
jnp
.
float32
)
*
scaled_tensor
.
scale_inv
.
astype
(
jnp
.
float32
),
scaled_tensor
.
data
,
scaled_tensor
.
scale_inv
,
scaled_tensor
.
dq_dtype
scaled_tensor
.
dq_dtype
,
)
)
class
BlockScaleDequantizer
(
Dequantizer
):
"""BlockScaling Dequantizer Class.
This class provides static methods for dequantizing tensors that have been
quantized using block scaling modes.
"""
@
staticmethod
@
staticmethod
def
_d
q_func_block_scaling
(
scaled_tensor
):
def
_d
equantize_func
(
data
,
scale_inv
,
dq_dtype
,
scaling_mode
,
is_colwise
,
flatten_axis
):
"""Dequantize a tensor using block scaling.
"""Dequantize a tensor using block scaling.
This function dequantizes a tensor that was quantized using block scaling
by applying the inverse scaling factor to each block of data.
Args:
Args:
scaled_tensor: The quantized tensor to dequantize
data: The quantized tensor data
scale_inv: The inverse scaling factors
dq_dtype: The data type for dequantized values
scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D
Returns:
Returns:
The dequantized tensor
in the specified data type
The dequantized tensor
"""
"""
data
=
scaled_tensor
.
data
.
astype
(
jnp
.
float32
)
data
=
data
.
astype
(
jnp
.
float32
)
scale_inv
=
scale_inv
.
view
(
jnp
.
uint8
).
astype
(
jnp
.
float32
)
data_shape
=
data
.
shape
data_shape
=
data
.
shape
scale
=
scaled_tensor
.
scale_inv
.
view
(
jnp
.
uint8
).
astype
(
jnp
.
float32
)
flatten_axis
=
scaled_tensor
.
flatten_axis
flatten_axis
=
len
(
data_shape
)
+
flatten_axis
if
flatten_axis
<
0
else
flatten_axis
flatten_axis
=
len
(
data_shape
)
+
flatten_axis
if
flatten_axis
<
0
else
flatten_axis
assert
(
assert
(
0
<
flatten_axis
<
len
(
data_shape
)
0
<
flatten_axis
<
len
(
data_shape
)
),
f
"flatten_axis
{
flatten_axis
}
is out of bounds for shape
{
data_shape
}
"
),
f
"flatten_axis
{
flatten_axis
}
is out of bounds for shape
{
data_shape
}
"
scale_shape
=
scaled_tensor
.
scaling_mode
.
get_scale_shape
(
scale_shape
=
scaling_mode
.
get_scale_shape
(
data_shape
,
scaled_tensor
.
is_colwise
,
is_padded
=
False
,
flatten_axis
=
flatten_axis
data_shape
,
is_colwise
,
is_padded
=
False
,
flatten_axis
=
flatten_axis
)
)
scale
=
jax
.
lax
.
slice
(
scale
,
[
0
]
*
len
(
scale_shape
),
scale_shape
)
# slice out the padding
scale_inv
=
jax
.
lax
.
slice
(
scale_inv
,
[
0
]
*
len
(
scale_shape
),
scale_shape
)
# slice out the padding
data
=
data
.
reshape
(
data
=
data
.
reshape
(
*
data_shape
[:
flatten_axis
-
1
],
*
data_shape
[:
flatten_axis
-
1
],
...
@@ -76,31 +118,106 @@ class Dequantizer:
...
@@ -76,31 +118,106 @@ class Dequantizer:
int
(
data_shape
[
-
1
]
/
scale_shape
[
-
1
]),
int
(
data_shape
[
-
1
]
/
scale_shape
[
-
1
]),
)
)
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
scale_inv
=
jnp
.
expand_dims
(
scale_inv
,
axis
=
(
flatten_axis
+
2
-
2
,
-
1
))
scale
=
jnp
.
expand_dims
(
scale
,
axis
=
(
flatten_axis
+
2
-
2
,
-
1
))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
return
jnp
.
asarray
(
data
*
jnp
.
power
(
2
,
scale
-
127
),
scaled_tensor
.
dq_dtype
).
reshape
(
data_shape
)
funcs
=
{
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
ScalingMode
.
DELAYED_TENSOR_SCALING
:
_dq_func_tensor_scaling
,
return
jnp
.
asarray
(
data
*
jnp
.
power
(
2
,
scale_inv
-
127
),
dq_dtype
).
reshape
(
data_shape
)
ScalingMode
.
CURRENT_TENSOR_SCALING
:
_dq_func_tensor_scaling
,
ScalingMode
.
MXFP8_1D_SCALING
:
_dq_func_block_scaling
,
}
@
staticmethod
@
staticmethod
def
dequantize
(
scaled_tensor
):
def
dequantize
(
scaled_tensor
):
"""Dequantize a scaled tensor using the appropriate scaling mode.
"""Dequantize a tensor using block scaling.
Args:
data: The quantized tensor data
scale_inv: The inverse scaling factors
dq_dtype: The data type for dequantized values
scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D
Returns:
The dequantized tensor
"""
return
BlockScaleDequantizer
.
_dequantize_func
(
scaled_tensor
.
data
,
scaled_tensor
.
scale_inv
,
scaled_tensor
.
dq_dtype
,
scaled_tensor
.
scaling_mode
,
scaled_tensor
.
is_colwise
,
scaled_tensor
.
flatten_axis
,
)
ScalingModeToDequantizerMap
=
{
ScalingMode
.
DELAYED_TENSOR_SCALING
:
TensorScaleDequantizer
,
ScalingMode
.
CURRENT_TENSOR_SCALING
:
TensorScaleDequantizer
,
ScalingMode
.
MXFP8_1D_SCALING
:
BlockScaleDequantizer
,
}
This method selects the appropriate dequantization function based on the
scaling mode used for quantization and applies it to the tensor.
@
staticmethod
def
_grouped_dequantize
(
grouped_scaled_tensor
):
"""Dequantize a grouped tensor.
Args:
Args:
scaled_tensor: The
quantiz
ed tensor to dequantize
grouped_
scaled_tensor: The
grouped scal
ed tensor to dequantize
Returns:
Returns:
The
dequantized tensor
in the specified data type
List of
dequantized tensor
s for each group
"""
"""
dq_func
=
Dequantizer
.
funcs
[
scaled_tensor
.
scaling_mode
]
data
=
grouped_scaled_tensor
.
data
return
dq_func
(
scaled_tensor
)
scale_inv
=
grouped_scaled_tensor
.
scale_inv
group_sizes
=
grouped_scaled_tensor
.
group_sizes
flatten_axis
=
grouped_scaled_tensor
.
flatten_axis
scaling_mode
=
grouped_scaled_tensor
.
scaling_mode
original_shape
=
grouped_scaled_tensor
.
original_shape
group_axis
=
grouped_scaled_tensor
.
group_axis
flatten_axis
=
len
(
original_shape
)
+
flatten_axis
if
flatten_axis
<
0
else
flatten_axis
output
=
[]
non_group_shape
=
tuple
(
original_shape
[
i
]
for
i
in
range
(
len
(
original_shape
))
if
i
!=
group_axis
)
matrix_sizes
=
group_sizes
*
math
.
prod
(
non_group_shape
)
data
=
jnp
.
split
(
data
,
jnp
.
cumulative_sum
(
matrix_sizes
)[:
-
1
])
scale_inv_ptr
=
0
for
i
,
data_i
in
enumerate
(
data
):
data_shape_i
=
(
*
original_shape
[:
group_axis
],
group_sizes
[
i
],
*
original_shape
[
group_axis
+
1
:],
)
assert
math
.
prod
(
data_shape_i
)
==
data_i
.
size
,
(
f
"math.prod(
{
data_shape_i
}
) =
{
math
.
prod
(
data_shape_i
)
}
which is not equal to"
f
"
{
data_i
.
size
}
"
)
scale_shape_i
=
scaling_mode
.
get_scale_shape
(
data_shape_i
,
grouped_scaled_tensor
.
is_colwise
,
is_padded
=
True
,
flatten_axis
=
flatten_axis
,
)
scale_shape_i_size
=
math
.
prod
(
scale_shape_i
)
scale_inv_i
=
scale_inv
[
scale_inv_ptr
:
scale_inv_ptr
+
scale_shape_i_size
]
dequantizer_type
=
ScalingModeToDequantizerMap
.
get
(
grouped_scaled_tensor
.
scaling_mode
)
if
len
(
data_i
)
==
0
:
out_i
=
[]
else
:
out_i
=
dequantizer_type
.
_dequantize_func
(
data_i
.
reshape
(
data_shape_i
),
scale_inv_i
.
reshape
(
scale_shape_i
),
grouped_scaled_tensor
.
dq_dtype
,
scaling_mode
=
grouped_scaled_tensor
.
scaling_mode
,
is_colwise
=
grouped_scaled_tensor
.
is_colwise
,
flatten_axis
=
grouped_scaled_tensor
.
flatten_axis
,
)
output
.
append
(
out_i
)
scale_inv_ptr
+=
scale_shape_i_size
return
output
Dequantizer
.
grouped_dequantize
=
_grouped_dequantize
transformer_engine/jax/quantize/quantizer.py
View file @
2b05e121
...
@@ -9,7 +9,8 @@ This module provides classes and utilities for quantizing tensors in JAX.
...
@@ -9,7 +9,8 @@ This module provides classes and utilities for quantizing tensors in JAX.
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
functools
import
partial
from
functools
import
partial
from
typing
import
Union
,
Optional
from
typing
import
Union
,
Optional
,
Tuple
import
warnings
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
...
@@ -17,7 +18,7 @@ from jax.tree_util import register_pytree_node_class
...
@@ -17,7 +18,7 @@ from jax.tree_util import register_pytree_node_class
from
transformer_engine_jax
import
QuantizeLayout
from
transformer_engine_jax
import
QuantizeLayout
from
.scaling_modes
import
ScalingMode
from
.scaling_modes
import
ScalingMode
from
.tensor
import
ScaledTensor1x
,
ScaledTensor2x
,
ScaledTensorFactory
from
.tensor
import
ScaledTensor
,
ScaledTensor1x
,
ScaledTensor2x
,
ScaledTensorFactory
from
.helper
import
(
from
.helper
import
(
QuantizeConfig
,
QuantizeConfig
,
AmaxComputeAlgo
,
AmaxComputeAlgo
,
...
@@ -30,6 +31,7 @@ __all__ = [
...
@@ -30,6 +31,7 @@ __all__ = [
"CurrentScaleQuantizer"
,
"CurrentScaleQuantizer"
,
"DelayedScaleQuantizer"
,
"DelayedScaleQuantizer"
,
"BlockScaleQuantizer"
,
"BlockScaleQuantizer"
,
"GroupedQuantizer"
,
"QuantizerFactory"
,
"QuantizerFactory"
,
"noop_quantizer_set"
,
"noop_quantizer_set"
,
"compute_scale_from_amax"
,
"compute_scale_from_amax"
,
...
@@ -74,6 +76,7 @@ class Quantizer(ABC):
...
@@ -74,6 +76,7 @@ class Quantizer(ABC):
q_dtype
:
jnp
.
dtype
q_dtype
:
jnp
.
dtype
scaling_mode
:
ScalingMode
scaling_mode
:
ScalingMode
q_layout
:
QuantizeLayout
q_layout
:
QuantizeLayout
data_layout
:
str
def
tree_flatten
(
self
):
def
tree_flatten
(
self
):
"""Flatten the quantizer for JAX tree operations.
"""Flatten the quantizer for JAX tree operations.
...
@@ -82,7 +85,7 @@ class Quantizer(ABC):
...
@@ -82,7 +85,7 @@ class Quantizer(ABC):
Tuple of (children, aux_data) for tree operations
Tuple of (children, aux_data) for tree operations
"""
"""
children
=
()
children
=
()
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_layout
)
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_layout
,
self
.
data_layout
)
return
(
children
,
aux_data
)
return
(
children
,
aux_data
)
@
classmethod
@
classmethod
...
@@ -110,13 +113,22 @@ class Quantizer(ABC):
...
@@ -110,13 +113,22 @@ class Quantizer(ABC):
"""
"""
return
self
.
q_layout
==
QuantizeLayout
.
ROWWISE_COLWISE
return
self
.
q_layout
==
QuantizeLayout
.
ROWWISE_COLWISE
@
abstractmethod
def
get_data_layout
(
self
)
->
str
:
def
get_data_layout
(
self
)
->
str
:
"""Get the data data_layout.
"""Get the data data_layout
string
.
Returns:
Returns:
Data data_layout in string format
Data data_layout in string format
Raises:
ValueError: If quantization axis is invalid
"""
"""
if
self
.
q_layout
==
QuantizeLayout
.
ROWWISE_COLWISE
:
return
self
.
data_layout
if
self
.
q_layout
==
QuantizeLayout
.
ROWWISE
:
return
self
.
data_layout
[
0
]
if
self
.
q_layout
==
QuantizeLayout
.
COLWISE
:
return
self
.
data_layout
[
1
]
raise
ValueError
(
f
"Invalid q_layout:
{
self
.
q_layout
}
"
)
@
abstractmethod
@
abstractmethod
def
_quantize_func
(
self
,
x
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
)
->
ScaledTensor1x
:
def
_quantize_func
(
self
,
x
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
)
->
ScaledTensor1x
:
...
@@ -132,7 +144,9 @@ class Quantizer(ABC):
...
@@ -132,7 +144,9 @@ class Quantizer(ABC):
A ScaledTensor1x containing the quantized data
A ScaledTensor1x containing the quantized data
"""
"""
def
quantize
(
self
,
x
,
is_rowwise
=
False
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
):
def
quantize
(
self
,
x
,
is_rowwise
=
False
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
,
**
kwargs
)
->
ScaledTensor
:
"""Quantize a tensor using the internal _quantize_func().
"""Quantize a tensor using the internal _quantize_func().
Args:
Args:
...
@@ -145,6 +159,7 @@ class Quantizer(ABC):
...
@@ -145,6 +159,7 @@ class Quantizer(ABC):
Returns:
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
"""
del
kwargs
if
(
is_rowwise
and
is_colwise
)
or
self
.
is_2x2x
():
if
(
is_rowwise
and
is_colwise
)
or
self
.
is_2x2x
():
rowwise_tensor
=
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
rowwise_tensor
=
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
colwise_tensor
=
self
.
_quantize_func
(
colwise_tensor
=
self
.
_quantize_func
(
...
@@ -159,7 +174,7 @@ class Quantizer(ABC):
...
@@ -159,7 +174,7 @@ class Quantizer(ABC):
return
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
return
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
def
get_scale_shapes
(
self
,
data_shape
,
is_padded
=
True
,
flatten_axis
=-
1
):
def
get_scale_shapes
(
self
,
data_shape
,
is_padded
=
True
,
flatten_axis
=-
1
,
**
kwargs
):
"""Get shapes for scale tensors.
"""Get shapes for scale tensors.
Args:
Args:
...
@@ -169,6 +184,7 @@ class Quantizer(ABC):
...
@@ -169,6 +184,7 @@ class Quantizer(ABC):
Returns:
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
"""
del
kwargs
return
self
.
scaling_mode
.
get_scale_shape_2x
(
data_shape
,
is_padded
,
flatten_axis
)
return
self
.
scaling_mode
.
get_scale_shape_2x
(
data_shape
,
is_padded
,
flatten_axis
)
def
get_scale_dtype
(
self
):
def
get_scale_dtype
(
self
):
...
@@ -194,24 +210,7 @@ class CurrentScaleQuantizer(Quantizer):
...
@@ -194,24 +210,7 @@ class CurrentScaleQuantizer(Quantizer):
scaling_mode
:
ScalingMode
=
ScalingMode
.
CURRENT_TENSOR_SCALING
scaling_mode
:
ScalingMode
=
ScalingMode
.
CURRENT_TENSOR_SCALING
q_layout
:
QuantizeLayout
=
QuantizeLayout
.
ROWWISE_COLWISE
q_layout
:
QuantizeLayout
=
QuantizeLayout
.
ROWWISE_COLWISE
data_layout
:
str
=
"NT"
def
get_data_layout
(
self
)
->
str
:
"""Get the data data_layout string.
Returns:
Data data_layout in string format
Raises:
ValueError: If quantization axis is invalid
"""
data_layout
=
"NT"
if
self
.
q_layout
==
QuantizeLayout
.
ROWWISE_COLWISE
:
return
data_layout
if
self
.
q_layout
==
QuantizeLayout
.
ROWWISE
:
return
data_layout
[
0
]
if
self
.
q_layout
==
QuantizeLayout
.
COLWISE
:
return
data_layout
[
1
]
raise
ValueError
(
f
"Invalid q_layout:
{
self
.
q_layout
}
"
)
def
_quantize_func
(
def
_quantize_func
(
self
,
x
:
jnp
.
ndarray
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
self
,
x
:
jnp
.
ndarray
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
...
@@ -230,16 +229,11 @@ class CurrentScaleQuantizer(Quantizer):
...
@@ -230,16 +229,11 @@ class CurrentScaleQuantizer(Quantizer):
compute_dtype
=
jnp
.
float32
compute_dtype
=
jnp
.
float32
dtype_max
=
(
jnp
.
finfo
(
self
.
q_dtype
).
max
).
astype
(
compute_dtype
)
dtype_max
=
(
jnp
.
finfo
(
self
.
q_dtype
).
max
).
astype
(
compute_dtype
)
amax
=
jnp
.
max
(
jnp
.
abs
(
x
)).
reshape
((
1
,))
.
astype
(
compute_dtype
)
amax
=
jnp
.
max
(
jnp
.
abs
(
x
)).
reshape
((
1
,))
fp8_max
=
jnp
.
astype
(
jnp
.
finfo
(
self
.
q_dtype
).
max
,
jnp
.
float32
)
fp8_max
=
jnp
.
astype
(
jnp
.
finfo
(
self
.
q_dtype
).
max
,
jnp
.
float32
)
scale
=
(
fp8_max
/
amax
)
/
(
2
**
QuantizeConfig
.
MARGIN
)
scale
=
(
fp8_max
/
amax
)
/
(
2
**
QuantizeConfig
.
MARGIN
)
scaled_x
=
x
.
astype
(
compute_dtype
)
*
scale
scaled_x
=
x
.
astype
(
compute_dtype
)
*
scale
# quantize() in the old dot.py do this way, leave this code block here for future debugging
# compute_dtype = x.dtype
# dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
# scaled_x = x * self.scale.astype(compute_dtype)
clipped_scaled_x
=
jnp
.
clip
(
scaled_x
,
-
dtype_max
,
dtype_max
).
astype
(
self
.
q_dtype
)
clipped_scaled_x
=
jnp
.
clip
(
scaled_x
,
-
dtype_max
,
dtype_max
).
astype
(
self
.
q_dtype
)
scale_inv
=
1.0
/
scale
scale_inv
=
1.0
/
scale
return
ScaledTensorFactory
.
create_1x
(
return
ScaledTensorFactory
.
create_1x
(
...
@@ -295,6 +289,7 @@ class CurrentScaleQuantizer(Quantizer):
...
@@ -295,6 +289,7 @@ class CurrentScaleQuantizer(Quantizer):
data_layout
=
"T"
,
data_layout
=
"T"
,
flatten_axis
=
flatten_axis
,
flatten_axis
=
flatten_axis
,
)
)
if
is_colwise
and
is_rowwise
:
if
is_colwise
and
is_rowwise
:
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
if
is_colwise
:
if
is_colwise
:
...
@@ -332,7 +327,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
...
@@ -332,7 +327,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Tuple of (children, aux_data) for tree operations
Tuple of (children, aux_data) for tree operations
"""
"""
children
=
(
self
.
scale
,
self
.
amax_history
)
children
=
(
self
.
scale
,
self
.
amax_history
)
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_layout
)
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_layout
,
self
.
data_layout
)
return
(
children
,
aux_data
)
return
(
children
,
aux_data
)
def
_quantize_func
(
def
_quantize_func
(
...
@@ -447,16 +442,7 @@ class BlockScaleQuantizer(Quantizer):
...
@@ -447,16 +442,7 @@ class BlockScaleQuantizer(Quantizer):
scaling_mode
:
ScalingMode
=
ScalingMode
.
MXFP8_1D_SCALING
scaling_mode
:
ScalingMode
=
ScalingMode
.
MXFP8_1D_SCALING
q_layout
:
QuantizeLayout
=
QuantizeLayout
.
ROWWISE_COLWISE
q_layout
:
QuantizeLayout
=
QuantizeLayout
.
ROWWISE_COLWISE
data_layout
:
str
=
"NN"
def
get_data_layout
(
self
)
->
str
:
"""Get the data data_layout string.
Returns:
Data data_layout in string format
"""
if
self
.
is_2x2x
():
return
"NN"
return
"N"
def
_quantize_func
(
self
,
x
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
)
->
ScaledTensor1x
:
def
_quantize_func
(
self
,
x
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
)
->
ScaledTensor1x
:
"""Quantize function helper for block scaling FP8.
"""Quantize function helper for block scaling FP8.
...
@@ -591,6 +577,189 @@ class QuantizerSet:
...
@@ -591,6 +577,189 @@ class QuantizerSet:
return
cls
(
*
aux_data
,
*
children
)
return
cls
(
*
aux_data
,
*
children
)
@
register_pytree_node_class
@
dataclass
class
GroupedQuantizer
(
Quantizer
):
"""Quantizer for grouped arrays.
This class extends Quantizer to support quantization of arrays in grouped manner,
where elements are grouped along a specified axis then quantized separately.
Attributes:
data_layout: The data layout specification
n_groups: Number of groups for quantization
quantizers: Tuple of quantizers for each group
"""
data_layout
:
str
=
None
n_groups
:
int
=
1
quantizers
:
Tuple
[
Quantizer
]
=
field
(
default_factory
=
lambda
:
(
None
,))
def
tree_flatten
(
self
):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children
=
(
self
.
quantizers
,)
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_layout
,
self
.
data_layout
,
self
.
n_groups
)
return
(
children
,
aux_data
)
def
__post_init__
(
self
):
if
self
.
quantizers
[
0
]
is
None
:
self
.
quantizers
=
QuantizerFactory
.
create
(
self
.
n_groups
,
self
.
scaling_mode
,
self
.
q_dtype
,
self
.
q_layout
)
self
.
data_layout
=
self
.
quantizers
[
0
].
data_layout
def
_create_grouped_tensor_from_tensor_list
(
self
,
tensor_list
,
group_sizes
,
original_shape
,
group_axis
,
mode
):
# mode 0 = concate, mode 1 = add
# TODO(Ming Huang): Consider to apply Enum for mode.
assert
mode
in
[
0
,
1
]
grouped_data
=
(
[]
if
mode
==
0
else
jnp
.
zeros
(
tensor_list
[
0
].
data
.
shape
,
tensor_list
[
0
].
data
.
dtype
)
)
grouped_scale_inv
=
[]
for
tensor
in
tensor_list
:
if
mode
==
0
:
grouped_data
.
append
(
tensor
.
data
.
flatten
())
else
:
grouped_data
+=
tensor
.
data
grouped_scale_inv
.
append
(
tensor
.
scale_inv
.
flatten
())
grouped_data
=
jnp
.
concatenate
(
grouped_data
)
if
mode
==
0
else
grouped_data
.
flatten
()
grouped_scale_inv
=
jnp
.
concatenate
(
grouped_scale_inv
)
return
ScaledTensorFactory
.
create_1x
(
grouped_data
,
grouped_scale_inv
,
self
.
scaling_mode
,
tensor_list
[
0
].
dq_dtype
,
tensor_list
[
0
].
is_colwise
,
tensor_list
[
0
].
data_layout
,
tensor_list
[
0
].
flatten_axis
,
group_sizes
=
group_sizes
,
original_shape
=
original_shape
,
group_axis
=
group_axis
,
)
def
_quantize_func
(
self
,
*
args
,
**
kwargs
):
pass
def
quantize
(
self
,
x
,
is_rowwise
:
bool
=
None
,
is_colwise
:
bool
=
None
,
dq_dtype
=
None
,
flatten_axis
=-
1
,
group_sizes
=
None
,
group_axis
=
0
,
):
"""Quantize a tensor in grouped manner.
Expected input shape: [M, K] or [G, K, N]
Split to x.shape[group_axis] number of groups if group_sizes is not given
Args:
x: Input tensor to quantize
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
group_sizes: Array of ints containing the size of each group (default: None)
group_axis: The axis along which grouping is performed (default: 0)
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
assert
group_axis
==
0
,
"Only group_axis == 0 is supported now!"
dq_dtype
=
dq_dtype
if
dq_dtype
is
not
None
else
x
.
dtype
if
flatten_axis
<
0
:
flatten_axis
+=
x
.
ndim
assert
0
<
flatten_axis
<
x
.
ndim
,
"flatten_axis is out of bounds!"
is_rowwise
=
(
is_rowwise
if
is_rowwise
is
not
None
else
(
self
.
q_layout
==
QuantizeLayout
.
ROWWISE
or
self
.
is_2x2x
())
)
is_colwise
=
(
is_colwise
if
is_colwise
is
not
None
else
(
self
.
q_layout
==
QuantizeLayout
.
COLWISE
or
self
.
is_2x2x
())
)
assert
is_rowwise
or
is_colwise
,
"No quantization layout is specified"
original_shape
=
x
.
shape
if
group_sizes
is
not
None
:
assert
not
is_colwise
,
"Not yet implememted!"
assert
group_sizes
.
ndim
==
1
,
(
"GroupedQuantizer only support 1D group_sizes, got group_sizes.ndim ="
f
"
{
group_sizes
.
ndim
}
"
)
_zeros
=
partial
(
jax
.
lax
.
full_like
,
fill_value
=
0
)
x_iota
=
jax
.
lax
.
broadcasted_iota
(
group_sizes
.
dtype
,
x
.
shape
,
0
)
group_ends
=
jnp
.
cumulative_sum
(
group_sizes
)
group_starts
=
jax
.
lax
.
concatenate
(
[
_zeros
(
group_sizes
)[:
1
],
group_ends
[:
-
1
]],
dimension
=
0
,
)
x_zero
=
_zeros
(
x
)
tensor_list
=
[]
for
i
in
range
(
len
(
group_sizes
)):
mask
=
jax
.
lax
.
bitwise_and
(
group_starts
[
i
]
<=
x_iota
,
x_iota
<
group_ends
[
i
])
x_selected
=
jax
.
lax
.
select
(
mask
,
x
,
x_zero
)
tensor
=
self
.
quantizers
[
i
].
quantize
(
x_selected
,
is_rowwise
,
is_colwise
,
dq_dtype
,
flatten_axis
)
tensor_list
.
append
(
tensor
)
combine_mode
=
1
# Add
else
:
group_sizes
=
jnp
.
ones
(
x
.
shape
[
group_axis
],
dtype
=
jnp
.
int32
)
x
=
jnp
.
split
(
x
,
x
.
shape
[
group_axis
],
axis
=
group_axis
)
tensor_list
=
[]
for
i
in
range
(
len
(
group_sizes
)):
tensor
=
self
.
quantizers
[
i
].
quantize
(
x
[
i
],
is_rowwise
,
is_colwise
,
dq_dtype
,
flatten_axis
)
tensor_list
.
append
(
tensor
)
combine_mode
=
0
# Concate
grouped_rowwise_tensor
=
grouped_colwise_tensor
=
None
if
is_rowwise
:
rowwise_tensor_list
=
[
tensor
.
get_rowwise_tensor
()
for
tensor
in
tensor_list
]
grouped_rowwise_tensor
=
self
.
_create_grouped_tensor_from_tensor_list
(
rowwise_tensor_list
,
group_sizes
,
original_shape
,
group_axis
,
combine_mode
)
if
is_colwise
:
colwise_tensor_list
=
[
tensor
.
get_colwise_tensor
()
for
tensor
in
tensor_list
]
grouped_colwise_tensor
=
self
.
_create_grouped_tensor_from_tensor_list
(
colwise_tensor_list
,
group_sizes
,
original_shape
,
group_axis
,
combine_mode
)
if
is_colwise
and
is_rowwise
:
return
ScaledTensor2x
(
grouped_rowwise_tensor
,
grouped_colwise_tensor
)
if
is_colwise
:
return
grouped_colwise_tensor
return
grouped_rowwise_tensor
def
get_scale_shapes
(
self
,
data_shape
,
is_padded
=
True
,
flatten_axis
=-
1
,
group_sizes
=
None
):
assert
group_sizes
,
"Empty group_sizes was given!"
return
self
.
scaling_mode
.
get_grouped_scale_shape_2x
(
data_shape
,
group_sizes
,
is_padded
,
flatten_axis
)
@
dataclass
@
dataclass
class
QuantizerFactory
:
class
QuantizerFactory
:
"""Factory class for creating quantizers.
"""Factory class for creating quantizers.
...
@@ -611,6 +780,7 @@ class QuantizerFactory:
...
@@ -611,6 +780,7 @@ class QuantizerFactory:
scaling_mode
:
ScalingMode
=
None
,
scaling_mode
:
ScalingMode
=
None
,
q_dtype
:
jnp
.
dtype
=
None
,
q_dtype
:
jnp
.
dtype
=
None
,
q_layout
:
QuantizeLayout
=
None
,
q_layout
:
QuantizeLayout
=
None
,
n_groups
:
int
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Quantizer
:
)
->
Quantizer
:
"""Create one or more quantizers with specified parameters.
"""Create one or more quantizers with specified parameters.
...
@@ -621,6 +791,7 @@ class QuantizerFactory:
...
@@ -621,6 +791,7 @@ class QuantizerFactory:
q_dtype: Quantization data type
q_dtype: Quantization data type
q_layout: Quantization axis
q_layout: Quantization axis
flatten_axis: The quantization axis for the tensor
flatten_axis: The quantization axis for the tensor
n_groups: Number of quantizers if GroupedQuantizer
**kwargs: Additional arguments for quantizer initialization
**kwargs: Additional arguments for quantizer initialization
Returns:
Returns:
...
@@ -628,13 +799,21 @@ class QuantizerFactory:
...
@@ -628,13 +799,21 @@ class QuantizerFactory:
"""
"""
# (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
# (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
assert
isinstance
(
scaling_mode
,
ScalingMode
),
"Invalid scaling_mode type"
assert
isinstance
(
scaling_mode
,
ScalingMode
),
"Invalid scaling_mode type"
# import pdb; pdb.set_trace()
if
n_groups
:
if
n_quantizers
!=
1
:
warnings
.
warn
(
"Using more than one GroupedQuantizer for a grouped input is not recommended"
)
quantizer_type
=
GroupedQuantizer
kwargs
[
"n_groups"
]
=
n_groups
else
:
quantizer_type
=
QuantizerFactory
.
quantizer_type_map
.
get
(
scaling_mode
)
if
scaling_mode
==
ScalingMode
.
NO_SCALING
:
if
scaling_mode
==
ScalingMode
.
NO_SCALING
:
quantizers
=
[
None
]
*
n_quantizers
quantizers
=
[
None
]
*
n_quantizers
else
:
else
:
quantizers
=
[]
quantizers
=
[]
for
_
in
range
(
n_quantizers
):
for
_
in
range
(
n_quantizers
):
quantizer_type
=
QuantizerFactory
.
quantizer_type_map
.
get
(
scaling_mode
)
quantizers
.
append
(
quantizers
.
append
(
quantizer_type
(
quantizer_type
(
q_dtype
=
q_dtype
,
scaling_mode
=
scaling_mode
,
q_layout
=
q_layout
,
**
kwargs
q_dtype
=
q_dtype
,
scaling_mode
=
scaling_mode
,
q_layout
=
q_layout
,
**
kwargs
...
@@ -643,7 +822,9 @@ class QuantizerFactory:
...
@@ -643,7 +822,9 @@ class QuantizerFactory:
return
quantizers
[
0
]
if
len
(
quantizers
)
==
1
else
tuple
(
quantizers
)
return
quantizers
[
0
]
if
len
(
quantizers
)
==
1
else
tuple
(
quantizers
)
@
staticmethod
@
staticmethod
def
_create_set
(
scaling_mode
,
fwd_dtype
,
bwd_dtype
,
is_2x2x
,
**
kwargs
)
->
QuantizerSet
:
def
_create_set
(
scaling_mode
,
fwd_dtype
,
bwd_dtype
,
is_2x2x
,
n_groups
,
**
kwargs
)
->
QuantizerSet
:
"""Create a set of quantizers for forward and backward passes.
"""Create a set of quantizers for forward and backward passes.
Args:
Args:
...
@@ -651,6 +832,7 @@ class QuantizerFactory:
...
@@ -651,6 +832,7 @@ class QuantizerFactory:
fwd_dtype: Data type for forward pass
fwd_dtype: Data type for forward pass
bwd_dtype: Data type for backward pass
bwd_dtype: Data type for backward pass
is_2x2x: Whether to use 2x2x quantization
is_2x2x: Whether to use 2x2x quantization
n_groups
**kwargs: Additional arguments for quantizer initialization
**kwargs: Additional arguments for quantizer initialization
Returns:
Returns:
...
@@ -680,11 +862,13 @@ class QuantizerFactory:
...
@@ -680,11 +862,13 @@ class QuantizerFactory:
else
:
else
:
args_x
=
args_kernel
=
args_grad
=
{}
args_x
=
args_kernel
=
args_grad
=
{}
q_x
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
fwd_dtype
,
q_layout_x
,
**
args_x
)
q_x
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
fwd_dtype
,
q_layout_x
,
n_groups
,
**
args_x
)
q_kernel
=
QuantizerFactory
.
create
(
q_kernel
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
fwd_dtype
,
q_layout_kernel
,
**
args_kernel
1
,
scaling_mode
,
fwd_dtype
,
q_layout_kernel
,
n_groups
,
**
args_kernel
)
q_dgrad
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
bwd_dtype
,
q_layout_dgrad
,
n_groups
,
**
args_grad
)
)
q_dgrad
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
bwd_dtype
,
q_layout_dgrad
,
**
args_grad
)
return
QuantizerSet
(
x
=
q_x
,
kernel
=
q_kernel
,
dgrad
=
q_dgrad
)
return
QuantizerSet
(
x
=
q_x
,
kernel
=
q_kernel
,
dgrad
=
q_dgrad
)
@
staticmethod
@
staticmethod
...
@@ -694,6 +878,7 @@ class QuantizerFactory:
...
@@ -694,6 +878,7 @@ class QuantizerFactory:
fwd_dtype
:
jnp
.
dtype
=
None
,
fwd_dtype
:
jnp
.
dtype
=
None
,
bwd_dtype
:
jnp
.
dtype
=
None
,
bwd_dtype
:
jnp
.
dtype
=
None
,
is_2x2x
:
bool
=
None
,
is_2x2x
:
bool
=
None
,
n_groups
:
int
=
None
,
**
kwargs
,
**
kwargs
,
)
->
tuple
[
Union
[
tuple
[
Quantizer
],
None
]]:
)
->
tuple
[
Union
[
tuple
[
Quantizer
],
None
]]:
"""Create one or more sets of quantizers.
"""Create one or more sets of quantizers.
...
@@ -704,6 +889,7 @@ class QuantizerFactory:
...
@@ -704,6 +889,7 @@ class QuantizerFactory:
fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE
fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
n_groups:
**kwargs: Additional arguments for quantizer initialization
**kwargs: Additional arguments for quantizer initialization
Returns:
Returns:
...
@@ -717,7 +903,9 @@ class QuantizerFactory:
...
@@ -717,7 +903,9 @@ class QuantizerFactory:
q_set
=
[]
q_set
=
[]
for
_
in
range
(
n_quantizer_sets
):
for
_
in
range
(
n_quantizer_sets
):
q_set
.
append
(
q_set
.
append
(
QuantizerFactory
.
_create_set
(
scaling_mode
,
fwd_dtype
,
bwd_dtype
,
is_2x2x
,
**
kwargs
)
QuantizerFactory
.
_create_set
(
scaling_mode
,
fwd_dtype
,
bwd_dtype
,
is_2x2x
,
n_groups
,
**
kwargs
)
)
)
return
q_set
[
0
]
if
len
(
q_set
)
==
1
else
tuple
(
q_set
)
return
q_set
[
0
]
if
len
(
q_set
)
==
1
else
tuple
(
q_set
)
...
...
transformer_engine/jax/quantize/scaling_modes.py
View file @
2b05e121
...
@@ -15,6 +15,7 @@ from enum import Enum
...
@@ -15,6 +15,7 @@ from enum import Enum
from
typing
import
Tuple
,
Dict
from
typing
import
Tuple
,
Dict
from
functools
import
reduce
from
functools
import
reduce
import
operator
import
operator
import
numpy
as
np
from
jax.experimental.custom_partitioning
import
CompoundFactor
from
jax.experimental.custom_partitioning
import
CompoundFactor
from
jax.tree_util
import
register_pytree_node_class
from
jax.tree_util
import
register_pytree_node_class
...
@@ -26,6 +27,11 @@ from transformer_engine_jax import JAXX_Scaling_Mode
...
@@ -26,6 +27,11 @@ from transformer_engine_jax import JAXX_Scaling_Mode
__all__
=
[
"QuantizeShardyRules"
,
"ScalingMode"
]
__all__
=
[
"QuantizeShardyRules"
,
"ScalingMode"
]
def
DIVUP
(
a
,
b
):
"Divide a by b and then round up"
return
-
(
a
//
-
b
)
@
dataclass
@
dataclass
class
QuantizeShardyRules
:
class
QuantizeShardyRules
:
"""Information necessary to shard scale tensors with Shardy.
"""Information necessary to shard scale tensors with Shardy.
...
@@ -74,7 +80,26 @@ class ScalingModeMetadataImpl(ABC):
...
@@ -74,7 +80,26 @@ class ScalingModeMetadataImpl(ABC):
data_shape: The shape of the tensor being quantized
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
Returns:
The shape for scale tensors
"""
@
abstractmethod
def
get_grouped_scale_shape
(
self
,
data_shape
,
n_groups
,
group_axis
,
is_colwise
,
is_padded
=
True
,
flatten_axis
=-
1
)
->
Tuple
[
int
]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Original shape of the data tensor
n_groups: Number of groups in grouped quantization
group_axis: The axis along which grouping is performed
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
Returns:
Returns:
The shape for scale tensors
The shape for scale tensors
"""
"""
...
@@ -127,9 +152,29 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -127,9 +152,29 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
Returns:
Returns:
The shape for scale tensors - (1,)
The shape for scale tensors - (1,)
"""
"""
del
data_shape
,
is_colwise
del
is_colwise
if
np
.
prod
(
data_shape
)
==
0
:
return
(
0
,)
return
(
1
,)
return
(
1
,)
def
get_grouped_scale_shape
(
self
,
data_shape
,
n_groups
,
group_axis
,
is_colwise
,
is_padded
=
True
,
flatten_axis
=-
1
)
->
Tuple
[
int
]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Original shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
"""
del
data_shape
,
group_axis
,
is_colwise
assert
isinstance
(
n_groups
,
int
)
return
(
n_groups
,)
def
get_shardy_sharding_rules
(
def
get_shardy_sharding_rules
(
self
,
input_rank
,
unique_var
,
flatten_axis
self
,
input_rank
,
unique_var
,
flatten_axis
)
->
QuantizeShardyRules
:
)
->
QuantizeShardyRules
:
...
@@ -276,6 +321,77 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -276,6 +321,77 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return
(
*
first_dim_scale_shape
,
*
last_dim_scale_shape
)
return
(
*
first_dim_scale_shape
,
*
last_dim_scale_shape
)
def
get_grouped_scale_shape
(
self
,
data_shape
,
n_groups
,
group_axis
,
is_colwise
,
is_padded
=
True
,
flatten_axis
=-
1
)
->
Tuple
[
int
]:
"""Get the shape for grouped scale tensors in this mode.
If padded: The estimiated maximal possible shape for grouped scale tensor is return instead.
Args:
data_shape: Original shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
"""
assert
isinstance
(
n_groups
,
int
)
block_alignment
=
self
.
_block_alignment
if
is_padded
else
(
1
,
1
)
if
is_colwise
:
block_y
,
block_x
=
self
.
_block_dims
alignment_y
,
alignment_x
=
block_alignment
else
:
block_x
,
block_y
=
self
.
_block_dims
alignment_x
,
alignment_y
=
block_alignment
if
flatten_axis
<
0
:
flatten_axis
=
len
(
data_shape
)
+
flatten_axis
assert
(
0
<
flatten_axis
<
len
(
data_shape
)
),
f
"flatten_axis
{
flatten_axis
}
is out of bounds for shape
{
data_shape
}
"
assert
data_shape
[
flatten_axis
-
1
]
%
block_x
==
0
,
(
f
"Data shape
{
data_shape
}
should be divisible by block_x
{
block_x
}
in axis"
f
"
{
flatten_axis
-
1
}
"
)
assert
(
data_shape
[
-
1
]
%
block_y
==
0
),
f
"Data shape
{
data_shape
}
should be divisible by block_y
{
block_y
}
in axis -1"
flattened_first_dim
=
reduce
(
operator
.
mul
,
data_shape
[:
flatten_axis
],
1
)
flattened_last_dim
=
reduce
(
operator
.
mul
,
data_shape
[
flatten_axis
:],
1
)
assert
flattened_first_dim
%
block_x
==
0
,
(
f
"Flattened first dim - mutiplication of axes=
{
tuple
(
range
(
0
,
flatten_axis
))
}
of shape"
f
"
{
data_shape
}
- should be divisible by block_x
{
block_x
}
"
)
assert
flattened_last_dim
%
block_y
==
0
,
(
"Flattened last dim - mutiplication of"
f
" axes=
{
tuple
(
range
(
flatten_axis
,
len
(
data_shape
)))
}
of shape
{
data_shape
}
- should be"
f
" divisible by block_y
{
block_y
}
"
)
n_block_x
=
int
(
flattened_first_dim
//
block_x
)
n_block_y
=
int
(
flattened_last_dim
//
block_y
)
"""
Given the scale shape of [M, N], and G groups, and padding alignment (128, 4),
The worst scenario is when we have (G-1) groups with 1 rows and 1 group with (M-G+1) rows.
Then:
max_padded_rows = (G-1) * 128 + DIVUP(M-G+1, 128) * 128
max_padded_cols = DIVUP(N, 4) * 4
max_scale_size = max_padded_rows * max_padded_cols
"""
if
is_padded
:
n_block_x
=
(
n_groups
-
1
)
*
alignment_x
+
DIVUP
(
n_block_x
-
n_groups
+
1
,
alignment_x
)
*
alignment_x
n_block_y
=
DIVUP
(
n_block_y
,
alignment_y
)
*
alignment_y
return
(
n_block_x
*
n_block_y
,)
def
get_shardy_sharding_rules
(
def
get_shardy_sharding_rules
(
self
,
input_rank
,
unique_var
,
flatten_axis
self
,
input_rank
,
unique_var
,
flatten_axis
)
->
QuantizeShardyRules
:
)
->
QuantizeShardyRules
:
...
@@ -404,6 +520,61 @@ class ScalingMode(Enum):
...
@@ -404,6 +520,61 @@ class ScalingMode(Enum):
"""
"""
return
self
.
_get_impl
().
get_shardy_sharding_rules
(
input_rank
,
unique_var
,
flatten_axis
)
return
self
.
_get_impl
().
get_shardy_sharding_rules
(
input_rank
,
unique_var
,
flatten_axis
)
def
get_grouped_scale_shape_2x
(
self
,
data_shape
,
n_groups
,
group_axis
,
is_padded
=
True
,
flatten_axis
=-
1
)
->
Tuple
[
Tuple
[
int
]]:
"""Get shapes for both row-wise and column-wise scaling.
Args:
data_shape: Shape of the data tensor
n_groups: Number of groups for grouped quantization
group_axis: The axis along which grouping is performed
is_padded: Whether to use padded shapes
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
rowwise_scale_shape
=
self
.
get_grouped_scale_shape
(
data_shape
,
n_groups
,
group_axis
,
is_colwise
=
False
,
is_padded
=
is_padded
,
flatten_axis
=
flatten_axis
,
)
colwise_scale_shape
=
self
.
get_grouped_scale_shape
(
data_shape
,
n_groups
,
group_axis
,
is_colwise
=
True
,
is_padded
=
is_padded
,
flatten_axis
=
flatten_axis
,
)
return
(
rowwise_scale_shape
,
colwise_scale_shape
)
def
get_grouped_scale_shape
(
self
,
data_shape
,
n_groups
,
group_axis
,
is_colwise
,
is_padded
=
True
,
flatten_axis
=-
1
)
->
Tuple
[
Tuple
[
int
]]:
"""Get shapes for both row-wise and column-wise scaling.
Args:
data_shape: Shape of the data tensor
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
return
self
.
_get_impl
().
get_grouped_scale_shape
(
data_shape
,
n_groups
,
group_axis
,
is_colwise
=
is_colwise
,
is_padded
=
is_padded
,
flatten_axis
=
flatten_axis
,
)
def
is_tensor_scaling
(
self
)
->
bool
:
def
is_tensor_scaling
(
self
)
->
bool
:
"""Check if this scaling mode is per-tensor scaling.
"""Check if this scaling mode is per-tensor scaling.
...
...
transformer_engine/jax/quantize/tensor.py
View file @
2b05e121
...
@@ -18,7 +18,7 @@ from jax.tree_util import register_pytree_node_class
...
@@ -18,7 +18,7 @@ from jax.tree_util import register_pytree_node_class
from
transformer_engine_jax
import
QuantizeLayout
from
transformer_engine_jax
import
QuantizeLayout
from
.scaling_modes
import
ScalingMode
from
.scaling_modes
import
ScalingMode
from
.dequantizer
import
Dequantizer
from
.dequantizer
import
ScalingModeTo
Dequantizer
Map
from
..sharding
import
(
from
..sharding
import
(
with_sharding_constraint_by_logical_axes
as
original_with_sharding_constraint_by_logical_axes
,
with_sharding_constraint_by_logical_axes
as
original_with_sharding_constraint_by_logical_axes
,
)
)
...
@@ -27,6 +27,7 @@ __all__ = [
...
@@ -27,6 +27,7 @@ __all__ = [
"ScaledTensor"
,
"ScaledTensor"
,
"ScaledTensor1x"
,
"ScaledTensor1x"
,
"ScaledTensor2x"
,
"ScaledTensor2x"
,
"GroupedScaledTensor1x"
,
"ScaledTensorFactory"
,
"ScaledTensorFactory"
,
"with_sharding_constraint_by_logical_axes"
,
"with_sharding_constraint_by_logical_axes"
,
]
]
...
@@ -122,7 +123,7 @@ class ScaledTensor1x(ScaledTensor):
...
@@ -122,7 +123,7 @@ class ScaledTensor1x(ScaledTensor):
_dq_func
:
Callable
_dq_func
:
Callable
is_colwise
:
bool
is_colwise
:
bool
data_layout
:
str
data_layout
:
str
flatten_axis
:
int
=
-
1
flatten_axis
:
int
def
__post_init__
(
self
):
def
__post_init__
(
self
):
"""Validates and adjusts the scale_inv shape after initialization.
"""Validates and adjusts the scale_inv shape after initialization.
...
@@ -130,22 +131,16 @@ class ScaledTensor1x(ScaledTensor):
...
@@ -130,22 +131,16 @@ class ScaledTensor1x(ScaledTensor):
Ensures the scale_inv shape matches the expected shape based on the scaling mode
Ensures the scale_inv shape matches the expected shape based on the scaling mode
and quantization direction. Pads the scale_inv if necessary.
and quantization direction. Pads the scale_inv if necessary.
"""
"""
flatten_axis
=
(
assert
self
.
flatten_axis
>
0
len
(
self
.
data
.
shape
)
+
self
.
flatten_axis
if
self
.
flatten_axis
<
0
else
self
.
flatten_axis
)
assert
(
assert
(
0
<
flatten_axis
<
len
(
self
.
data
.
shape
)
0
<
self
.
flatten_axis
<
len
(
self
.
data
.
shape
)
),
f
"flatten_axis
{
flatten_axis
}
is out of bounds for shape
{
self
.
data
.
shape
}
"
),
f
"flatten_axis
{
self
.
flatten_axis
}
is out of bounds for shape
{
self
.
data
.
shape
}
"
if
self
.
data_layout
==
"T"
:
flatten_axis
=
self
.
data
.
ndim
-
flatten_axis
self
.
flatten_axis
=
flatten_axis
expected_scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
expected_scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
self
.
data
.
shape
,
self
.
is_colwise
,
is_padded
=
True
,
flatten_axis
=
flatten_axis
self
.
data
.
shape
,
self
.
is_colwise
,
is_padded
=
True
,
flatten_axis
=
self
.
flatten_axis
)
)
expected_unpadded_scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
expected_unpadded_scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
self
.
data
.
shape
,
self
.
is_colwise
,
is_padded
=
False
,
flatten_axis
=
flatten_axis
self
.
data
.
shape
,
self
.
is_colwise
,
is_padded
=
False
,
flatten_axis
=
self
.
flatten_axis
)
)
if
self
.
scale_inv
.
shape
!=
expected_scale_shape
:
if
self
.
scale_inv
.
shape
!=
expected_scale_shape
:
assert
self
.
scale_inv
.
shape
==
expected_unpadded_scale_shape
,
(
assert
self
.
scale_inv
.
shape
==
expected_unpadded_scale_shape
,
(
...
@@ -229,8 +224,12 @@ class ScaledTensor1x(ScaledTensor):
...
@@ -229,8 +224,12 @@ class ScaledTensor1x(ScaledTensor):
# axis_names were given for N layout, so needs to be transpose for T layout
# axis_names were given for N layout, so needs to be transpose for T layout
if
self
.
data_layout
==
"T"
:
if
self
.
data_layout
==
"T"
:
assert
self
.
flatten_axis
>
0
assert
self
.
flatten_axis
>
0
flatten_axis
=
-
self
.
flatten_axis
assert
len
(
logical_axis_names
)
==
self
.
data
.
ndim
axis_names
=
(
*
logical_axis_names
[
flatten_axis
:],
*
logical_axis_names
[:
flatten_axis
])
flatten_axis
=
self
.
data
.
ndim
-
self
.
flatten_axis
axis_names
=
(
*
logical_axis_names
[
flatten_axis
:],
*
logical_axis_names
[:
flatten_axis
],
)
else
:
else
:
axis_names
=
logical_axis_names
axis_names
=
logical_axis_names
...
@@ -254,6 +253,98 @@ class ScaledTensor1x(ScaledTensor):
...
@@ -254,6 +253,98 @@ class ScaledTensor1x(ScaledTensor):
)
)
@
register_pytree_node_class
@
dataclass
class
GroupedScaledTensor1x
(
ScaledTensor1x
):
"""Grouped Quantizer for an array.
This class extends ScaledTensor1x to support quantization of an array in grouped manner,
where elements are grouped along a specified axis.
Attributes:
group_sizes: Array containing the size of each group
original_shape: The original shape of the tensor before grouping
group_axis: The axis along which grouping is performed (default: 0)
"""
group_sizes
:
jnp
.
ndarray
original_shape
:
Tuple
group_axis
:
int
def
__init__
(
self
,
data
,
scale_inv
,
group_sizes
,
scaling_mode
,
dq_dtype
,
_dq_func
,
is_colwise
,
data_layout
,
flatten_axis
,
original_shape
,
group_axis
=
0
,
):
self
.
flatten_axis
=
flatten_axis
self
.
group_sizes
=
group_sizes
self
.
original_shape
=
original_shape
self
.
group_axis
=
group_axis
super
().
__init__
(
data
,
scale_inv
,
scaling_mode
,
dq_dtype
,
_dq_func
,
is_colwise
,
data_layout
,
flatten_axis
)
def
__post_init__
(
self
):
assert
self
.
scale_inv
.
ndim
==
1
,
"Only support flattened scale_inv"
assert
self
.
data
.
ndim
==
1
,
"Only support flattened data"
assert
self
.
group_axis
>=
0
assert
self
.
flatten_axis
>
0
data_ndim
=
len
(
self
.
original_shape
)
assert
(
0
<
self
.
flatten_axis
<
data_ndim
),
f
"flatten_axis
{
self
.
flatten_axis
}
is out of bounds for data.ndim =
{
data_ndim
}
"
assert
(
0
<=
self
.
group_axis
<
data_ndim
),
f
"group_axis
{
self
.
group_axis
}
is out of bounds for shape
{
self
.
original_shape
}
"
expected_scale_shape
=
self
.
scaling_mode
.
get_grouped_scale_shape
(
self
.
original_shape
,
self
.
group_sizes
.
size
,
self
.
group_axis
,
self
.
is_colwise
,
is_padded
=
True
,
flatten_axis
=
self
.
flatten_axis
,
)
assert
self
.
scale_inv
.
shape
==
expected_scale_shape
,
(
f
"Unexpected scale_inv shape!
\n
Expect
{
expected_scale_shape
}
for padded"
f
" scale_inv, got
{
self
.
scale_inv
.
shape
}
"
)
def
tree_flatten
(
self
):
"""Flattens the tensor for JAX tree operations.
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children
=
(
self
.
data
,
self
.
scale_inv
,
self
.
group_sizes
)
aux_data
=
(
self
.
scaling_mode
,
self
.
dq_dtype
,
self
.
_dq_func
,
self
.
is_colwise
,
self
.
data_layout
,
self
.
flatten_axis
,
self
.
original_shape
,
self
.
group_axis
,
)
return
(
children
,
aux_data
)
def
apply_sharding_constraint_by_logical_axes
(
self
,
logical_axis_names
:
Tuple
[
str
,
...]):
raise
NotImplementedError
@
register_pytree_node_class
@
register_pytree_node_class
@
dataclass
@
dataclass
class
ScaledTensor2x
(
ScaledTensor
):
class
ScaledTensor2x
(
ScaledTensor
):
...
@@ -342,6 +433,9 @@ class ScaledTensorFactory:
...
@@ -342,6 +433,9 @@ class ScaledTensorFactory:
is_colwise
=
False
,
is_colwise
=
False
,
data_layout
=
"N"
,
data_layout
=
"N"
,
flatten_axis
=-
1
,
flatten_axis
=-
1
,
group_sizes
=
None
,
original_shape
=
None
,
group_axis
=
0
,
):
):
"""Creates a single-scale quantized tensor.
"""Creates a single-scale quantized tensor.
...
@@ -353,13 +447,67 @@ class ScaledTensorFactory:
...
@@ -353,13 +447,67 @@ class ScaledTensorFactory:
is_colwise: Whether to use column-wise quantization (default: False)
is_colwise: Whether to use column-wise quantization (default: False)
data_layout: The data_layout specification (default: "N")
data_layout: The data_layout specification (default: "N")
flatten_axis: The quantization axis for the tensor
flatten_axis: The quantization axis for the tensor
group_sizes: Arra of ints containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
Returns:
Returns:
A ScaledTensor1x
instance
A ScaledTensor1x
or GroupedScaledTensor1x instance depending on whether group_sizes is provided
"""
"""
dq_func
=
Dequantizer
.
funcs
.
get
(
scaling_mode
)
dequantizer
=
ScalingModeToDequantizerMap
.
get
(
scaling_mode
)
if
group_sizes
is
not
None
:
flatten_axis
=
len
(
original_shape
)
+
flatten_axis
if
flatten_axis
<
0
else
flatten_axis
assert
(
original_shape
is
not
None
),
"original_shape is not given for GroupedScaledTensor1x"
# Handling attrs of transposed tensors
group_axis
=
len
(
original_shape
)
+
group_axis
if
group_axis
<
0
else
group_axis
if
data_layout
==
"T"
:
if
original_shape
[
0
]
==
group_sizes
.
size
:
original_shape
=
(
original_shape
[
0
],
*
original_shape
[
flatten_axis
:],
*
original_shape
[
1
:
flatten_axis
],
)
flatten_axis
=
len
(
original_shape
)
-
flatten_axis
+
1
else
:
original_shape
=
(
*
original_shape
[
flatten_axis
:],
*
original_shape
[:
flatten_axis
],
)
group_axis
=
flatten_axis
flatten_axis
=
len
(
original_shape
)
-
flatten_axis
return
GroupedScaledTensor1x
(
data
=
data
,
scale_inv
=
scale_inv
,
scaling_mode
=
scaling_mode
,
dq_dtype
=
dq_dtype
,
_dq_func
=
dequantizer
.
grouped_dequantize
,
is_colwise
=
is_colwise
,
data_layout
=
data_layout
,
flatten_axis
=
flatten_axis
,
group_sizes
=
group_sizes
,
original_shape
=
original_shape
,
group_axis
=
group_axis
,
)
# Handling attrs of transposed tensors
flatten_axis
=
data
.
ndim
+
flatten_axis
if
flatten_axis
<
0
else
flatten_axis
if
data_layout
==
"T"
:
flatten_axis
=
data
.
ndim
-
flatten_axis
return
ScaledTensor1x
(
return
ScaledTensor1x
(
data
,
scale_inv
,
scaling_mode
,
dq_dtype
,
dq_func
,
is_colwise
,
data_layout
,
flatten_axis
data
,
scale_inv
,
scaling_mode
,
dq_dtype
,
dequantizer
.
dequantize
,
is_colwise
,
data_layout
,
flatten_axis
,
)
)
@
staticmethod
@
staticmethod
...
@@ -372,6 +520,9 @@ class ScaledTensorFactory:
...
@@ -372,6 +520,9 @@ class ScaledTensorFactory:
dq_dtype
=
jnp
.
bfloat16
,
dq_dtype
=
jnp
.
bfloat16
,
data_layout
=
"NN"
,
data_layout
=
"NN"
,
flatten_axis
=-
1
,
flatten_axis
=-
1
,
group_sizes
=
None
,
original_shape
=
None
,
group_axis
=
0
,
):
):
"""Creates a double-scale quantized tensor.
"""Creates a double-scale quantized tensor.
...
@@ -384,30 +535,37 @@ class ScaledTensorFactory:
...
@@ -384,30 +535,37 @@ class ScaledTensorFactory:
dq_dtype: The data type for dequantized values (default: bfloat16)
dq_dtype: The data type for dequantized values (default: bfloat16)
data_layout: The data_layout specification (default: "NN")
data_layout: The data_layout specification (default: "NN")
flatten_axis: The quantization axis for the tensor
flatten_axis: The quantization axis for the tensor
group_sizes: Array containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
Returns:
Returns:
A ScaledTensor2x instance
A ScaledTensor2x instance
"""
"""
dq_func
=
Dequantizer
.
funcs
.
get
(
scaling_mode
)
assert
len
(
data_layout
)
==
2
,
f
"Expect 2 layouts, got
{
data_layout
}
"
rowwise_tensor
=
ScaledTensor1x
(
rowwise_tensor
=
ScaledTensor
Factory
.
create_
1x
(
data
,
data
,
scale_inv
,
scale_inv
,
scaling_mode
,
scaling_mode
,
dq_dtype
,
dq_dtype
,
dq_func
,
is_colwise
=
False
,
is_colwise
=
False
,
data_layout
=
data_layout
[
0
],
data_layout
=
data_layout
[
0
],
flatten_axis
=
flatten_axis
,
flatten_axis
=
flatten_axis
,
group_sizes
=
group_sizes
,
original_shape
=
original_shape
,
group_axis
=
group_axis
,
)
)
colwise_tensor
=
ScaledTensor1x
(
colwise_tensor
=
ScaledTensor
Factory
.
create_
1x
(
colwise_data
,
colwise_data
,
colwise_scale_inv
,
colwise_scale_inv
,
scaling_mode
,
scaling_mode
,
dq_dtype
,
dq_dtype
,
dq_func
,
is_colwise
=
True
,
is_colwise
=
True
,
data_layout
=
data_layout
[
1
],
data_layout
=
data_layout
[
1
],
flatten_axis
=
flatten_axis
,
flatten_axis
=
flatten_axis
,
group_sizes
=
group_sizes
,
original_shape
=
original_shape
,
group_axis
=
group_axis
,
)
)
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
...
@@ -422,6 +580,9 @@ class ScaledTensorFactory:
...
@@ -422,6 +580,9 @@ class ScaledTensorFactory:
data_layout
:
str
=
"NN"
,
data_layout
:
str
=
"NN"
,
q_layout
:
QuantizeLayout
=
QuantizeLayout
.
ROWWISE
,
q_layout
:
QuantizeLayout
=
QuantizeLayout
.
ROWWISE
,
flatten_axis
:
int
=
-
1
,
flatten_axis
:
int
=
-
1
,
group_sizes
:
jnp
.
ndarray
=
None
,
original_shape
:
Tuple
[
int
]
=
None
,
group_axis
:
int
=
0
,
):
):
"""Creates a scaled tensor based on the quantization axis.
"""Creates a scaled tensor based on the quantization axis.
...
@@ -434,6 +595,10 @@ class ScaledTensorFactory:
...
@@ -434,6 +595,10 @@ class ScaledTensorFactory:
dq_dtype: The data type for dequantized values (default: bfloat16)
dq_dtype: The data type for dequantized values (default: bfloat16)
data_layout: The data_layout specification (default: "NN")
data_layout: The data_layout specification (default: "NN")
q_layout: The quantization axis (default: ROWWISE)
q_layout: The quantization axis (default: ROWWISE)
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
group_sizes: Array containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
Returns:
Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
...
@@ -448,9 +613,26 @@ class ScaledTensorFactory:
...
@@ -448,9 +613,26 @@ class ScaledTensorFactory:
dq_dtype
,
dq_dtype
,
data_layout
=
data_layout
,
data_layout
=
data_layout
,
flatten_axis
=
flatten_axis
,
flatten_axis
=
flatten_axis
,
group_sizes
=
group_sizes
,
original_shape
=
original_shape
,
group_axis
=
group_axis
,
)
)
is_colwise
=
q_layout
==
QuantizeLayout
.
COLWISE
is_colwise
=
q_layout
==
QuantizeLayout
.
COLWISE
if
is_colwise
:
return
ScaledTensorFactory
.
create_1x
(
colwise_data
,
colwise_scale_inv
,
scaling_mode
,
dq_dtype
,
is_colwise
=
is_colwise
,
data_layout
=
data_layout
[
0
],
flatten_axis
=
flatten_axis
,
group_sizes
=
group_sizes
,
original_shape
=
original_shape
,
group_axis
=
group_axis
,
)
return
ScaledTensorFactory
.
create_1x
(
return
ScaledTensorFactory
.
create_1x
(
data
,
data
,
scale_inv
,
scale_inv
,
...
@@ -459,6 +641,9 @@ class ScaledTensorFactory:
...
@@ -459,6 +641,9 @@ class ScaledTensorFactory:
is_colwise
=
is_colwise
,
is_colwise
=
is_colwise
,
data_layout
=
data_layout
[
0
],
data_layout
=
data_layout
[
0
],
flatten_axis
=
flatten_axis
,
flatten_axis
=
flatten_axis
,
group_sizes
=
group_sizes
,
original_shape
=
original_shape
,
group_axis
=
group_axis
,
)
)
...
@@ -472,6 +657,9 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, .
...
@@ -472,6 +657,9 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, .
Returns:
Returns:
The tensor with applied sharding constraints
The tensor with applied sharding constraints
"""
"""
if
isinstance
(
x
,
GroupedScaledTensor1x
):
raise
NotImplementedError
if
isinstance
(
x
,
ScaledTensor
):
if
isinstance
(
x
,
ScaledTensor
):
return
x
.
apply_sharding_constraint_by_logical_axes
(
logical_axis_names
)
return
x
.
apply_sharding_constraint_by_logical_axes
(
logical_axis_names
)
...
...
transformer_engine/jax/setup.py
View file @
2b05e121
...
@@ -44,11 +44,10 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
...
@@ -44,11 +44,10 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from
build_tools.build_ext
import
get_build_ext
from
build_tools.build_ext
import
get_build_ext
from
build_tools.utils
import
copy_common_headers
,
install_and_import
from
build_tools.utils
import
copy_common_headers
from
build_tools.te_version
import
te_version
from
build_tools.te_version
import
te_version
from
build_tools.jax
import
setup_jax_extension
from
build_tools.jax
import
setup_jax_extension
,
install_requirements
,
test_requirements
install_and_import
(
"pybind11"
)
from
pybind11.setup_helpers
import
build_ext
as
BuildExtension
from
pybind11.setup_helpers
import
build_ext
as
BuildExtension
os
.
environ
[
"NVTE_PROJECT_BUILDING"
]
=
"1"
os
.
environ
[
"NVTE_PROJECT_BUILDING"
]
=
"1"
...
@@ -101,19 +100,8 @@ if __name__ == "__main__":
...
@@ -101,19 +100,8 @@ if __name__ == "__main__":
description
=
"Transformer acceleration library - Jax Lib"
,
description
=
"Transformer acceleration library - Jax Lib"
,
ext_modules
=
ext_modules
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
},
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
},
setup_requires
=
[
install_requires
=
install_requirements
(),
"jax[cuda12]"
,
tests_require
=
test_requirements
(),
"flax>=0.7.1"
,
"nvidia-cuda-runtime-cu12"
,
"nvidia-cublas-cu12"
,
"nvidia-cudnn-cu12"
,
"nvidia-cuda-cccl-cu12"
,
"nvidia-cuda-nvcc-cu12"
,
"nvidia-nvtx-cu12"
,
"nvidia-cuda-nvrtc-cu12"
,
],
install_requires
=
[
"jax"
,
"flax>=0.7.1"
],
tests_require
=
[
"numpy"
],
)
)
if
any
(
x
in
sys
.
argv
for
x
in
(
"."
,
"sdist"
,
"bdist_wheel"
)):
if
any
(
x
in
sys
.
argv
for
x
in
(
"."
,
"sdist"
,
"bdist_wheel"
)):
shutil
.
rmtree
(
common_headers_dir
)
shutil
.
rmtree
(
common_headers_dir
)
...
...
transformer_engine/jax/sharding.py
View file @
2b05e121
...
@@ -18,6 +18,7 @@ from jax.interpreters import pxla
...
@@ -18,6 +18,7 @@ from jax.interpreters import pxla
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
jax.sharding
import
PartitionSpec
from
jax.sharding
import
PartitionSpec
import
numpy
as
np
_PXLA_THREAD_RESOURCES
=
pxla
.
thread_resources
_PXLA_THREAD_RESOURCES
=
pxla
.
thread_resources
...
@@ -201,6 +202,31 @@ def get_mesh_axis_rank(axis: str, mesh=None):
...
@@ -201,6 +202,31 @@ def get_mesh_axis_rank(axis: str, mesh=None):
return
jax
.
lax
.
axis_index
(
axis_name
)
return
jax
.
lax
.
axis_index
(
axis_name
)
def
get_mesh_axis_rank_host
(
axis
,
mesh
)
->
int
:
"""
Same as get_mesh_axis_rank(), but return a host value instead of a
traced device value.
"""
if
axis
not
in
mesh
.
axis_names
:
raise
ValueError
(
f
"Axis
{
axis
}
not found in mesh axis names:
{
mesh
.
axis_names
}
"
)
axis_index
=
mesh
.
axis_names
.
index
(
axis
)
# Convert mesh.devices (ndarray of Device objects) to flat list
devices
=
mesh
.
devices
local_device
=
jax
.
devices
()[
jax
.
process_index
()]
# Pick one device on this host
# Find index of local_device in mesh.devices
coords
=
np
.
argwhere
(
devices
==
local_device
)
if
coords
.
size
==
0
:
raise
ValueError
(
f
"Local device
{
local_device
}
not found in mesh.devices."
)
coords
=
tuple
(
coords
[
0
])
# Coordinates in the mesh array
# Get the mesh rank along the specified axis
rank
=
coords
[
axis_index
]
return
int
(
rank
)
@
dataclass
@
dataclass
class
MeshResource
:
class
MeshResource
:
"""A data container for managing mesh resources in distributed training.
"""A data container for managing mesh resources in distributed training.
...
...
transformer_engine/pytorch/attention/dot_product_attention/backends.py
View file @
2b05e121
...
@@ -217,7 +217,12 @@ class UnfusedDotProductAttention(torch.nn.Module):
...
@@ -217,7 +217,12 @@ class UnfusedDotProductAttention(torch.nn.Module):
if
"padding"
in
attn_mask_type
and
attention_mask
is
None
:
if
"padding"
in
attn_mask_type
and
attention_mask
is
None
:
attention_mask
=
dpa_utils
.
get_padding_mask
(
attention_mask
=
dpa_utils
.
get_padding_mask
(
batch_size
,
cu_seqlens_q
,
cu_seqlens_kv
,
max_seqlen_q
,
max_seqlen_kv
batch_size
,
cu_seqlens_q
,
cu_seqlens_kv
,
max_seqlen_q
,
max_seqlen_kv
,
self
.
attention_type
,
)
)
attn_mask_type
,
attention_mask
,
actual_seqlens_q
,
actual_seqlens_kv
=
(
attn_mask_type
,
attention_mask
,
actual_seqlens_q
,
actual_seqlens_kv
=
(
dpa_utils
.
get_full_mask
(
dpa_utils
.
get_full_mask
(
...
...
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
View file @
2b05e121
...
@@ -461,6 +461,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -461,6 +461,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
):
):
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
nvtx_range_push
(
"transformer_engine.AttnFuncWithCPAndKVP2P.forward"
)
nvtx_range_push
(
"transformer_engine.AttnFuncWithCPAndKVP2P.forward"
)
enable_mla
=
k
.
shape
[
-
1
]
!=
v
.
shape
[
-
1
]
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
...
@@ -498,6 +499,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -498,6 +499,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_q_half
,
cu_seqlens_kv_half
=
None
,
None
cu_seqlens_q_half
,
cu_seqlens_kv_half
=
None
,
None
if
qkv_format
in
[
"bshd"
,
"sbhd"
]:
if
qkv_format
in
[
"bshd"
,
"sbhd"
]:
seq_dim
=
qkv_format
.
index
(
"s"
)
seq_dim
=
qkv_format
.
index
(
"s"
)
if
enable_mla
:
qkv_layout
=
qkv_format
+
"_"
+
qkv_format
+
"_"
+
qkv_format
else
:
qkv_layout
=
qkv_format
+
"_"
+
qkv_format
[:
-
2
]
+
"2"
+
qkv_format
[
-
2
:]
qkv_layout
=
qkv_format
+
"_"
+
qkv_format
[:
-
2
]
+
"2"
+
qkv_format
[
-
2
:]
cu_seqlens_q_padded
,
cu_seqlens_kv_padded
=
None
,
None
cu_seqlens_q_padded
,
cu_seqlens_kv_padded
=
None
,
None
if
use_fused_attention
:
if
use_fused_attention
:
...
@@ -676,9 +680,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -676,9 +680,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fwd_results_correction_done
=
torch
.
cuda
.
Event
()
fwd_results_correction_done
=
torch
.
cuda
.
Event
()
p2p_comm_buffers
=
[
None
for
_
in
range
(
cp_size
)]
p2p_comm_buffers
=
[
None
for
_
in
range
(
cp_size
)]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]:
if
enable_mla
:
# If MLA, the shape of k and v does not match, so we flatten them
# and split them after receiving them.
k_shape
=
k
.
shape
k_numel
=
k
.
numel
()
v_shape
=
v
.
shape
p2p_comm_buffers
[
0
]
=
torch
.
cat
((
k
.
view
(
-
1
),
v
.
view
(
-
1
)),
dim
=-
1
)
elif
qkv_format
in
[
"bshd"
,
"sbhd"
]:
p2p_comm_buffers
[
0
]
=
torch
.
cat
((
k
.
unsqueeze
(
-
3
),
v
.
unsqueeze
(
-
3
)),
dim
=-
3
)
p2p_comm_buffers
[
0
]
=
torch
.
cat
((
k
.
unsqueeze
(
-
3
),
v
.
unsqueeze
(
-
3
)),
dim
=-
3
)
else
:
else
:
# qkv_format == "thd"
p2p_comm_buffers
[
0
]
=
torch
.
cat
((
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)),
dim
=
0
)
p2p_comm_buffers
[
0
]
=
torch
.
cat
((
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)),
dim
=
0
)
send_recv_reqs
=
[[],
[]]
send_recv_reqs
=
[[],
[]]
...
@@ -707,6 +718,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -707,6 +718,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else
:
else
:
# KV exchange is in BF16/FP16, cast received KV in each step
# KV exchange is in BF16/FP16, cast received KV in each step
kv_inputs
[
i
%
2
]
=
QKV_quantizer
(
p2p_comm_buffers
[
i
]).
_data
kv_inputs
[
i
%
2
]
=
QKV_quantizer
(
p2p_comm_buffers
[
i
]).
_data
if
enable_mla
:
# If MLA, k and v are flattened, so split them after receiving.
k_part
=
kv_inputs
[
i
%
2
][:
k_numel
].
view
(
*
k_shape
)
v_part
=
kv_inputs
[
i
%
2
][
k_numel
:].
view
(
*
v_shape
)
if
causal
:
if
causal
:
if
i
==
0
:
if
i
==
0
:
if
pad_between_seqs
:
if
pad_between_seqs
:
...
@@ -725,6 +740,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -725,6 +740,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if
qkv_format
==
"bshd"
:
if
qkv_format
==
"bshd"
:
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs
[
i
%
2
]
=
q
.
view
(
q
.
shape
[
0
],
-
1
,
*
q
.
shape
[
-
2
:])
q_inputs
[
i
%
2
]
=
q
.
view
(
q
.
shape
[
0
],
-
1
,
*
q
.
shape
[
-
2
:])
if
enable_mla
:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part
=
k_part
.
view
(
k_part
.
shape
[
0
],
-
1
,
*
k_part
.
shape
[
-
2
:])
v_part
=
v_part
.
view
(
v_part
.
shape
[
0
],
-
1
,
*
v_part
.
shape
[
-
2
:])
else
:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs
[
i
%
2
]
=
kv_inputs
[
i
%
2
].
view
(
kv_inputs
[
i
%
2
]
=
kv_inputs
[
i
%
2
].
view
(
k
.
shape
[
0
],
-
1
,
2
,
*
k
.
shape
[
-
2
:]
k
.
shape
[
0
],
-
1
,
2
,
*
k
.
shape
[
-
2
:]
...
@@ -732,6 +752,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -732,6 +752,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif
qkv_format
==
"sbhd"
:
elif
qkv_format
==
"sbhd"
:
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs
[
i
%
2
]
=
q
.
view
(
-
1
,
*
q
.
shape
[
-
3
:])
q_inputs
[
i
%
2
]
=
q
.
view
(
-
1
,
*
q
.
shape
[
-
3
:])
if
enable_mla
:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part
=
k_part
.
view
(
-
1
,
*
k_part
.
shape
[
2
:])
v_part
=
v_part
.
view
(
-
1
,
*
v_part
.
shape
[
2
:])
else
:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs
[
i
%
2
]
=
kv_inputs
[
i
%
2
].
view
(
kv_inputs
[
i
%
2
]
=
kv_inputs
[
i
%
2
].
view
(
-
1
,
k
.
shape
[
2
],
2
,
*
k
.
shape
[
-
2
:]
-
1
,
k
.
shape
[
2
],
2
,
*
k
.
shape
[
-
2
:]
...
@@ -750,6 +775,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -750,6 +775,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
).
contiguous
()
).
contiguous
()
q_part
=
q_inputs
[
i
%
2
]
q_part
=
q_inputs
[
i
%
2
]
if
not
enable_mla
:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part
=
(
k_part
=
(
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
...
@@ -810,6 +838,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -810,6 +838,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_kv
=
max_seqlen_kv
,
max_seqlen_kv
=
max_seqlen_kv
,
)
)
# Need to add MLA support once Flash Attention supports MLA
fa_outputs
=
flash_attn_fwd
(
fa_outputs
=
flash_attn_fwd
(
q_inputs
[
i
%
2
],
q_inputs
[
i
%
2
],
(
(
...
@@ -858,26 +887,50 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -858,26 +887,50 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if
qkv_format
==
"bshd"
:
if
qkv_format
==
"bshd"
:
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs
[
i
%
2
]
=
q
.
view
(
q
.
shape
[
0
],
-
1
,
*
q
.
shape
[
-
2
:])
q_inputs
[
i
%
2
]
=
q
.
view
(
q
.
shape
[
0
],
-
1
,
*
q
.
shape
[
-
2
:])
if
enable_mla
:
# [b, 2, sk//2, np, hn] -> [b, sk//2, np, hn]
k_part
=
k_part
[:,
0
,
...]
v_part
=
v_part
[:,
0
,
...]
else
:
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_inputs
[
i
%
2
]
=
kv_inputs
[
i
%
2
][:,
0
,
...]
kv_inputs
[
i
%
2
]
=
kv_inputs
[
i
%
2
][:,
0
,
...]
elif
qkv_format
==
"sbhd"
:
elif
qkv_format
==
"sbhd"
:
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs
[
i
%
2
]
=
q
.
view
(
-
1
,
*
q
.
shape
[
-
3
:])
q_inputs
[
i
%
2
]
=
q
.
view
(
-
1
,
*
q
.
shape
[
-
3
:])
if
enable_mla
:
# [2, sk//2, b, np, hn] -> [sk//2, b, np, hn]
k_part
=
k_part
[
0
]
v_part
=
v_part
[
0
]
else
:
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_inputs
[
i
%
2
]
=
kv_inputs
[
i
%
2
][
0
]
kv_inputs
[
i
%
2
]
=
kv_inputs
[
i
%
2
][
0
]
elif
qkv_format
==
"thd"
:
elif
qkv_format
==
"thd"
:
q_inputs
[
i
%
2
]
=
q
q_inputs
[
i
%
2
]
=
q
if
enable_mla
:
# [t, np, hn] -> [t/2, np, hn]
k_part
=
tex
.
thd_read_half_tensor
(
k_part
,
cu_seqlens_kv_padded
,
0
)
v_part
=
tex
.
thd_read_half_tensor
(
v_part
,
cu_seqlens_kv_padded
,
0
)
else
:
# [2, t, np, hn] -> [2, t/2, np, hn]
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs
[
i
%
2
]
=
tex
.
thd_read_half_tensor
(
kv_inputs
[
i
%
2
]
=
tex
.
thd_read_half_tensor
(
kv_inputs
[
i
%
2
],
cu_seqlens_kv_padded
,
0
kv_inputs
[
i
%
2
],
cu_seqlens_kv_padded
,
0
)
)
if
use_fused_attention
:
if
use_fused_attention
:
if
enable_mla
:
k_part
=
k_part
.
contiguous
()
v_part
=
v_part
.
contiguous
()
else
:
kv_inputs
[
i
%
2
]
=
kv_inputs
[
i
%
2
].
contiguous
()
kv_inputs
[
i
%
2
]
=
kv_inputs
[
i
%
2
].
contiguous
()
if
attn_bias
is
not
None
:
if
attn_bias
is
not
None
:
idx
=
(
rank
-
i
)
%
cp_size
idx
=
(
rank
-
i
)
%
cp_size
attn_bias_inputs
[
i
%
2
]
=
attn_bias
[...,
idx
,
:].
contiguous
()
attn_bias_inputs
[
i
%
2
]
=
attn_bias
[...,
idx
,
:].
contiguous
()
q_part
=
q_inputs
[
i
%
2
]
q_part
=
q_inputs
[
i
%
2
]
if
not
enable_mla
:
k_part
=
(
k_part
=
(
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
...
@@ -948,6 +1001,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -948,6 +1001,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif
fa_utils
.
v2_7_0_plus
:
elif
fa_utils
.
v2_7_0_plus
:
fa_forward_kwargs
[
"window_size_left"
]
=
-
1
fa_forward_kwargs
[
"window_size_left"
]
=
-
1
fa_forward_kwargs
[
"window_size_right"
]
=
-
1
fa_forward_kwargs
[
"window_size_right"
]
=
-
1
# Need to add MLA support once Flash Attention supports MLA
fa_outputs
=
flash_attn_fwd
(
fa_outputs
=
flash_attn_fwd
(
q_inputs
[
i
%
2
],
q_inputs
[
i
%
2
],
(
(
...
@@ -996,6 +1050,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -996,6 +1050,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if
qkv_format
==
"bshd"
:
if
qkv_format
==
"bshd"
:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_inputs
[
i
%
2
]
=
q
[:,
1
,
...]
q_inputs
[
i
%
2
]
=
q
[:,
1
,
...]
if
enable_mla
:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part
=
k_part
.
view
(
k_part
.
shape
[
0
],
-
1
,
*
k_part
.
shape
[
-
2
:])
v_part
=
v_part
.
view
(
v_part
.
shape
[
0
],
-
1
,
*
v_part
.
shape
[
-
2
:])
else
:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs
[
i
%
2
]
=
kv_inputs
[
i
%
2
].
view
(
kv_inputs
[
i
%
2
]
=
kv_inputs
[
i
%
2
].
view
(
k
.
shape
[
0
],
-
1
,
2
,
*
k
.
shape
[
-
2
:]
k
.
shape
[
0
],
-
1
,
2
,
*
k
.
shape
[
-
2
:]
...
@@ -1003,6 +1062,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1003,6 +1062,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif
qkv_format
==
"sbhd"
:
elif
qkv_format
==
"sbhd"
:
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_inputs
[
i
%
2
]
=
q
[
1
]
q_inputs
[
i
%
2
]
=
q
[
1
]
if
enable_mla
:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part
=
k_part
.
view
(
-
1
,
*
k_part
.
shape
[
2
:])
v_part
=
v_part
.
view
(
-
1
,
*
v_part
.
shape
[
2
:])
else
:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs
[
i
%
2
]
=
kv_inputs
[
i
%
2
].
view
(
kv_inputs
[
i
%
2
]
=
kv_inputs
[
i
%
2
].
view
(
-
1
,
k
.
shape
[
2
],
2
,
*
k
.
shape
[
-
2
:]
-
1
,
k
.
shape
[
2
],
2
,
*
k
.
shape
[
-
2
:]
...
@@ -1025,6 +1089,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1025,6 +1089,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
).
contiguous
()
).
contiguous
()
q_part
=
q_inputs
[
i
%
2
]
q_part
=
q_inputs
[
i
%
2
]
if
not
enable_mla
:
k_part
=
(
k_part
=
(
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
...
@@ -1095,6 +1160,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1095,6 +1160,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif
fa_utils
.
v2_7_0_plus
:
elif
fa_utils
.
v2_7_0_plus
:
fa_forward_kwargs
[
"window_size_left"
]
=
-
1
fa_forward_kwargs
[
"window_size_left"
]
=
-
1
fa_forward_kwargs
[
"window_size_right"
]
=
-
1
fa_forward_kwargs
[
"window_size_right"
]
=
-
1
# Need to add MLA support once Flash Attention supports MLA
fa_outputs
=
flash_attn_fwd
(
fa_outputs
=
flash_attn_fwd
(
q_inputs
[
i
%
2
],
q_inputs
[
i
%
2
],
(
(
...
@@ -1152,6 +1218,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1152,6 +1218,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
).
contiguous
()
).
contiguous
()
q_part
=
q
q_part
=
q
if
not
enable_mla
:
k_part
=
(
k_part
=
(
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
kv_inputs
[
i
%
2
][...,
0
,
:,
:]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
if
qkv_format
in
[
"bshd"
,
"sbhd"
]
...
@@ -1211,6 +1278,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1211,6 +1278,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_kv
=
max_seqlen_kv
,
max_seqlen_kv
=
max_seqlen_kv
,
)
)
# Need to add MLA support once Flash Attention supports MLA
fa_outputs
=
flash_attn_fwd
(
fa_outputs
=
flash_attn_fwd
(
q
,
q
,
(
(
...
@@ -1257,7 +1325,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1257,7 +1325,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if
i
==
1
:
if
i
==
1
:
softmax_lse
=
torch
.
clone
(
softmax_lse_per_step
[
0
])
softmax_lse
=
torch
.
clone
(
softmax_lse_per_step
[
0
])
if
qkv_format
==
"thd"
:
if
qkv_format
==
"thd"
:
out
=
torch
.
zeros_like
(
q
if
not
fp8
else
out_per_step
[
0
]).
view
(
q
.
shape
)
if
enable_mla
:
out
=
torch
.
zeros_like
(
v
if
not
fp8
else
out_per_step
[
0
]).
view
(
v_shape
)
else
:
# MHA or GQA
out
=
torch
.
zeros_like
(
q
if
not
fp8
else
out_per_step
[
0
]).
view
(
q
.
shape
)
elif
(
i
-
1
)
<=
rank
or
not
causal
:
elif
(
i
-
1
)
<=
rank
or
not
causal
:
flash_attn_fwd_softmax_lse_correction
(
flash_attn_fwd_softmax_lse_correction
(
softmax_lse
,
softmax_lse_per_step
[
i
-
1
]
softmax_lse
,
softmax_lse_per_step
[
i
-
1
]
...
@@ -1295,6 +1371,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1295,6 +1371,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step
[
0
],
softmax_lse_per_step
[
0
],
seq_dim
,
seq_dim
,
)
)
if
enable_mla
:
out
=
out
.
view
(
v_shape
)
else
:
out
=
out
.
view
(
q
.
shape
)
out
=
out
.
view
(
q
.
shape
)
else
:
else
:
flash_attn_fwd_out_correction
(
flash_attn_fwd_out_correction
(
...
@@ -1417,6 +1496,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1417,6 +1496,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx
.
is_output_fp8
=
is_output_fp8
ctx
.
is_output_fp8
=
is_output_fp8
ctx
.
use_flash_attn_3
=
use_flash_attn_3
ctx
.
use_flash_attn_3
=
use_flash_attn_3
ctx
.
enable_mla
=
enable_mla
if
enable_mla
:
ctx
.
k_numel
=
k_numel
ctx
.
k_shape
=
k_shape
ctx
.
v_shape
=
v_shape
ctx
.
qkv_dtype
=
qkv_dtype
ctx
.
qkv_dtype
=
qkv_dtype
ctx
.
dQKV_quantizer
=
dQKV_quantizer
ctx
.
dQKV_quantizer
=
dQKV_quantizer
ctx
.
dQKV_CP_quantizer
=
dQKV_CP_quantizer
ctx
.
dQKV_CP_quantizer
=
dQKV_CP_quantizer
...
@@ -1466,6 +1551,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1466,6 +1551,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
seq_dim
=
None
seq_dim
=
None
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]:
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]:
seq_dim
=
ctx
.
qkv_format
.
index
(
"s"
)
seq_dim
=
ctx
.
qkv_format
.
index
(
"s"
)
if
ctx
.
enable_mla
:
qkv_layout
=
ctx
.
qkv_format
+
"_"
+
ctx
.
qkv_format
+
"_"
+
ctx
.
qkv_format
else
:
qkv_layout
=
ctx
.
qkv_format
+
"_"
+
ctx
.
qkv_format
[:
-
2
]
+
"2"
+
ctx
.
qkv_format
[
-
2
:]
qkv_layout
=
ctx
.
qkv_format
+
"_"
+
ctx
.
qkv_format
[:
-
2
]
+
"2"
+
ctx
.
qkv_format
[
-
2
:]
else
:
else
:
qkv_layout
=
ctx
.
qkv_format
+
"_"
+
ctx
.
qkv_format
+
"_"
+
ctx
.
qkv_format
qkv_layout
=
ctx
.
qkv_format
+
"_"
+
ctx
.
qkv_format
+
"_"
+
ctx
.
qkv_format
...
@@ -1595,6 +1683,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1595,6 +1683,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
)
)
dout
=
dout
.
dequantize
(
dtype
=
dout_dtype
)
dout
=
dout
.
dequantize
(
dtype
=
dout_dtype
)
if
ctx
.
enable_mla
:
out
=
out
.
view
(
*
ctx
.
v_shape
)
dout
=
dout
.
view
(
*
ctx
.
v_shape
)
else
:
# MHA or GQA
out
=
out
.
view
(
*
q
.
shape
)
out
=
out
.
view
(
*
q
.
shape
)
dout
=
dout
.
view
(
*
q
.
shape
)
dout
=
dout
.
view
(
*
q
.
shape
)
send_recv_reqs
=
[]
send_recv_reqs
=
[]
...
@@ -1672,6 +1765,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1672,6 +1765,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv
=
p2p_comm_buffers
[
i
%
2
][
0
]
kv
=
p2p_comm_buffers
[
i
%
2
][
0
]
q_
,
kv_
,
out_
,
dout_
=
None
,
None
,
None
,
None
q_
,
kv_
,
out_
,
dout_
=
None
,
None
,
None
,
None
dq_
,
dk_
,
dv_
=
None
,
None
,
None
dq_
,
dk_
,
dv_
=
None
,
None
,
None
if
ctx
.
enable_mla
:
k_part
=
kv
[:
ctx
.
k_numel
].
view
(
*
ctx
.
k_shape
)
v_part
=
kv
[
ctx
.
k_numel
:].
view
(
*
ctx
.
v_shape
)
# In reversed order of fwd
# In reversed order of fwd
if
causal
:
if
causal
:
if
i
==
(
cp_size
-
1
):
if
i
==
(
cp_size
-
1
):
...
@@ -1680,11 +1776,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1680,11 +1776,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_
,
out_
,
dout_
=
[
q_
,
out_
,
dout_
=
[
x
.
view
(
x
.
shape
[
0
],
-
1
,
*
x
.
shape
[
-
2
:])
for
x
in
[
q
,
out
,
dout
]
x
.
view
(
x
.
shape
[
0
],
-
1
,
*
x
.
shape
[
-
2
:])
for
x
in
[
q
,
out
,
dout
]
]
]
if
ctx
.
enable_mla
:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part
=
k_part
.
view
(
k_part
.
shape
[
0
],
-
1
,
*
k_part
.
shape
[
-
2
:])
v_part
=
v_part
.
view
(
v_part
.
shape
[
0
],
-
1
,
*
v_part
.
shape
[
-
2
:])
else
:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_
=
kv
.
view
(
kv
.
shape
[
0
],
-
1
,
*
kv
.
shape
[
-
3
:])
kv_
=
kv
.
view
(
kv
.
shape
[
0
],
-
1
,
*
kv
.
shape
[
-
3
:])
elif
ctx
.
qkv_format
==
"sbhd"
:
elif
ctx
.
qkv_format
==
"sbhd"
:
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_
,
out_
,
dout_
=
[
x
.
view
(
-
1
,
*
x
.
shape
[
-
3
:])
for
x
in
[
q
,
out
,
dout
]]
q_
,
out_
,
dout_
=
[
x
.
view
(
-
1
,
*
x
.
shape
[
-
3
:])
for
x
in
[
q
,
out
,
dout
]]
if
ctx
.
enable_mla
:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part
=
k_part
.
view
(
-
1
,
*
k_part
.
shape
[
-
3
:])
v_part
=
v_part
.
view
(
-
1
,
*
v_part
.
shape
[
-
3
:])
else
:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_
=
kv
.
view
(
-
1
,
*
kv
.
shape
[
-
4
:])
kv_
=
kv
.
view
(
-
1
,
*
kv
.
shape
[
-
4
:])
elif
ctx
.
qkv_format
==
"thd"
:
elif
ctx
.
qkv_format
==
"thd"
:
...
@@ -1701,8 +1807,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1701,8 +1807,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if
attn_dbias
is
not
None
:
if
attn_dbias
is
not
None
:
aux_ctx_tensors
+=
[
attn_biases
[
cp_size
-
i
-
1
]]
aux_ctx_tensors
+=
[
attn_biases
[
cp_size
-
i
-
1
]]
q_part
=
q_
q_part
=
q_
k_part
=
kv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
0
]
if
not
ctx
.
enable_mla
:
v_part
=
kv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
1
]
k_part
=
(
kv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
0
]
)
v_part
=
(
kv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
1
]
)
out_part
=
out_
out_part
=
out_
dout_part
=
dout_
dout_part
=
dout_
...
@@ -1784,6 +1895,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1784,6 +1895,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs
[
"window_size_right"
]
=
0
fa_backward_kwargs
[
"window_size_right"
]
=
0
if
not
ctx
.
use_flash_attn_3
:
if
not
ctx
.
use_flash_attn_3
:
fa_backward_kwargs
[
"rng_state"
]
=
rng_states
[
cp_size
-
i
-
1
]
fa_backward_kwargs
[
"rng_state"
]
=
rng_states
[
cp_size
-
i
-
1
]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd
(
flash_attn_bwd
(
dout_
,
dout_
,
q_
,
q_
,
...
@@ -1801,18 +1913,37 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1801,18 +1913,37 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_
,
out_
,
dout_
=
[
q_
,
out_
,
dout_
=
[
x
.
view
(
x
.
shape
[
0
],
-
1
,
*
x
.
shape
[
-
2
:])
for
x
in
[
q
,
out
,
dout
]
x
.
view
(
x
.
shape
[
0
],
-
1
,
*
x
.
shape
[
-
2
:])
for
x
in
[
q
,
out
,
dout
]
]
]
if
ctx
.
enable_mla
:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part
=
k_part
[:,
0
]
v_part
=
v_part
[:,
0
]
else
:
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_
=
kv
[:,
0
]
kv_
=
kv
[:,
0
]
elif
ctx
.
qkv_format
==
"sbhd"
:
elif
ctx
.
qkv_format
==
"sbhd"
:
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_
,
out_
,
dout_
=
[
x
.
view
(
-
1
,
*
x
.
shape
[
-
3
:])
for
x
in
[
q
,
out
,
dout
]]
q_
,
out_
,
dout_
=
[
x
.
view
(
-
1
,
*
x
.
shape
[
-
3
:])
for
x
in
[
q
,
out
,
dout
]]
if
ctx
.
enable_mla
:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part
=
k_part
[
0
]
v_part
=
v_part
[
0
]
else
:
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_
=
kv
[
0
]
kv_
=
kv
[
0
]
elif
ctx
.
qkv_format
==
"thd"
:
elif
ctx
.
qkv_format
==
"thd"
:
q_
,
out_
,
dout_
=
q
,
out
,
dout
q_
,
out_
,
dout_
=
q
,
out
,
dout
if
ctx
.
enable_mla
:
# [t, np, hn] -> [t/2, np, hn]
k_part
=
tex
.
thd_read_half_tensor
(
k_part
,
cu_seqlens_kv_padded
,
0
)
v_part
=
tex
.
thd_read_half_tensor
(
v_part
,
cu_seqlens_kv_padded
,
0
)
else
:
# [2, t, np, hn] -> [2, t/2, np, hn]
# [2, t, np, hn] -> [2, t/2, np, hn]
kv_
=
tex
.
thd_read_half_tensor
(
kv
,
cu_seqlens_kv_padded
,
0
)
kv_
=
tex
.
thd_read_half_tensor
(
kv
,
cu_seqlens_kv_padded
,
0
)
if
ctx
.
use_fused_attention
:
if
ctx
.
use_fused_attention
:
if
ctx
.
enable_mla
:
k_part
=
k_part
.
contiguous
()
v_part
=
v_part
.
contiguous
()
else
:
kv_
=
kv_
.
contiguous
()
kv_
=
kv_
.
contiguous
()
if
ctx
.
fp8
:
if
ctx
.
fp8
:
aux_ctx_tensors
=
[
aux_ctx_tensors
=
[
...
@@ -1825,8 +1956,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1825,8 +1956,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if
attn_dbias
is
not
None
:
if
attn_dbias
is
not
None
:
aux_ctx_tensors
+=
[
attn_biases
[
cp_size
-
i
-
1
]]
aux_ctx_tensors
+=
[
attn_biases
[
cp_size
-
i
-
1
]]
q_part
=
q_
q_part
=
q_
k_part
=
kv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
0
]
if
not
ctx
.
enable_mla
:
v_part
=
kv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
1
]
k_part
=
(
kv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
0
]
)
v_part
=
(
kv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
1
]
)
out_part
=
out_
out_part
=
out_
dout_part
=
dout_
dout_part
=
dout_
...
@@ -1910,6 +2046,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1910,6 +2046,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs
[
"window_size_right"
]
=
-
1
fa_backward_kwargs
[
"window_size_right"
]
=
-
1
if
not
ctx
.
use_flash_attn_3
:
if
not
ctx
.
use_flash_attn_3
:
fa_backward_kwargs
[
"rng_state"
]
=
rng_states
[
cp_size
-
i
-
1
]
fa_backward_kwargs
[
"rng_state"
]
=
rng_states
[
cp_size
-
i
-
1
]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd
(
flash_attn_bwd
(
dout_
,
dout_
,
q_
,
q_
,
...
@@ -1925,11 +2062,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1925,11 +2062,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if
ctx
.
qkv_format
==
"bshd"
:
if
ctx
.
qkv_format
==
"bshd"
:
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_
,
out_
,
dout_
=
q
[:,
1
],
out
[:,
1
],
dout
[:,
1
]
q_
,
out_
,
dout_
=
q
[:,
1
],
out
[:,
1
],
dout
[:,
1
]
if
ctx
.
enable_mla
:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part
=
k_part
.
view
(
k_part
.
shape
[
0
],
-
1
,
*
k_part
.
shape
[
-
2
:])
v_part
=
v_part
.
view
(
v_part
.
shape
[
0
],
-
1
,
*
v_part
.
shape
[
-
2
:])
else
:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_
=
kv
.
view
(
kv
.
shape
[
0
],
-
1
,
*
kv
.
shape
[
-
3
:])
kv_
=
kv
.
view
(
kv
.
shape
[
0
],
-
1
,
*
kv
.
shape
[
-
3
:])
elif
ctx
.
qkv_format
==
"sbhd"
:
elif
ctx
.
qkv_format
==
"sbhd"
:
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_
,
out_
,
dout_
=
q
[
1
],
out
[
1
],
dout
[
1
]
q_
,
out_
,
dout_
=
q
[
1
],
out
[
1
],
dout
[
1
]
if
ctx
.
enable_mla
:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part
=
k_part
.
view
(
-
1
,
*
k_part
.
shape
[
-
3
:])
v_part
=
v_part
.
view
(
-
1
,
*
v_part
.
shape
[
-
3
:])
else
:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_
=
kv
.
view
(
-
1
,
*
kv
.
shape
[
-
4
:])
kv_
=
kv
.
view
(
-
1
,
*
kv
.
shape
[
-
4
:])
elif
ctx
.
qkv_format
==
"thd"
:
elif
ctx
.
qkv_format
==
"thd"
:
...
@@ -1953,8 +2100,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1953,8 +2100,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
aux_ctx_tensors
+=
[
attn_biases
[
cp_size
-
i
-
1
]]
aux_ctx_tensors
+=
[
attn_biases
[
cp_size
-
i
-
1
]]
q_part
=
q_
q_part
=
q_
k_part
=
kv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
0
]
if
not
ctx
.
enable_mla
:
v_part
=
kv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
1
]
k_part
=
(
kv_
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
0
]
)
v_part
=
(
kv_
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv_
[
1
]
)
out_part
=
out_
out_part
=
out_
dout_part
=
dout_
dout_part
=
dout_
...
@@ -2038,6 +2190,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -2038,6 +2190,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs
[
"window_size_right"
]
=
-
1
fa_backward_kwargs
[
"window_size_right"
]
=
-
1
if
not
ctx
.
use_flash_attn_3
:
if
not
ctx
.
use_flash_attn_3
:
fa_backward_kwargs
[
"rng_state"
]
=
rng_states
[
cp_size
-
i
-
1
]
fa_backward_kwargs
[
"rng_state"
]
=
rng_states
[
cp_size
-
i
-
1
]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd
(
flash_attn_bwd
(
dout_
,
dout_
,
q_
,
q_
,
...
@@ -2058,6 +2211,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -2058,6 +2211,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if
attn_dbias
is
not
None
:
if
attn_dbias
is
not
None
:
aux_ctx_tensors
+=
[
attn_biases
[
cp_size
-
i
-
1
]]
aux_ctx_tensors
+=
[
attn_biases
[
cp_size
-
i
-
1
]]
q_part
=
q
q_part
=
q
if
not
ctx
.
enable_mla
:
k_part
=
kv
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv
[
0
]
k_part
=
kv
[...,
0
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv
[
0
]
v_part
=
kv
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv
[
1
]
v_part
=
kv
[...,
1
,
:,
:]
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]
else
kv
[
1
]
out_part
=
out
out_part
=
out
...
@@ -2133,6 +2287,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -2133,6 +2287,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs
[
"window_size_right"
]
=
-
1
fa_backward_kwargs
[
"window_size_right"
]
=
-
1
if
not
ctx
.
use_flash_attn_3
:
if
not
ctx
.
use_flash_attn_3
:
fa_backward_kwargs
[
"rng_state"
]
=
rng_states
[
cp_size
-
i
-
1
]
fa_backward_kwargs
[
"rng_state"
]
=
rng_states
[
cp_size
-
i
-
1
]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd
(
flash_attn_bwd
(
dout
,
dout
,
q
,
q
,
...
@@ -2225,15 +2380,18 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -2225,15 +2380,18 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else
:
else
:
dkv
=
p2p_comm_buffers
[(
i
+
1
)
%
2
][
1
]
dkv
=
p2p_comm_buffers
[(
i
+
1
)
%
2
][
1
]
if
ctx
.
use_fused_attention
:
if
ctx
.
use_fused_attention
:
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]:
if
ctx
.
enable_mla
:
dkv_
=
None
elif
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]:
dkv_
=
combine_tensors
([
dk_
,
dv_
],
-
2
)
dkv_
=
combine_tensors
([
dk_
,
dv_
],
-
2
)
elif
ctx
.
qkv_format
==
"thd"
:
elif
ctx
.
qkv_format
==
"thd"
:
dkv_
=
torch
.
cat
(
dkv_
=
torch
.
cat
(
(
dk_
.
unsqueeze
(
0
),
dv_
.
unsqueeze
(
0
)),
dim
=
0
(
dk_
.
unsqueeze
(
0
),
dv_
.
unsqueeze
(
0
)),
dim
=
0
)
# pylint: disable=used-before-assignment
)
# pylint: disable=used-before-assignment
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]:
if
not
ctx
.
enable_mla
and
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]:
# [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
# [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
# [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
# [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
# dkv is a buffer, so we do not need to transpose it, but only need to reshape it.
dkv
=
dkv
.
view
(
2
,
*
dkv
.
shape
[
0
:
-
3
],
*
dkv
.
shape
[
-
2
:])
dkv
=
dkv
.
view
(
2
,
*
dkv
.
shape
[
0
:
-
3
],
*
dkv
.
shape
[
-
2
:])
dkv_
=
dkv_
.
movedim
(
-
3
,
0
)
dkv_
=
dkv_
.
movedim
(
-
3
,
0
)
if
causal
and
(
i
<
(
cp_size
-
rank
-
1
)
or
i
==
(
cp_size
-
1
)):
if
causal
and
(
i
<
(
cp_size
-
rank
-
1
)
or
i
==
(
cp_size
-
1
)):
...
@@ -2241,7 +2399,101 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -2241,7 +2399,101 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn]
# [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn]
dkv_
=
dkv_
.
view
(
*
dkv
.
shape
)
dkv_
=
dkv_
.
view
(
*
dkv
.
shape
)
if
ctx
.
enable_mla
:
# [b, 2, sk//2, np, hn] or
# [2, sk//2, b, np, hn]
dk
=
dkv
[:
ctx
.
k_numel
].
view
(
*
ctx
.
k_shape
)
dv
=
dkv
[
ctx
.
k_numel
:].
view
(
*
ctx
.
v_shape
)
if
causal
and
(
i
<
(
cp_size
-
rank
-
1
)
or
i
==
(
cp_size
-
1
)):
dk_
=
dk_
.
view
(
*
ctx
.
k_shape
)
dv_
=
dv_
.
view
(
*
ctx
.
v_shape
)
if
ctx
.
fp8
:
# enable_mla and fp8
if
causal
and
i
>=
(
cp_size
-
rank
-
1
)
and
i
!=
(
cp_size
-
1
):
if
ctx
.
qkv_format
==
"bshd"
:
dk
[:,
0
,
...].
copy_
(
dk_
)
dk
[:,
1
,
...].
fill_
(
0
)
dv
[:,
0
,
...].
copy_
(
dv_
)
dv
[:,
1
,
...].
fill_
(
0
)
elif
ctx
.
qkv_format
==
"sbhd"
:
dk
[
0
].
copy_
(
dk_
)
dk
[
1
].
fill_
(
0
)
dv
[
0
].
copy_
(
dv_
)
dv
[
1
].
fill_
(
0
)
else
:
dk
.
copy_
(
dk_
)
dv
.
copy_
(
dv_
)
elif
causal
:
# enable_mla and not fp8 and causal
if
i
==
(
cp_size
-
1
):
if
rank
==
0
:
if
ctx
.
qkv_format
==
"bshd"
:
dk
[:,
0
,
...].
add_
(
dk_
[:,
0
,
...])
dk
[:,
1
,
...].
copy_
(
dk_
[:,
1
,
...])
dv
[:,
0
,
...].
add_
(
dv_
[:,
0
,
...])
dv
[:,
1
,
...].
copy_
(
dv_
[:,
1
,
...])
elif
ctx
.
qkv_format
==
"sbhd"
:
dk
[
0
,
...].
add_
(
dk_
[
0
,
...])
dk
[
1
,
...].
copy_
(
dk_
[
1
,
...])
dv
[
0
,
...].
add_
(
dv_
[
0
,
...])
dv
[
1
,
...].
copy_
(
dv_
[
1
,
...])
elif
ctx
.
qkv_format
==
"thd"
:
tex
.
thd_grad_correction
(
dk
,
dk_
,
cu_seqlens_kv_padded
,
"add"
,
"copy"
)
tex
.
thd_grad_correction
(
dv
,
dv_
,
cu_seqlens_kv_padded
,
"add"
,
"copy"
)
else
:
dk
.
add_
(
dk_
)
dv
.
add_
(
dv_
)
elif
i
>=
(
cp_size
-
rank
-
1
):
if
i
==
0
and
rank
==
(
cp_size
-
1
):
if
ctx
.
qkv_format
==
"bshd"
:
dk
[:,
0
,
...].
copy_
(
dk_
)
dv
[:,
0
,
...].
copy_
(
dv_
)
elif
ctx
.
qkv_format
==
"sbhd"
:
dk
[
0
,
...].
copy_
(
dk_
)
dv
[
0
,
...].
copy_
(
dv_
)
elif
ctx
.
qkv_format
==
"thd"
:
tex
.
thd_grad_correction
(
dk
,
dk_
,
cu_seqlens_kv_padded
,
"copy"
,
"none"
)
tex
.
thd_grad_correction
(
dv
,
dv_
,
cu_seqlens_kv_padded
,
"copy"
,
"none"
)
else
:
if
ctx
.
qkv_format
==
"bshd"
:
dk
[:,
0
,
...].
add_
(
dk_
)
dv
[:,
0
,
...].
add_
(
dv_
)
elif
ctx
.
qkv_format
==
"sbhd"
:
dk
[
0
,
...].
add_
(
dk_
)
dv
[
0
,
...].
add_
(
dv_
)
elif
ctx
.
qkv_format
==
"thd"
:
tex
.
thd_grad_correction
(
dk
,
dk_
,
cu_seqlens_kv_padded
,
"add"
,
"none"
)
tex
.
thd_grad_correction
(
dv
,
dv_
,
cu_seqlens_kv_padded
,
"add"
,
"none"
)
elif
i
>
0
:
dk
.
add_
(
dk_
)
dv
.
add_
(
dv_
)
else
:
# i == 0
dk
.
copy_
(
dk_
)
dv
.
copy_
(
dv_
)
else
:
# enable_mla and not fp8 and not causal
if
i
==
0
:
dk
.
copy_
(
dk_
)
dv
.
copy_
(
dv_
)
else
:
# i > 0
dk
.
add_
(
dk_
)
dv
.
add_
(
dv_
)
else
:
if
ctx
.
fp8
:
if
ctx
.
fp8
:
# fp8
if
causal
and
i
>=
(
cp_size
-
rank
-
1
)
and
i
!=
(
cp_size
-
1
):
if
causal
and
i
>=
(
cp_size
-
rank
-
1
)
and
i
!=
(
cp_size
-
1
):
if
ctx
.
qkv_format
==
"bshd"
:
if
ctx
.
qkv_format
==
"bshd"
:
dkv
[:,
:,
0
,
...].
copy_
(
dkv_
)
dkv
[:,
:,
0
,
...].
copy_
(
dkv_
)
...
@@ -2252,6 +2504,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -2252,6 +2504,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else
:
else
:
dkv
.
copy_
(
dkv_
)
dkv
.
copy_
(
dkv_
)
elif
causal
:
elif
causal
:
# not fp8 and causal
if
i
==
(
cp_size
-
1
):
if
i
==
(
cp_size
-
1
):
if
rank
==
0
:
if
rank
==
0
:
if
ctx
.
qkv_format
==
"bshd"
:
if
ctx
.
qkv_format
==
"bshd"
:
...
@@ -2261,7 +2514,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -2261,7 +2514,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dkv
[:,
0
,
...].
add_
(
dkv_
[:,
0
,
...])
dkv
[:,
0
,
...].
add_
(
dkv_
[:,
0
,
...])
dkv
[:,
1
,
...].
copy_
(
dkv_
[:,
1
,
...])
dkv
[:,
1
,
...].
copy_
(
dkv_
[:,
1
,
...])
elif
ctx
.
qkv_format
==
"thd"
:
elif
ctx
.
qkv_format
==
"thd"
:
tex
.
thd_grad_correction
(
dkv
,
dkv_
,
cu_seqlens_kv_padded
,
"add"
,
"copy"
)
tex
.
thd_grad_correction
(
dkv
,
dkv_
,
cu_seqlens_kv_padded
,
"add"
,
"copy"
)
else
:
else
:
dkv
.
add_
(
dkv_
)
dkv
.
add_
(
dkv_
)
elif
i
>=
(
cp_size
-
rank
-
1
):
elif
i
>=
(
cp_size
-
rank
-
1
):
...
@@ -2271,35 +2526,54 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -2271,35 +2526,54 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif
ctx
.
qkv_format
==
"sbhd"
:
elif
ctx
.
qkv_format
==
"sbhd"
:
dkv
[:,
0
,
...].
copy_
(
dkv_
)
dkv
[:,
0
,
...].
copy_
(
dkv_
)
elif
ctx
.
qkv_format
==
"thd"
:
elif
ctx
.
qkv_format
==
"thd"
:
tex
.
thd_grad_correction
(
dkv
,
dkv_
,
cu_seqlens_kv_padded
,
"copy"
,
"none"
)
tex
.
thd_grad_correction
(
dkv
,
dkv_
,
cu_seqlens_kv_padded
,
"copy"
,
"none"
)
else
:
else
:
if
ctx
.
qkv_format
==
"bshd"
:
if
ctx
.
qkv_format
==
"bshd"
:
dkv
[:,
:,
0
,
...].
add_
(
dkv_
)
dkv
[:,
:,
0
,
...].
add_
(
dkv_
)
elif
ctx
.
qkv_format
==
"sbhd"
:
elif
ctx
.
qkv_format
==
"sbhd"
:
dkv
[:,
0
,
...].
add_
(
dkv_
)
dkv
[:,
0
,
...].
add_
(
dkv_
)
elif
ctx
.
qkv_format
==
"thd"
:
elif
ctx
.
qkv_format
==
"thd"
:
tex
.
thd_grad_correction
(
dkv
,
dkv_
,
cu_seqlens_kv_padded
,
"add"
,
"none"
)
tex
.
thd_grad_correction
(
dkv
,
dkv_
,
cu_seqlens_kv_padded
,
"add"
,
"none"
)
elif
i
>
0
:
elif
i
>
0
:
dkv
.
add_
(
dkv_
)
dkv
.
add_
(
dkv_
)
else
:
else
:
# i == 0
dkv
.
copy_
(
dkv_
)
dkv
.
copy_
(
dkv_
)
else
:
else
:
# not fp8 and not causal
if
i
==
0
:
if
i
==
0
:
dkv
.
copy_
(
dkv_
)
dkv
.
copy_
(
dkv_
)
else
:
else
:
# i > 0
dkv
.
add_
(
dkv_
)
dkv
.
add_
(
dkv_
)
if
ctx
.
fp8
and
ctx
.
use_fused_attention
:
if
ctx
.
fp8
and
ctx
.
use_fused_attention
:
amax_cp_bwd
=
amax_per_step
.
amax
(
dim
=
1
)
amax_cp_bwd
=
amax_per_step
.
amax
(
dim
=
1
)
ctx
.
dP_quantizer
.
amax
.
copy_
(
amax_cp_bwd
[
0
])
ctx
.
dP_quantizer
.
amax
.
copy_
(
amax_cp_bwd
[
0
])
ctx
.
dQKV_CP_quantizer
.
amax
.
copy_
(
amax_cp_bwd
[
1
])
ctx
.
dQKV_CP_quantizer
.
amax
.
copy_
(
amax_cp_bwd
[
1
])
dq
=
ctx
.
dQKV_CP_quantizer
.
create_tensor_from_data
(
dq_fp8
,
fake_dtype
=
torch
.
float32
,
internal
=
True
)
if
ctx
.
enable_mla
:
# [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn]
dk_fp8
=
dkv_fp8
[:
ctx
.
k_numel
].
view
(
cp_size
,
*
ctx
.
k_shape
)
dv_fp8
=
dkv_fp8
[
ctx
.
k_numel
:].
view
(
cp_size
,
*
ctx
.
v_shape
)
dk
=
ctx
.
dQKV_CP_quantizer
.
create_tensor_from_data
(
dk_fp8
,
fake_dtype
=
torch
.
float32
,
internal
=
True
)
dv
=
ctx
.
dQKV_CP_quantizer
.
create_tensor_from_data
(
dv_fp8
,
fake_dtype
=
torch
.
float32
,
internal
=
True
)
dq
,
dk
,
dv
=
[
x
.
dequantize
(
dtype
=
torch
.
float32
)
for
x
in
[
dq
,
dk
,
dv
]]
dq
,
dk
,
dv
=
[
x
.
sum
(
dim
=
0
).
to
(
dout_dtype
)
for
x
in
[
dq
,
dk
,
dv
]]
else
:
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]:
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]:
# [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
# [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
# [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
# [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
dkv_fp8
=
dkv_fp8
.
view
(
cp_size
,
2
,
*
dkv_fp8
.
shape
[
1
:
-
3
],
*
dkv_fp8
.
shape
[
-
2
:])
dkv_fp8
=
dkv_fp8
.
view
(
cp_size
,
2
,
*
dkv_fp8
.
shape
[
1
:
-
3
],
*
dkv_fp8
.
shape
[
-
2
:])
dq
=
ctx
.
dQKV_CP_quantizer
.
create_tensor_from_data
(
dq_fp8
,
fake_dtype
=
torch
.
float32
,
internal
=
True
)
dkv
=
ctx
.
dQKV_CP_quantizer
.
create_tensor_from_data
(
dkv
=
ctx
.
dQKV_CP_quantizer
.
create_tensor_from_data
(
dkv_fp8
,
fake_dtype
=
torch
.
float32
,
internal
=
True
dkv_fp8
,
fake_dtype
=
torch
.
float32
,
internal
=
True
)
)
...
@@ -2310,21 +2584,39 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -2310,21 +2584,39 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if
ctx
.
qkv_format
==
"bshd"
:
if
ctx
.
qkv_format
==
"bshd"
:
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
dq
=
dq
.
view
(
dq
.
shape
[
0
],
-
1
,
*
dq
.
shape
[
-
2
:])
dq
=
dq
.
view
(
dq
.
shape
[
0
],
-
1
,
*
dq
.
shape
[
-
2
:])
if
ctx
.
enable_mla
:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
dk
=
dk
.
view
(
*
dk
.
shape
[
0
],
-
1
,
*
dk
.
shape
[
-
2
:])
dv
=
dv
.
view
(
*
dv
.
shape
[
0
],
-
1
,
*
dv
.
shape
[
-
2
:])
else
:
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
dkv
=
dkv
.
view
(
*
dkv
.
shape
[
0
:
2
],
-
1
,
*
dkv
.
shape
[
-
2
:])
dkv
=
dkv
.
view
(
*
dkv
.
shape
[
0
:
2
],
-
1
,
*
dkv
.
shape
[
-
2
:])
elif
ctx
.
qkv_format
==
"sbhd"
:
elif
ctx
.
qkv_format
==
"sbhd"
:
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
# [2, sq//2, b, np, hn] -> [sq, b, np, hn]
dq
=
dq
.
view
(
-
1
,
*
dq
.
shape
[
-
3
:])
dq
=
dq
.
view
(
-
1
,
*
dq
.
shape
[
-
3
:])
if
ctx
.
enable_mla
:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
dk
=
dk
.
view
(
-
1
,
*
dk
.
shape
[
-
3
:])
dv
=
dv
.
view
(
-
1
,
*
dv
.
shape
[
-
3
:])
else
:
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
dkv
=
dkv
.
view
(
dkv
.
shape
[
0
],
-
1
,
*
dkv
.
shape
[
-
3
:])
dkv
=
dkv
.
view
(
dkv
.
shape
[
0
],
-
1
,
*
dkv
.
shape
[
-
3
:])
if
ctx
.
qkv_format
==
"thd"
and
not
ctx
.
use_fused_attention
:
if
ctx
.
qkv_format
==
"thd"
and
not
ctx
.
use_fused_attention
:
dq
[
cu_seqlens_q_padded
[
-
1
]
:].
fill_
(
0
)
dq
[
cu_seqlens_q_padded
[
-
1
]
:].
fill_
(
0
)
if
ctx
.
enable_mla
:
dk
[
cu_seqlens_kv_padded
[
-
1
]
:].
fill_
(
0
)
dv
[
cu_seqlens_kv_padded
[
-
1
]
:].
fill_
(
0
)
else
:
dkv
[:,
cu_seqlens_kv_padded
[
-
1
]
:].
fill_
(
0
)
dkv
[:,
cu_seqlens_kv_padded
[
-
1
]
:].
fill_
(
0
)
if
ctx
.
fp8
and
ctx
.
is_input_fp8
:
if
ctx
.
fp8
and
ctx
.
is_input_fp8
:
assert
torch
.
uint8
not
in
[
dq
.
dtype
,
dkv
.
dtype
]
assert
torch
.
uint8
not
in
[
dq
.
dtype
,
dkv
.
dtype
]
if
ctx
.
enable_mla
:
dq
,
dk
,
dv
=
[
ctx
.
dQKV_quantizer
(
x
).
_data
for
x
in
[
dq
,
dk
,
dv
]]
else
:
dq
,
dkv
=
[
ctx
.
dQKV_quantizer
(
x
).
_data
for
x
in
[
dq
,
dkv
]]
dq
,
dkv
=
[
ctx
.
dQKV_quantizer
(
x
).
_data
for
x
in
[
dq
,
dkv
]]
if
not
ctx
.
enable_mla
:
dk
,
dv
=
dkv
[
0
],
dkv
[
1
]
dk
,
dv
=
dkv
[
0
],
dkv
[
1
]
if
cp_size_a2a
>
1
:
if
cp_size_a2a
>
1
:
...
@@ -3484,7 +3776,64 @@ def attn_forward_func_with_cp(
...
@@ -3484,7 +3776,64 @@ def attn_forward_func_with_cp(
use_flash_attn_3
=
False
,
use_flash_attn_3
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Attention implementation with context parallelism.
Attention implementation with context parallelism (CP). CP partitions tensors along the sequence
dimension, and by reducing the memory and computational pressure on each GPU, it enables long-context
LLMs in a distributed fashion. Transformer Engine's PyTorch CP implementation currently utilizes
the DualChunkSwap strategy to ensure load balancing across CP ranks. It is applied to all `attn_mask_type`s
and all `qkv_format`s, and it requires sequence lengths to be, or are padded to be, divisible by
(cp_size * 2). It also requires tokens to be re-ordered before entering this function.
For qkv_format = {'bshd', 'sbhd'}, the token re-ordering is illustrated as below, for an example
use case of s = 12, attn_mask_type = 'causal', and cp_size = 2. seq_pos indicates each token's position
in their corresponding sequence.
GPU0 | GPU1 GPU0 | GPU1
seq_pos | 0 1 2 3 4 5 | 6 7 8 9 10 11 seq_pos | 0 1 2 9 10 11 | 3 4 5 6 7 8
---------------------------|----------------- ---------------------------|------------------
0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
U 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 9 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 1, 1,
0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 10 | 1, 1, 1, 1, 1, 0,| 1, 1, 1, 1, 1, 1,
5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1,
---------------------------|----------------- ---------------------------|------------------
6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 3 | 1, 1, 1, 0, 0, 0,| 1, 0, 0, 0, 0, 0,
G 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 4 | 1, 1, 1, 0, 0, 0,| 1, 1, 0, 0, 0, 0,
P 8 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 0, 0, 0, P 5 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 0, 0, 0,
U 9 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 0, 0, U 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
1 10 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 0, 1 7 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 0,
11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, 8 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 1,
For qkv_format = 'thd', multiple sequences may be packed into the batch, and they may be of different
lengths. DualChunkSwap divides each sequence into (cp_size * 2) chunks and distributes 2 chunks of
every sequence onto a CP rank. The token matrix transformation is shown as follows, for an example of
batch_size = 2, seq_ids = [0, 1], seq_lens = [8, 4], t = 12, attn_mask_type = 'padding_causal', and
cp_size = 2.
GPU0 | GPU1 GPU0 | GPU1
seq_id | 0 0 0 0 0 0 | 0 0 1 1 1 1 seq_id | 0 0 0 0 1 1 | 0 0 0 0 1 1
seq_pos | 0 1 2 3 4 5 | 6 7 0 1 2 3 seq_pos | 0 1 6 7 0 3 | 2 3 4 5 1 2
---------------------------|----------------- ---------------------------|------------------
0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
P 0 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 0 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
U 0 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 0 7 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 0, 0,
0 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 1 0 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 0, 0,
0 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 1 3 | 0, 0, 0, 0, 2, 2,| 0, 0, 0, 0, 2, 2,
---------------------------|----------------- ---------------------------|------------------
0 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 0 2 | 1, 1, 0, 0, 0, 0,| 1, 0, 0, 0, 0, 0,
G 0 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 0 3 | 1, 1, 0, 0, 0, 0,| 1, 1, 0, 0, 0, 0,
P 1 0 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 0, 0, 0 P 0 4 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 0, 0, 0,
U 1 1 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 0, 0 U 0 5 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
1 1 2 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 0 1 1 1 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 0,
1 3 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 2 1 2 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 2,
When all transformer layers in a model share the same CP configuration, i.e. cp_group, cp_global_ranks,
cp_comm_type and cp_stream, token re-ordering can take place in the dataloader, i.e. only once for
all the layers. An example of the re-ordering code is `get_batch_on_this_cp_rank
<https://github.com/NVIDIA/Megatron-LM/blob/d6eb60b5ea1efca47401c0be97f456fbe3a55bcd/megatron/core/utils.py#L1725>`_
in Megatron-LM.
"""
"""
if
cp_comm_type
==
"a2a+p2p"
:
if
cp_comm_type
==
"a2a+p2p"
:
...
@@ -3527,6 +3876,12 @@ def attn_forward_func_with_cp(
...
@@ -3527,6 +3876,12 @@ def attn_forward_func_with_cp(
"all_gather"
,
"all_gather"
,
],
"The context parallel running configs cannot support sliding window attetnion!"
],
"The context parallel running configs cannot support sliding window attetnion!"
enable_mla
=
k
.
shape
[
-
1
]
!=
v
.
shape
[
-
1
]
assert
not
enable_mla
or
cp_comm_type
in
[
"p2p"
,
"a2a+p2p"
,
],
"The context parallel running configs cannot support MLA!"
args
=
[
args
=
[
is_training
,
is_training
,
q
,
q
,
...
...
transformer_engine/pytorch/attention/dot_product_attention/utils.py
View file @
2b05e121
...
@@ -624,11 +624,6 @@ def get_attention_backend(
...
@@ -624,11 +624,6 @@ def get_attention_backend(
" bias for THD format"
" bias for THD format"
)
)
use_fused_attention
=
False
use_fused_attention
=
False
elif
head_dim_qk
!=
head_dim_v
:
logger
.
debug
(
"Disabling FusedAttention as it does not support context parallelism with MLA"
)
use_fused_attention
=
False
# Filter: Attention mask
# Filter: Attention mask
# attn_mask_type | attention_mask | supported backends
# attn_mask_type | attention_mask | supported backends
...
@@ -782,6 +777,7 @@ def get_attention_backend(
...
@@ -782,6 +777,7 @@ def get_attention_backend(
q_type
=
get_fp8_te_dtype
(
fp8_meta
[
"recipe"
],
fprop_tensor
=
True
)
q_type
=
get_fp8_te_dtype
(
fp8_meta
[
"recipe"
],
fprop_tensor
=
True
)
kv_type
=
q_type
kv_type
=
q_type
fused_attention_backend
=
tex
.
get_fused_attn_backend
(
fused_attention_backend
=
tex
.
get_fused_attn_backend
(
is_training
,
q_type
,
q_type
,
kv_type
,
kv_type
,
QKVLayout
[
qkv_layout
],
QKVLayout
[
qkv_layout
],
...
@@ -962,15 +958,23 @@ def get_attention_backend(
...
@@ -962,15 +958,23 @@ def get_attention_backend(
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
get_padding_mask
(
def
get_padding_mask
(
batch_size
:
int
,
batch_size
:
int
,
cu_seqlens_q
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
=
None
,
cu_seqlens_kv
:
torch
.
Tensor
,
cu_seqlens_kv
:
torch
.
Tensor
=
None
,
max_seqlen_q
:
int
,
max_seqlen_q
:
int
=
None
,
max_seqlen_kv
:
int
,
max_seqlen_kv
:
int
=
None
,
attention_type
:
str
=
"self"
,
):
):
"""Convert cu_seqlens to attention_mask"""
"""Convert cu_seqlens to attention_mask"""
assert
(
cu_seqlens_q
is
not
None
and
max_seqlen_q
is
not
None
),
"cu_seqlens_q and max_seqlen_q are required for self-attention and cross-attention"
seqlens_q
=
cu_seqlens_q
[
1
:]
-
cu_seqlens_q
[:
-
1
]
seqlens_q
=
cu_seqlens_q
[
1
:]
-
cu_seqlens_q
[:
-
1
]
seqlens_kv
=
cu_seqlens_kv
[
1
:]
-
cu_seqlens_kv
[:
-
1
]
attention_mask_q
=
torch
.
Tensor
([]).
to
(
dtype
=
torch
.
bool
)
attention_mask_q
=
torch
.
Tensor
([]).
to
(
dtype
=
torch
.
bool
)
if
attention_type
==
"cross"
:
assert
(
cu_seqlens_kv
is
not
None
and
max_seqlen_kv
is
not
None
),
"cu_seqlens_kv and max_seqlen_kv are required for cross-attention"
seqlens_kv
=
cu_seqlens_kv
[
1
:]
-
cu_seqlens_kv
[:
-
1
]
attention_mask_kv
=
torch
.
Tensor
([]).
to
(
dtype
=
torch
.
bool
)
attention_mask_kv
=
torch
.
Tensor
([]).
to
(
dtype
=
torch
.
bool
)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
attention_mask_q
=
torch
.
cat
(
attention_mask_q
=
torch
.
cat
(
...
@@ -984,6 +988,7 @@ def get_padding_mask(
...
@@ -984,6 +988,7 @@ def get_padding_mask(
],
],
dim
=
0
,
dim
=
0
,
)
)
if
attention_type
==
"cross"
:
attention_mask_kv
=
torch
.
cat
(
attention_mask_kv
=
torch
.
cat
(
[
[
attention_mask_kv
,
attention_mask_kv
,
...
@@ -995,8 +1000,12 @@ def get_padding_mask(
...
@@ -995,8 +1000,12 @@ def get_padding_mask(
],
],
dim
=
0
,
dim
=
0
,
)
)
attention_mask_q
=
attention_mask_q
.
to
(
device
=
"cuda"
)
if
attention_type
==
"self"
:
attention_mask
=
attention_mask_q
else
:
attention_mask
=
(
attention_mask
=
(
attention_mask_q
.
to
(
device
=
"cuda"
)
,
attention_mask_q
,
attention_mask_kv
.
to
(
device
=
"cuda"
),
attention_mask_kv
.
to
(
device
=
"cuda"
),
)
)
return
attention_mask
return
attention_mask
...
...
transformer_engine/pytorch/attention/multi_head_attention.py
View file @
2b05e121
...
@@ -12,6 +12,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
...
@@ -12,6 +12,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.module.base
import
TransformerEngineBaseModule
from
transformer_engine.pytorch.module.base
import
TransformerEngineBaseModule
from
transformer_engine.pytorch.module
import
LayerNormLinear
,
Linear
from
transformer_engine.pytorch.module
import
LayerNormLinear
,
Linear
from
transformer_engine.pytorch.ops.basic.l2normalization
import
L2Normalization
from
transformer_engine.pytorch.utils
import
(
from
transformer_engine.pytorch.utils
import
(
SplitAlongDim
,
SplitAlongDim
,
divide
,
divide
,
...
@@ -174,6 +175,22 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -174,6 +175,22 @@ class MultiheadAttention(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument
fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`.
`fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False'
if set to `True`, L2 normalization is applied to query and key tensors
after RoPE (if applicable) but before attention computation.
This follows the Llama4 approach for QK normalization to improve
training stability and model performance.
qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors.
Only used when `use_qk_norm` is True.
seq_length: Optional[int], default = `None`
sequence length of input samples. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are used for
forward propagation and activation recompute phase.
micro_batch_size: Optional[int], default = `None`
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propagation and activation recompute phase.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -214,6 +231,10 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -214,6 +231,10 @@ class MultiheadAttention(torch.nn.Module):
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
qkv_format
:
str
=
"sbhd"
,
qkv_format
:
str
=
"sbhd"
,
name
:
str
=
None
,
name
:
str
=
None
,
use_qk_norm
:
bool
=
False
,
qk_norm_eps
:
float
=
1e-6
,
seq_length
:
Optional
[
int
]
=
None
,
micro_batch_size
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -267,6 +288,7 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -267,6 +288,7 @@ class MultiheadAttention(torch.nn.Module):
self
.
hidden_size_kv
=
self
.
hidden_size_per_attention_head
*
self
.
num_gqa_groups
self
.
hidden_size_kv
=
self
.
hidden_size_per_attention_head
*
self
.
num_gqa_groups
self
.
name
=
name
self
.
name
=
name
self
.
use_qk_norm
=
use_qk_norm
common_gemm_kwargs
=
{
common_gemm_kwargs
=
{
"fuse_wgrad_accumulation"
:
fuse_wgrad_accumulation
,
"fuse_wgrad_accumulation"
:
fuse_wgrad_accumulation
,
...
@@ -278,6 +300,14 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -278,6 +300,14 @@ class MultiheadAttention(torch.nn.Module):
"device"
:
device
,
"device"
:
device
,
}
}
# Initialize L2 normalization modules for query and key if enabled
if
self
.
use_qk_norm
:
self
.
qk_norm
=
L2Normalization
(
eps
=
qk_norm_eps
,
seq_length
=
seq_length
,
micro_batch_size
=
micro_batch_size
,
)
qkv_parallel_mode
=
"column"
if
set_parallel_mode
else
None
qkv_parallel_mode
=
"column"
if
set_parallel_mode
else
None
if
self
.
attention_type
==
"self"
:
if
self
.
attention_type
==
"self"
:
...
@@ -482,6 +512,8 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -482,6 +512,8 @@ class MultiheadAttention(torch.nn.Module):
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_kv
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_kv
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_q_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_kv_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen_q
:
Optional
[
int
]
=
None
,
max_seqlen_q
:
Optional
[
int
]
=
None
,
max_seqlen_kv
:
Optional
[
int
]
=
None
,
max_seqlen_kv
:
Optional
[
int
]
=
None
,
fast_zero_fill
:
bool
=
True
,
fast_zero_fill
:
bool
=
True
,
...
@@ -556,6 +588,12 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -556,6 +588,12 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
max_seqlen_q: Optional[int], default = `None`
max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`.
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided.
Calculated from `cu_seqlens_q` if not provided.
...
@@ -714,6 +752,18 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -714,6 +752,18 @@ class MultiheadAttention(torch.nn.Module):
for
x
in
(
key_layer
,
value_layer
)
for
x
in
(
key_layer
,
value_layer
)
)
)
if
self
.
qkv_format
==
"thd"
:
key_layer
,
value_layer
=
(
x
.
reshape
(
x
.
size
(
0
),
-
1
,
self
.
hidden_size_per_attention_head
)
for
x
in
(
key_layer
,
value_layer
)
)
else
:
# key, value: -> [sq, b, ng, hn]
key_layer
,
value_layer
=
(
x
.
reshape
(
x
.
size
(
0
),
x
.
size
(
1
),
-
1
,
self
.
hidden_size_per_attention_head
)
for
x
in
(
key_layer
,
value_layer
)
)
# Attention head [sq, b, h] --> [sq, b, hp]
# Attention head [sq, b, h] --> [sq, b, hp]
if
self
.
input_layernorm
:
if
self
.
input_layernorm
:
layernorm_query_outputs
=
self
.
layernorm_query
(
layernorm_query_outputs
=
self
.
layernorm_query
(
...
@@ -792,6 +842,14 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -792,6 +842,14 @@ class MultiheadAttention(torch.nn.Module):
interleaved
=
self
.
rotary_pos_interleaved
,
interleaved
=
self
.
rotary_pos_interleaved
,
)
)
# ===========================
# Apply L2 normalization to query and key tensors
# ===========================
if
self
.
use_qk_norm
:
query_layer
=
self
.
qk_norm
(
query_layer
)
key_layer
=
self
.
qk_norm
(
key_layer
)
# ===========================
# ===========================
# Core attention computation
# Core attention computation
# ===========================
# ===========================
...
@@ -803,6 +861,8 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -803,6 +861,8 @@ class MultiheadAttention(torch.nn.Module):
qkv_format
=
self
.
qkv_format
,
qkv_format
=
self
.
qkv_format
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_kv
,
cu_seqlens_kv
=
cu_seqlens_kv
,
cu_seqlens_q_padded
=
cu_seqlens_q_padded
,
cu_seqlens_kv_padded
=
cu_seqlens_kv_padded
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_kv
=
max_seqlen_kv
,
max_seqlen_kv
=
max_seqlen_kv
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
2b05e121
...
@@ -140,6 +140,14 @@ def general_gemm(
...
@@ -140,6 +140,14 @@ def general_gemm(
# There is not use_split_accumulator == False
# There is not use_split_accumulator == False
# implementation for Float8BlockwiseQTensorBase GEMM
# implementation for Float8BlockwiseQTensorBase GEMM
use_split_accumulator
=
True
use_split_accumulator
=
True
# Check that data format is supported
if
(
A
.
_data_format
!=
tex
.
Float8BlockScaleTensorFormat
.
GEMM_READY
or
B
.
_data_format
!=
tex
.
Float8BlockScaleTensorFormat
.
GEMM_READY
):
raise
RuntimeError
(
"GEMM with Float8BlockwiseQTensor requires GEMM_READY format"
)
args
=
(
args
=
(
A
,
A
,
transa
,
# transa
transa
,
# transa
...
...
transformer_engine/pytorch/cpu_offload.py
View file @
2b05e121
...
@@ -253,13 +253,21 @@ class SynchronizedGroupOffloadHandler(OffloadHandler):
...
@@ -253,13 +253,21 @@ class SynchronizedGroupOffloadHandler(OffloadHandler):
return
state
return
state
@
staticmethod
@
staticmethod
def
reload
(
state
,
non_blocking
=
None
):
def
reload
(
state
,
non_blocking
=
None
,
copy_buffer
=
None
):
"""Reload."""
"""Reload."""
dev
,
cpu_backup
=
state
dev
,
cpu_backup
=
state
if
non_blocking
is
None
:
if
non_blocking
is
None
:
non_blocking
=
cpu_backup
.
is_pinned
()
non_blocking
=
cpu_backup
.
is_pinned
()
if
copy_buffer
is
None
:
return
cpu_backup
.
to
(
dev
,
non_blocking
=
non_blocking
)
return
cpu_backup
.
to
(
dev
,
non_blocking
=
non_blocking
)
assert
cpu_backup
.
size
()
==
copy_buffer
.
size
(),
"Can't copy two buffers of different sizes!"
copy_buffer
.
copy_
(
cpu_backup
,
non_blocking
=
non_blocking
)
return
copy_buffer
def
tensor_push
(
self
,
tensor
:
torch
.
Tensor
,
**
kwargs
):
def
tensor_push
(
self
,
tensor
:
torch
.
Tensor
,
**
kwargs
):
"""Tensor push."""
"""Tensor push."""
# obtain a unique tensor tag
# obtain a unique tensor tag
...
@@ -300,6 +308,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
...
@@ -300,6 +308,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
num_offload_group
,
# must be <= actual number of groups (number of commits)
num_offload_group
,
# must be <= actual number of groups (number of commits)
num_model_group
,
num_model_group
,
tensor_need_offloading_checker
=
(
lambda
t
:
True
),
tensor_need_offloading_checker
=
(
lambda
t
:
True
),
double_buffering
=
False
,
debug
=
False
,
debug
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
(
super
().
__init__
(
...
@@ -314,11 +323,17 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
...
@@ -314,11 +323,17 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# Data structure to hold the FP8/MXFP8 tensor objects
# Data structure to hold the FP8/MXFP8 tensor objects
self
.
fp8_tensor_object_map
=
{}
self
.
fp8_tensor_object_map
=
{}
self
.
float8_transpose_cache_valid
=
{}
self
.
float8_transpose_cache_valid
=
{}
self
.
dereferencing_list
=
[]
# Tracking the number of layers offloaded
# Tracking the number of layers offloaded
self
.
offloaded_group_count
=
0
self
.
offloaded_group_count
=
0
# Core data structure that decides the window for offloading
# Core data structure that decides the window for offloading
self
.
layer_window_map
=
{}
self
.
layer_window_map
=
{}
# Data structures fo double buffered reloading
self
.
double_buffering
=
double_buffering
self
.
reload_double_buffer
=
[[],
[]]
self
.
double_buffer_created
=
False
# Logic to make offloading load balance across computation
# Logic to make offloading load balance across computation
# for optimal CPU/GPU interconnect usage
# for optimal CPU/GPU interconnect usage
constant
=
0
constant
=
0
...
@@ -360,6 +375,12 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
...
@@ -360,6 +375,12 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
self
.
tensor_tag_to_state
[
tensor_tag
]
=
[]
self
.
tensor_tag_to_state
[
tensor_tag
]
=
[]
self
.
tensor_tag_to_buf
[
tensor_tag
]
=
[]
self
.
tensor_tag_to_buf
[
tensor_tag
]
=
[]
# Added support for de-duplicating FP8 param tensors
for
_
,
value
in
self
.
fp8_tensor_object_map
.
items
():
if
tensor
is
value
:
self
.
dereferencing_list
.
append
(
tensor_tag
)
break
self
.
fp8_tensor_object_map
[
tensor_tag
]
=
tensor
self
.
fp8_tensor_object_map
[
tensor_tag
]
=
tensor
if
isinstance
(
tensor
,
Float8Tensor
):
if
isinstance
(
tensor
,
Float8Tensor
):
self
.
float8_transpose_cache_valid
[
tensor_tag
]
=
getattr
(
self
.
float8_transpose_cache_valid
[
tensor_tag
]
=
getattr
(
...
@@ -398,11 +419,18 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
...
@@ -398,11 +419,18 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# Handling the quantized tensor case specially here
# Handling the quantized tensor case specially here
if
isinstance
(
tensor
,
list
):
if
isinstance
(
tensor
,
list
):
# If it's a duplicated tensor, we don't need to locally
# write back a tensor as it would already be written
if
tensor_tag
in
self
.
dereferencing_list
:
self
.
dereferencing_list
.
remove
(
tensor_tag
)
else
:
self
.
fp8_tensor_object_map
[
tensor_tag
].
restore_from_saved
(
tensor
)
self
.
fp8_tensor_object_map
[
tensor_tag
].
restore_from_saved
(
tensor
)
tensor
=
self
.
fp8_tensor_object_map
.
pop
(
tensor_tag
)
tensor
=
self
.
fp8_tensor_object_map
.
pop
(
tensor_tag
)
self
.
tensor_tag_to_buf
.
pop
(
tensor_tag
,
None
)
if
self
.
double_buffering
:
tensor
.
do_not_clear
=
True
self
.
tensor_tag_to_buf
.
pop
(
tensor_tag
,
None
)
# the tensor should have been copied back in on_group_commit_backward()
# the tensor should have been copied back in on_group_commit_backward()
# which invokes bulk_reload_group.
# which invokes bulk_reload_group.
assert
not
isinstance
(
tensor
,
tuple
)
assert
not
isinstance
(
tensor
,
tuple
)
...
@@ -454,6 +482,20 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
...
@@ -454,6 +482,20 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# the first compute completion
# the first compute completion
if
current_group
==
0
:
if
current_group
==
0
:
self
.
d2h_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
self
.
d2h_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
if
not
self
.
double_buffer_created
:
# Creating the first copy of double buffer for tensors that are offloaded
for
tensor_tag
,
buf
in
self
.
tensor_tag_to_buf
.
items
():
if
isinstance
(
buf
,
list
):
for
b
in
buf
:
self
.
reload_double_buffer
[
0
].
append
(
torch
.
empty_like
(
b
)
if
self
.
double_buffering
else
None
)
else
:
self
.
reload_double_buffer
[
0
].
append
(
torch
.
empty_like
(
buf
)
if
self
.
double_buffering
else
None
)
self
.
bulk_offload_group
(
current_group
)
self
.
bulk_offload_group
(
current_group
)
# Window map data structure helps us synchronize based on number
# Window map data structure helps us synchronize based on number
...
@@ -483,6 +525,15 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
...
@@ -483,6 +525,15 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# Increment the offload group count to keep track
# Increment the offload group count to keep track
self
.
offloaded_group_count
+=
1
self
.
offloaded_group_count
+=
1
if
not
self
.
double_buffer_created
:
# Creating second copy of double buffer for tensors that are offloaded
if
current_group
==
(
self
.
num_layers
-
1
):
for
buf
in
self
.
reload_double_buffer
[
0
]:
self
.
reload_double_buffer
[
1
].
append
(
torch
.
empty_like
(
buf
)
if
self
.
double_buffering
else
None
)
self
.
double_buffer_created
=
True
def
on_group_commit_forward
(
self
):
def
on_group_commit_forward
(
self
):
"""This function will cause host device synchronization"""
"""This function will cause host device synchronization"""
# handle synchronization events
# handle synchronization events
...
@@ -494,28 +545,49 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
...
@@ -494,28 +545,49 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
"""Bulk reload group."""
"""Bulk reload group."""
assert
group_to_reload
<
self
.
num_offload_group
assert
group_to_reload
<
self
.
num_offload_group
buffer_idx
=
0
double_buffer_idx
=
group_to_reload
%
2
with
torch
.
cuda
.
stream
(
self
.
h2d_stream
):
with
torch
.
cuda
.
stream
(
self
.
h2d_stream
):
# move back tensors
# move back tensors
for
tensor_label
,
state
in
self
.
tensor_tag_to_state
.
items
():
for
tensor_label
,
state
in
self
.
tensor_tag_to_state
.
items
():
group_id
,
_
=
tensor_label
group_id
,
_
=
tensor_label
if
group_id
==
group_to_reload
:
if
group_id
==
group_to_reload
:
if
isinstance
(
state
,
tuple
):
if
isinstance
(
state
,
tuple
):
recovered_tensor
=
SynchronizedGroupOffloadHandler
.
reload
(
state
)
recovered_tensor
=
SynchronizedGroupOffloadHandler
.
reload
(
state
,
True
,
self
.
reload_double_buffer
[
double_buffer_idx
][
buffer_idx
]
)
buffer_idx
=
buffer_idx
+
1
self
.
tensor_tag_to_state
[
tensor_label
]
=
recovered_tensor
self
.
tensor_tag_to_state
[
tensor_label
]
=
recovered_tensor
elif
isinstance
(
state
,
list
):
elif
isinstance
(
state
,
list
):
tensor_list
=
[]
tensor_list
=
[]
for
state_tuple
in
state
:
for
state_tuple
in
state
:
if
isinstance
(
state_tuple
,
tuple
):
if
isinstance
(
state_tuple
,
tuple
):
tensor_list
.
append
(
tensor_list
.
append
(
SynchronizedGroupOffloadHandler
.
reload
(
state_tuple
)
SynchronizedGroupOffloadHandler
.
reload
(
state_tuple
,
True
,
self
.
reload_double_buffer
[
double_buffer_idx
][
buffer_idx
],
)
)
)
buffer_idx
=
buffer_idx
+
1
else
:
else
:
tensor_list
.
append
(
state_tuple
)
tensor_list
.
append
(
state_tuple
)
_
=
self
.
fp8_tensor_object_map
[
tensor_label
].
restore_from_saved
(
tensor_list
)
# No need to write back the duplicated tensor againn
# to the same location, this check ensures that
if
tensor_label
in
self
.
dereferencing_list
:
self
.
dereferencing_list
.
remove
(
tensor_label
)
else
:
_
=
self
.
fp8_tensor_object_map
[
tensor_label
].
restore_from_saved
(
tensor_list
)
if
isinstance
(
self
.
fp8_tensor_object_map
[
tensor_label
],
Float8Tensor
):
if
isinstance
(
self
.
fp8_tensor_object_map
[
tensor_label
],
Float8Tensor
):
self
.
fp8_tensor_object_map
[
tensor_label
].
_transpose_invalid
=
(
self
.
fp8_tensor_object_map
[
tensor_label
].
_transpose_invalid
=
(
self
.
float8_transpose_cache_valid
.
pop
(
tensor_label
)
self
.
float8_transpose_cache_valid
.
pop
(
tensor_label
)
)
)
self
.
tensor_tag_to_state
[
tensor_label
]
=
self
.
fp8_tensor_object_map
.
pop
(
self
.
tensor_tag_to_state
[
tensor_label
]
=
self
.
fp8_tensor_object_map
.
pop
(
tensor_label
tensor_label
)
)
...
@@ -552,6 +624,7 @@ def get_cpu_offload_context(
...
@@ -552,6 +624,7 @@ def get_cpu_offload_context(
model_layers
:
int
=
1
,
model_layers
:
int
=
1
,
offload_activations
:
bool
=
True
,
offload_activations
:
bool
=
True
,
offload_weights
:
bool
=
False
,
offload_weights
:
bool
=
False
,
double_buffering
:
bool
=
False
,
):
):
"""
"""
This function returns the CPU Offload context and the synchronizer function that needs to be
This function returns the CPU Offload context and the synchronizer function that needs to be
...
@@ -580,6 +653,8 @@ def get_cpu_offload_context(
...
@@ -580,6 +653,8 @@ def get_cpu_offload_context(
When set to `True`, offloads the activations for the TE layer.
When set to `True`, offloads the activations for the TE layer.
offload_weights: bool, default = `True`
offload_weights: bool, default = `True`
When set to `True`, offloads the weights for the TE layer.
When set to `True`, offloads the weights for the TE layer.
double_buffering: bool, default = `False`
When set to `True`, uses double buffering for offloading.
"""
"""
...
@@ -611,6 +686,7 @@ def get_cpu_offload_context(
...
@@ -611,6 +686,7 @@ def get_cpu_offload_context(
num_offload_group
=
num_layers
,
num_offload_group
=
num_layers
,
num_model_group
=
model_layers
,
num_model_group
=
model_layers
,
tensor_need_offloading_checker
=
tensor_need_offloading_checker
,
tensor_need_offloading_checker
=
tensor_need_offloading_checker
,
double_buffering
=
double_buffering
,
)
)
def
group_prefetch_offload_commit_async
(
tensor
):
def
group_prefetch_offload_commit_async
(
tensor
):
...
...
transformer_engine/pytorch/csrc/common.cpp
View file @
2b05e121
...
@@ -20,6 +20,20 @@ std::vector<size_t> getTensorShape(at::Tensor t) {
...
@@ -20,6 +20,20 @@ std::vector<size_t> getTensorShape(at::Tensor t) {
return
shape
;
return
shape
;
}
}
NVTEShape
convertTorchShape
(
const
c10
::
IntArrayRef
torch_shape
)
{
NVTEShape
ret
;
ret
.
ndim
=
torch_shape
.
size
();
constexpr
int
max_dimensions
=
sizeof
(
ret
.
data
)
/
sizeof
(
size_t
);
NVTE_CHECK
(
ret
.
ndim
<
max_dimensions
,
"Torch tensor has too many dimensions. Max supported: "
,
max_dimensions
,
" and got "
,
ret
.
ndim
,
"."
);
for
(
size_t
i
=
0
;
i
<
ret
.
ndim
;
++
i
)
{
const
auto
&
v
=
torch_shape
[
i
];
ret
.
data
[
i
]
=
static_cast
<
size_t
>
(
v
);
}
return
ret
;
}
std
::
unique_ptr
<
Quantizer
>
convert_quantizer
(
py
::
handle
quantizer
)
{
std
::
unique_ptr
<
Quantizer
>
convert_quantizer
(
py
::
handle
quantizer
)
{
init_extension
();
init_extension
();
if
(
quantizer
.
is_none
())
{
if
(
quantizer
.
is_none
())
{
...
...
transformer_engine/pytorch/csrc/common.h
View file @
2b05e121
...
@@ -34,6 +34,7 @@
...
@@ -34,6 +34,7 @@
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/padding.h>
#include <transformer_engine/padding.h>
...
@@ -177,6 +178,8 @@ class Float8BlockQuantizer : public Quantizer {
...
@@ -177,6 +178,8 @@ class Float8BlockQuantizer : public Quantizer {
bool
force_pow_2_scales
=
false
;
bool
force_pow_2_scales
=
false
;
// Amax within quantization tile has a floor of epsilon.
// Amax within quantization tile has a floor of epsilon.
float
amax_epsilon
=
0.0
;
float
amax_epsilon
=
0.0
;
// Whether quantized tensor will be used in an all-gather
bool
all_gather_usage
=
false
;
private:
private:
int
block_scaling_dim
=
2
;
int
block_scaling_dim
=
2
;
...
@@ -222,21 +225,23 @@ std::vector<size_t> getTensorShape(at::Tensor t);
...
@@ -222,21 +225,23 @@ std::vector<size_t> getTensorShape(at::Tensor t);
transformer_engine
::
DType
getTransformerEngineFP8Type
(
bool
e4m3_if_hybrid
,
transformer_engine
::
DType
getTransformerEngineFP8Type
(
bool
e4m3_if_hybrid
,
const
std
::
string
&
fp8_recipe
);
const
std
::
string
&
fp8_recipe
);
inline
size_t
typeTo
Size
(
transformer_engine
::
DType
t
)
{
inline
size_t
typeTo
NumBits
(
transformer_engine
::
DType
t
)
{
switch
(
t
)
{
switch
(
t
)
{
case
transformer_engine
::
DType
::
kInt64
:
case
transformer_engine
::
DType
::
kInt64
:
return
8
;
return
64
;
case
transformer_engine
::
DType
::
kInt32
:
case
transformer_engine
::
DType
::
kInt32
:
case
transformer_engine
::
DType
::
kFloat32
:
case
transformer_engine
::
DType
::
kFloat32
:
return
4
;
return
32
;
case
transformer_engine
::
DType
::
kInt16
:
case
transformer_engine
::
DType
::
kInt16
:
case
transformer_engine
::
DType
::
kFloat16
:
case
transformer_engine
::
DType
::
kFloat16
:
case
transformer_engine
::
DType
::
kBFloat16
:
case
transformer_engine
::
DType
::
kBFloat16
:
return
2
;
return
16
;
case
transformer_engine
::
DType
::
kByte
:
case
transformer_engine
::
DType
::
kByte
:
case
transformer_engine
::
DType
::
kFloat8E4M3
:
case
transformer_engine
::
DType
::
kFloat8E4M3
:
case
transformer_engine
::
DType
::
kFloat8E5M2
:
case
transformer_engine
::
DType
::
kFloat8E5M2
:
return
1
;
return
8
;
case
transformer_engine
::
DType
::
kFloat4E2M1
:
return
4
;
default:
default:
NVTE_ERROR
(
"Invalid type"
);
NVTE_ERROR
(
"Invalid type"
);
}
}
...
@@ -355,6 +360,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape);
...
@@ -355,6 +360,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape);
int
roundup
(
const
int
value
,
const
int
multiple
);
int
roundup
(
const
int
value
,
const
int
multiple
);
NVTEShape
convertTorchShape
(
const
c10
::
IntArrayRef
torch_shape
);
}
// namespace transformer_engine::pytorch
}
// namespace transformer_engine::pytorch
namespace
std
{
namespace
std
{
...
...
transformer_engine/pytorch/csrc/extensions.h
View file @
2b05e121
...
@@ -35,13 +35,11 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
...
@@ -35,13 +35,11 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
* Attention
* Attention
**************************************************************************************************/
**************************************************************************************************/
NVTE_Fused_Attn_Backend
get_fused_attn_backend
(
const
DType
q_dtype
,
const
DType
kv_dtype
,
NVTE_Fused_Attn_Backend
get_fused_attn_backend
(
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
bool
is_training
,
const
DType
q_dtype
,
const
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Mask_Type
attn_mask_type
,
float
p_dropout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
float
p_dropout
,
size_t
num_attn_heads
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
);
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
);
std
::
vector
<
py
::
object
>
fused_attn_fwd
(
std
::
vector
<
py
::
object
>
fused_attn_fwd
(
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
...
@@ -450,6 +448,8 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
...
@@ -450,6 +448,8 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
at
::
Tensor
get_buffer
(
bool
local_chunk
=
false
,
at
::
Tensor
get_buffer
(
bool
local_chunk
=
false
,
std
::
optional
<
std
::
vector
<
int64_t
>>
shape
=
std
::
nullopt
);
std
::
optional
<
std
::
vector
<
int64_t
>>
shape
=
std
::
nullopt
);
at
::
Stream
get_communication_stream
();
};
// CommOverlap
};
// CommOverlap
class
CommOverlapP2P
:
torch
::
CustomClassHolder
,
public
transformer_engine
::
CommOverlapP2PBase
{
class
CommOverlapP2P
:
torch
::
CustomClassHolder
,
public
transformer_engine
::
CommOverlapP2PBase
{
...
@@ -469,6 +469,8 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
...
@@ -469,6 +469,8 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
at
::
Tensor
get_buffer
(
bool
local_chunk
=
false
,
at
::
Tensor
get_buffer
(
bool
local_chunk
=
false
,
std
::
optional
<
std
::
vector
<
int64_t
>>
shape
=
std
::
nullopt
);
std
::
optional
<
std
::
vector
<
int64_t
>>
shape
=
std
::
nullopt
);
at
::
Stream
get_communication_stream
();
};
// CommOverlapP2P
};
// CommOverlapP2P
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
transformer_engine/pytorch/csrc/extensions/activation.cpp
View file @
2b05e121
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "../extensions.h"
#include "common.h"
#include "common.h"
#include "extensions.h"
#include "pybind.h"
#include "pybind.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
...
...
transformer_engine/pytorch/csrc/extensions/apply_rope.cpp
View file @
2b05e121
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include "../extensions.h"
#include "common.h"
#include "common.h"
#include "extensions.h"
namespace
transformer_engine
::
pytorch
{
namespace
transformer_engine
::
pytorch
{
...
...
Prev
1
…
6
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment