Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
646
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1160 additions
and
147 deletions
+1160
-147
transformer_engine/debug/features/log_fp8_tensor_stats.py
transformer_engine/debug/features/log_fp8_tensor_stats.py
+1
-1
transformer_engine/debug/features/log_tensor_stats.py
transformer_engine/debug/features/log_tensor_stats.py
+1
-1
transformer_engine/debug/features/per_tensor_scaling.py
transformer_engine/debug/features/per_tensor_scaling.py
+1
-1
transformer_engine/debug/features/utils/__init__.py
transformer_engine/debug/features/utils/__init__.py
+1
-1
transformer_engine/debug/features/utils/stats_buffer.py
transformer_engine/debug/features/utils/stats_buffer.py
+1
-1
transformer_engine/debug/features/utils/stats_computation.py
transformer_engine/debug/features/utils/stats_computation.py
+1
-1
transformer_engine/debug/pytorch/__init__.py
transformer_engine/debug/pytorch/__init__.py
+1
-1
transformer_engine/debug/pytorch/debug_quantization.py
transformer_engine/debug/pytorch/debug_quantization.py
+29
-7
transformer_engine/debug/pytorch/debug_state.py
transformer_engine/debug/pytorch/debug_state.py
+1
-1
transformer_engine/debug/pytorch/utils.py
transformer_engine/debug/pytorch/utils.py
+1
-1
transformer_engine/jax/__init__.py
transformer_engine/jax/__init__.py
+1
-1
transformer_engine/jax/activation.py
transformer_engine/jax/activation.py
+1
-1
transformer_engine/jax/attention.py
transformer_engine/jax/attention.py
+203
-32
transformer_engine/jax/checkpoint_policies.py
transformer_engine/jax/checkpoint_policies.py
+1
-1
transformer_engine/jax/cpp_extensions/__init__.py
transformer_engine/jax/cpp_extensions/__init__.py
+1
-1
transformer_engine/jax/cpp_extensions/activation.py
transformer_engine/jax/cpp_extensions/activation.py
+2
-2
transformer_engine/jax/cpp_extensions/amax.py
transformer_engine/jax/cpp_extensions/amax.py
+3
-3
transformer_engine/jax/cpp_extensions/attention.py
transformer_engine/jax/cpp_extensions/attention.py
+901
-87
transformer_engine/jax/cpp_extensions/base.py
transformer_engine/jax/cpp_extensions/base.py
+7
-1
transformer_engine/jax/cpp_extensions/gemm.py
transformer_engine/jax/cpp_extensions/gemm.py
+2
-2
No files found.
transformer_engine/debug/features/log_fp8_tensor_stats.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/debug/features/log_tensor_stats.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/debug/features/per_tensor_scaling.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/debug/features/utils/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/debug/features/utils/stats_buffer.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/debug/features/utils/stats_computation.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/debug/pytorch/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
transformer_engine/debug/pytorch/debug_quantization.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
@@ -9,7 +9,7 @@ These wrappers add logic related to debugging, using the nvdlfw_inspect package.
...
@@ -9,7 +9,7 @@ These wrappers add logic related to debugging, using the nvdlfw_inspect package.
"""
"""
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Optional
,
Tuple
,
Iterable
,
Union
from
typing
import
Optional
,
Tuple
,
Iterable
,
Union
,
List
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
...
@@ -62,12 +62,17 @@ class DebugQuantizer(Quantizer):
...
@@ -62,12 +62,17 @@ class DebugQuantizer(Quantizer):
self
.
tp_group
=
tp_group
# used in inspect_tensor calls
self
.
tp_group
=
tp_group
# used in inspect_tensor calls
self
.
iteration
=
TEDebugState
.
get_iteration
()
self
.
iteration
=
TEDebugState
.
get_iteration
()
# Configure parent quantizer
if
parent_quantizer
is
not
None
:
# .internal = True is slightly faster, but results
# .internal = True is slightly faster, but results
# in errors when caching the weights.
# in errors when caching the weights.
# Setting .internal = False is safer.
# Setting .internal = False is safer.
if
parent_quantizer
is
not
None
:
parent_quantizer
.
internal
=
False
parent_quantizer
.
internal
=
False
# .optimize_for_gemm = True is not supported because debug
# quantizers perform non-GEMM operations.
parent_quantizer
.
optimize_for_gemm
=
False
self
.
rowwise_gemm_name
,
self
.
columnwise_gemm_name
=
_tensor_to_gemm_names_map
[
tensor_name
]
self
.
rowwise_gemm_name
,
self
.
columnwise_gemm_name
=
_tensor_to_gemm_names_map
[
tensor_name
]
# next iteration when this quantizer will call any API
# next iteration when this quantizer will call any API
...
@@ -556,6 +561,23 @@ class DebugQuantizer(Quantizer):
...
@@ -556,6 +561,23 @@ class DebugQuantizer(Quantizer):
if
not
self
.
output_tensor
:
if
not
self
.
output_tensor
:
self
.
_update_parent_quantizer_usage
()
self
.
_update_parent_quantizer_usage
()
@
classmethod
def
multi_tensor_quantize
(
cls
,
tensor
:
torch
.
Tensor
,
quantizers
:
List
[
Quantizer
],
m_splits
:
List
[
int
],
activation_dtype
:
torch
.
dtype
,
)
->
List
[
DebugQuantizedTensor
]:
"""
Splits a tensor into a list of tensors and quantizes each tensor using a list of quantizers.
"""
tensors
=
torch
.
split
(
tensor
,
m_splits
)
output
=
[]
for
tensor
,
quantizer
in
zip
(
tensors
,
quantizers
):
output
.
append
(
quantizer
.
quantize
(
tensor
,
dtype
=
activation_dtype
))
return
output
class
DebugQuantizedTensor
(
QuantizedTensorStorage
):
class
DebugQuantizedTensor
(
QuantizedTensorStorage
):
"""
"""
...
@@ -623,9 +645,9 @@ class DebugQuantizedTensor(QuantizedTensorStorage):
...
@@ -623,9 +645,9 @@ class DebugQuantizedTensor(QuantizedTensorStorage):
"""Is used in the python gemm() to get tensor or transpose of the tensor."""
"""Is used in the python gemm() to get tensor or transpose of the tensor."""
return
self
.
rowwise_gemm_tensor
if
not
transpose
else
self
.
columnwise_gemm_tensor
return
self
.
rowwise_gemm_tensor
if
not
transpose
else
self
.
columnwise_gemm_tensor
def
size
(
self
):
def
size
(
self
,
*
args
):
"""Size of the tensor."""
"""Size of the tensor."""
return
self
.
rowwise_gemm_tensor
.
size
()
return
self
.
rowwise_gemm_tensor
.
size
(
*
args
)
def
update_usage
(
self
,
rowwise_usage
:
bool
=
None
,
columnwise_usage
:
bool
=
None
):
def
update_usage
(
self
,
rowwise_usage
:
bool
=
None
,
columnwise_usage
:
bool
=
None
):
"""Update usage of the tensor."""
"""Update usage of the tensor."""
...
...
transformer_engine/debug/pytorch/debug_state.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/debug/pytorch/utils.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
...
...
transformer_engine/jax/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""Transformer Engine bindings for JAX.
"""Transformer Engine bindings for JAX.
...
...
transformer_engine/jax/activation.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""Activation functions for Transformer Engine in JAX.
"""Activation functions for Transformer Engine in JAX.
...
...
transformer_engine/jax/attention.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""JAX multi-head attention modules"""
"""JAX multi-head attention modules"""
...
@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type
...
@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type
from
transformer_engine_jax
import
NVTE_QKV_Layout
from
transformer_engine_jax
import
NVTE_QKV_Layout
from
transformer_engine_jax
import
NVTE_QKV_Format
from
transformer_engine_jax
import
NVTE_QKV_Format
from
transformer_engine_jax
import
nvte_get_qkv_format
from
transformer_engine_jax
import
nvte_get_qkv_format
from
transformer_engine_jax
import
NVTE_Softmax_Type
from
.
import
cpp_extensions
as
tex
from
.
import
cpp_extensions
as
tex
...
@@ -74,6 +75,35 @@ class AttnMaskType(Enum):
...
@@ -74,6 +75,35 @@ class AttnMaskType(Enum):
]
]
class
AttnSoftmaxType
(
Enum
):
"""
VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)),
LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [H].
"""
VANILLA_SOFTMAX
=
NVTE_Softmax_Type
.
NVTE_VANILLA_SOFTMAX
OFF_BY_ONE_SOFTMAX
=
NVTE_Softmax_Type
.
NVTE_OFF_BY_ONE_SOFTMAX
LEARNABLE_SOFTMAX
=
NVTE_Softmax_Type
.
NVTE_LEARNABLE_SOFTMAX
@
classmethod
def
from_str
(
cls
,
softmax_type
:
str
)
->
"AttnSoftmaxType"
:
"""Convert string to AttnSoftmaxType: 'vanilla', 'off_by_one', or 'learnable'."""
softmax_type_map
=
{
"vanilla"
:
cls
.
VANILLA_SOFTMAX
,
"off_by_one"
:
cls
.
OFF_BY_ONE_SOFTMAX
,
"learnable"
:
cls
.
LEARNABLE_SOFTMAX
,
}
result
=
softmax_type_map
.
get
(
softmax_type
)
if
result
is
None
:
raise
ValueError
(
f
"Unknown softmax_type:
{
softmax_type
}
. "
"Valid options: 'vanilla', 'off_by_one', 'learnable'"
)
return
result
class
QKVFormat
(
Enum
):
class
QKVFormat
(
Enum
):
"""
"""
SBHD: q,k,v memory layout with [s, b, ..., h, d]
SBHD: q,k,v memory layout with [s, b, ..., h, d]
...
@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available(
...
@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available(
qkv_layout
,
qkv_layout
,
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
softmax_type
,
dropout_probability
,
dropout_probability
,
q_num_heads
,
q_num_heads
,
kv_num_heads
,
kv_num_heads
,
...
@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available(
...
@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available(
"""
"""
To check whether the fused attention kernel is supported
To check whether the fused attention kernel is supported
"""
"""
window_size_tuple
=
(
-
1
,
-
1
)
if
window_size
is
None
else
window_size
def
make_helper
(
attn_mask_type
):
def
make_helper
(
attn_mask_type
):
return
tex
.
FusedAttnHelper
(
return
tex
.
FusedAttnHelper
(
...
@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available(
...
@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available(
qkv_layout
,
qkv_layout
,
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
softmax_type
,
dropout_probability
,
dropout_probability
,
q_num_heads
,
q_num_heads
,
kv_num_heads
,
kv_num_heads
,
...
@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available(
...
@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available(
kv_max_seqlen
,
kv_max_seqlen
,
head_dim_qk
,
head_dim_qk
,
head_dim_v
,
head_dim_v
,
(
-
1
,
-
1
)
if
window_size
is
None
else
window_size
,
window_siz
e_tupl
e
,
)
)
return
make_helper
(
attn_mask_type
).
is_fused_attn_kernel_available
()
return
make_helper
(
attn_mask_type
).
is_fused_attn_kernel_available
()
...
@@ -353,23 +386,57 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
...
@@ -353,23 +386,57 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
return
batch
,
q_max_seqlen
,
kv_max_seqlen
return
batch
,
q_max_seqlen
,
kv_max_seqlen
def
reorder_causal_load_balancing
(
tensor
,
strategy
:
ReorderStrategy
,
cp_size
:
int
,
seq_dim
:
int
):
def
reorder_causal_load_balancing
(
tensor
,
strategy
:
ReorderStrategy
,
cp_size
:
int
,
seq_dim
:
int
,
stripe_size
:
int
|
None
=
None
):
"""Reorders a tensor for load balancing the compute of causal attention."""
"""Reorders a tensor for load balancing the compute of causal attention."""
if
strategy
==
ReorderStrategy
.
DualChunkSwap
:
if
strategy
==
ReorderStrategy
.
DualChunkSwap
:
if
stripe_size
is
not
None
:
raise
ValueError
(
f
"Incorrect value for CP dual chunk reordering
{
stripe_size
=
}
. stripe_size must be"
" None"
)
return
tex
.
attention
.
reorder_causal_dual_chunk_swap
(
tensor
,
cp_size
,
seq_dim
,
False
)
return
tex
.
attention
.
reorder_causal_dual_chunk_swap
(
tensor
,
cp_size
,
seq_dim
,
False
)
if
strategy
==
ReorderStrategy
.
Striped
:
if
strategy
==
ReorderStrategy
.
Striped
:
return
tex
.
attention
.
reorder_causal_striped
(
tensor
,
cp_size
,
seq_dim
,
False
)
# stripe_size > 1 is only supported for CP+THD+AG+Striped>1+SWA
# stripe_size = 128 is recommended for CP+THD+AG+Striped>1+SWA
if
stripe_size
is
not
None
and
stripe_size
<=
0
:
raise
ValueError
(
f
"Incorrect value for CP striped reordering
{
stripe_size
=
}
. stripe_size must be a"
" positive integer"
)
# Supporting old API defaults of stripe_size=1
effective_stripe_size
=
1
if
stripe_size
is
None
else
stripe_size
return
tex
.
attention
.
reorder_causal_striped
(
tensor
,
cp_size
,
seq_dim
,
False
,
effective_stripe_size
)
raise
ValueError
(
f
"Unsupported
{
strategy
=
}
"
)
raise
ValueError
(
f
"Unsupported
{
strategy
=
}
"
)
def
inverse_reorder_causal_load_balancing
(
def
inverse_reorder_causal_load_balancing
(
tensor
,
strategy
:
ReorderStrategy
,
cp_size
:
int
,
seq_dim
:
int
tensor
,
strategy
:
ReorderStrategy
,
cp_size
:
int
,
seq_dim
:
int
,
stripe_size
:
int
|
None
=
None
):
):
"""Inverse operation of `reorder_causal_load_balancing`."""
"""Inverse operation of `reorder_causal_load_balancing`."""
if
strategy
==
ReorderStrategy
.
DualChunkSwap
:
if
strategy
==
ReorderStrategy
.
DualChunkSwap
:
if
stripe_size
is
not
None
:
raise
ValueError
(
f
"Incorrect value for CP dual chunk reordering
{
stripe_size
=
}
. stripe_size must be"
" None"
)
return
tex
.
attention
.
reorder_causal_dual_chunk_swap
(
tensor
,
cp_size
,
seq_dim
,
True
)
return
tex
.
attention
.
reorder_causal_dual_chunk_swap
(
tensor
,
cp_size
,
seq_dim
,
True
)
if
strategy
==
ReorderStrategy
.
Striped
:
if
strategy
==
ReorderStrategy
.
Striped
:
return
tex
.
attention
.
reorder_causal_striped
(
tensor
,
cp_size
,
seq_dim
,
True
)
# stripe_size > 1 is only supported for CP+THD+AG+Striped>1+SWA
# stripe_size = 128 is recommended for CP+THD+AG+Striped>1+SWA
if
stripe_size
is
not
None
and
stripe_size
<=
0
:
raise
ValueError
(
f
"Incorrect value for CP reordering
{
stripe_size
=
}
. stripe_size must be a positive"
" integer"
)
# Supporting old API defaults of stripe_size=1
effective_stripe_size
=
1
if
stripe_size
is
None
else
stripe_size
return
tex
.
attention
.
reorder_causal_striped
(
tensor
,
cp_size
,
seq_dim
,
True
,
effective_stripe_size
)
raise
ValueError
(
f
"Unsupported
{
strategy
=
}
"
)
raise
ValueError
(
f
"Unsupported
{
strategy
=
}
"
)
...
@@ -497,6 +564,11 @@ def _segment_ids_pos_to_seqlens_offsets(
...
@@ -497,6 +564,11 @@ def _segment_ids_pos_to_seqlens_offsets(
#
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.
# examine only O(Q+KV) elements.
# For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation
# using the segment ids and pos along with mask type (causal or brcm) is sufficient.
# It does not need to involve SW for this mask's creation
# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
if
(
attn_mask_type
.
is_causal
()
and
window_size
is
None
)
or
(
if
(
attn_mask_type
.
is_causal
()
and
window_size
is
None
)
or
(
window_size
==
(
-
1
,
-
1
)
and
not
attn_mask_type
.
is_bottom_right
()
window_size
==
(
-
1
,
-
1
)
and
not
attn_mask_type
.
is_bottom_right
()
...
@@ -558,21 +630,6 @@ def _segment_ids_pos_to_seqlens_offsets(
...
@@ -558,21 +630,6 @@ def _segment_ids_pos_to_seqlens_offsets(
)
)
attn_mask
=
jnp
.
logical_and
(
segment_mask
,
causal_mask
)
attn_mask
=
jnp
.
logical_and
(
segment_mask
,
causal_mask
)
# TODO(KshitijLakhani): Evaluate if swa_mask is needed to procure seqlen and offsets
swa_mask
=
(
make_swa_mask
(
segment_pos_q
,
segment_pos_kv
,
window_size
,
dtype
=
jnp
.
bool
,
segment_ids_q
=
segment_ids_q
,
segment_ids_kv
=
segment_ids_kv
,
)
if
attn_mask_type
.
is_bottom_right
()
else
make_swa_mask
(
segment_pos_q
,
segment_pos_kv
,
window_size
,
dtype
=
jnp
.
bool
)
)
attn_mask
=
jnp
.
logical_and
(
attn_mask
,
swa_mask
)
attn_mask_with_id
=
jnp
.
where
(
attn_mask
,
segment_mask_with_id
,
0
)
attn_mask_with_id
=
jnp
.
where
(
attn_mask
,
segment_mask_with_id
,
0
)
q_seqlen
,
q_offset
,
kv_seqlen
,
kv_offset
=
_mask_to_seqlens_offset
(
q_seqlen
,
q_offset
,
kv_seqlen
,
kv_offset
=
_mask_to_seqlens_offset
(
attn_mask_with_id
,
max_segments_per_seq
attn_mask_with_id
,
max_segments_per_seq
...
@@ -601,7 +658,7 @@ class SequenceDescriptor:
...
@@ -601,7 +658,7 @@ class SequenceDescriptor:
- SequenceDescriptor.from_seqlens_and_offsets
- SequenceDescriptor.from_seqlens_and_offsets
For THD (packed) cases, where each batch may have not only 1 sequence.
For THD (packed) cases, where each batch may have not only 1 sequence.
- SequenceDescriptor.from_segment_ids_and_pos
- SequenceDescriptor.from_segment_ids_and_pos
Experimental feature for THD (packed) cases with
context parallelism.
Experimental feature for
BSHD (with and without reordering) and
THD (packed) cases with
out reordering
"""
"""
seqlens
:
Optional
[
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
]]
seqlens
:
Optional
[
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
]]
...
@@ -739,9 +796,14 @@ class SequenceDescriptor:
...
@@ -739,9 +796,14 @@ class SequenceDescriptor:
cls
,
cls
,
segment_ids
:
Union
[
jnp
.
ndarray
,
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
]],
segment_ids
:
Union
[
jnp
.
ndarray
,
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
]],
segment_pos
:
Optional
[
Union
[
jnp
.
ndarray
,
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
]]]
=
None
,
segment_pos
:
Optional
[
Union
[
jnp
.
ndarray
,
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
]]]
=
None
,
*
,
is_thd
:
bool
,
is_segment_ids_reordered
:
bool
,
)
->
SequenceDescriptor
:
)
->
SequenceDescriptor
:
"""
"""
Experimental factory method for inputs with segment IDs and optional positions. (THD)
Experimental factory method for inputs with segment IDs and optional positions.
segment_pos = None to be used only for: BSHD with or without load balancing and,
THD without load balancing
Args:
Args:
segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids):
segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids):
- q_segment_ids (jnp.ndarray):
- q_segment_ids (jnp.ndarray):
...
@@ -755,22 +817,84 @@ class SequenceDescriptor:
...
@@ -755,22 +817,84 @@ class SequenceDescriptor:
The position inside each segment for query, with shape [batch, max_seqlen].
The position inside each segment for query, with shape [batch, max_seqlen].
- kv_segment_pos (jnp.ndarray):
- kv_segment_pos (jnp.ndarray):
The position inside each segment for key, value, with shape [batch, max_seqlen].
The position inside each segment for key, value, with shape [batch, max_seqlen].
is_thd(bool): If True, QKVLayout is of type THD, else it is BSHD
is_segment_ids_reordered(bool): If True, the segment ids have been reordered for load balancing.
Only THD with load balancing is expected to have this flag set to True
Return:
Return:
A SequenceDescriptor with segment_ids/segment_pos initialized.
A SequenceDescriptor with segment_ids/segment_pos initialized.
"""
"""
q_seg_ids
,
kv_seg_ids
=
cls
.
_expand_to_pair
(
segment_ids
)
q_seg_ids
,
kv_seg_ids
=
cls
.
_expand_to_pair
(
segment_ids
)
if
segment_pos
is
not
None
:
# Using defaults : segment pos has to be generated.
segment_pos
=
cls
.
_expand_to_pair
(
segment_pos
)
if
segment_pos
is
None
:
else
:
# THD + load balanced segment_ids are not supported in this function
# BSHD + load balanced segment_ids are incorrect as BSHD handles reordering within the primitive itself
if
is_segment_ids_reordered
:
assert
not
is_thd
,
(
f
"
{
segment_pos
=
}
default arg is not supported for load balanced reordered"
" (Striped) THD inputs. Please pass the load balanced reordered segment_pos"
" and segment_ids explicitly to {from_segment_ids_and_pos.__qualname__}"
" using convenience function reorder_causal_load_balancing()"
)
assert
is_thd
,
(
f
"
{
segment_pos
=
}
default arg is not supported for load balanced reordered (Dual"
" Chunk) BSHD inputs. BSHD segment_pos and segment_ids do not need to be load"
" balanced reordered. The reordering for these is performed within the"
" primitive"
)
# Generate the default pos for THD and BSHD non-reordered segment_ids
def
generate_default_pos
(
seg_ids
):
if
is_thd
:
batch_size
,
seq_size
=
seg_ids
.
shape
# Assume that the first token belongs to a segment and is not a padded token
first_is_segment
=
jnp
.
full
((
batch_size
,
1
),
True
,
dtype
=
bool
)
# Get segment start positions
segment_start
=
jnp
.
concatenate
(
[
first_is_segment
,
(
seg_ids
[...,
1
:]
!=
seg_ids
[...,
:
-
1
])
&
(
seg_ids
[...,
1
:]
!=
0
),
],
axis
=-
1
,
)
# Get offset for location where new segment starts
segment_start_idx
=
jax
.
vmap
(
lambda
row
:
jnp
.
arange
(
row
.
size
)
*
row
)(
segment_start
)
segment_start_offsets
=
jax
.
vmap
(
jnp
.
maximum
.
accumulate
)(
segment_start_idx
)
# Get the last non-zero index - after this everything is padding
# (B,)
last_nonzero_idx
=
jax
.
vmap
(
lambda
segids_row
:
jnp
.
max
(
jnp
.
where
(
segids_row
!=
0
,
jnp
.
arange
(
seq_size
),
-
1
)
)
)(
seg_ids
)
seg_pos_no_thd
=
jnp
.
arange
(
seq_size
)
# Get a mask which can be used to zero out all the padding at the end (after the non-zero index)
mask
=
seg_pos_no_thd
<=
last_nonzero_idx
[:,
None
]
# Get the unmasked seg_pos for the THD sequence
seg_pos
=
(
jnp
.
broadcast_to
(
jnp
.
arange
(
seq_size
),
seg_ids
.
shape
)
-
segment_start_offsets
)
# Use the mask to zero out the padding at the end (after the non-zero index)
segment_pos
=
jax
.
vmap
(
lambda
pos_row
,
mask_row
:
jnp
.
where
(
mask_row
,
pos_row
,
0
)
)(
seg_pos
,
mask
)
return
segment_pos
def
generate_default_pos
(
segment_ids
):
seqlen
=
seg_ids
.
shape
[
-
1
]
seqlen
=
segment_ids
.
shape
[
-
1
]
return
jnp
.
broadcast_to
(
jnp
.
arange
(
seqlen
),
seg_ids
.
shape
)
return
jnp
.
broadcast_to
(
jnp
.
arange
(
seqlen
),
segment_ids
.
shape
)
q_seg_pos
=
generate_default_pos
(
q_seg_ids
)
q_seg_pos
=
generate_default_pos
(
q_seg_ids
)
kv_seg_pos
=
generate_default_pos
(
kv_seg_ids
)
kv_seg_pos
=
generate_default_pos
(
kv_seg_ids
)
segment_pos
=
(
q_seg_pos
,
kv_seg_pos
)
segment_pos
=
(
q_seg_pos
,
kv_seg_pos
)
# Explicitly passed segment_pos
else
:
segment_pos
=
cls
.
_expand_to_pair
(
segment_pos
)
return
cls
(
return
cls
(
segment_ids
=
(
q_seg_ids
,
kv_seg_ids
),
segment_ids
=
(
q_seg_ids
,
kv_seg_ids
),
...
@@ -786,6 +910,7 @@ def _legacy_fused_attn(
...
@@ -786,6 +910,7 @@ def _legacy_fused_attn(
attn_bias_type
:
AttnBiasType
,
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
scaling_factor
:
float
,
dropout_probability
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
is_training
:
bool
,
...
@@ -793,6 +918,7 @@ def _legacy_fused_attn(
...
@@ -793,6 +918,7 @@ def _legacy_fused_attn(
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
context_parallel_axis
:
str
=
""
,
softmax_offset
:
Optional
[
jnp
.
ndarray
]
=
None
,
):
):
"""
"""
Perform non-THD (non-packed) cuDNN fused attention.
Perform non-THD (non-packed) cuDNN fused attention.
...
@@ -815,6 +941,7 @@ def _legacy_fused_attn(
...
@@ -815,6 +941,7 @@ def _legacy_fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
dropout_probability (float): Dropout probability to apply during attention.
...
@@ -863,10 +990,12 @@ def _legacy_fused_attn(
...
@@ -863,10 +990,12 @@ def _legacy_fused_attn(
output
=
_fused_attn
(
output
=
_fused_attn
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
SequenceDescriptor
.
from_seqlens
((
q_seq_lens
,
kv_seq_lens
)),
SequenceDescriptor
.
from_seqlens
((
q_seq_lens
,
kv_seq_lens
)),
seed
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
...
@@ -900,6 +1029,7 @@ def fused_attn_thd(
...
@@ -900,6 +1029,7 @@ def fused_attn_thd(
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
context_parallel_axis
:
str
=
""
,
softmax_offset
:
Optional
[
jnp
.
ndarray
]
=
None
,
):
):
"""
"""
Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
...
@@ -937,6 +1067,7 @@ def fused_attn_thd(
...
@@ -937,6 +1067,7 @@ def fused_attn_thd(
output
=
_fused_attn
(
output
=
_fused_attn
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
SequenceDescriptor
.
from_seqlens_and_offsets
(
SequenceDescriptor
.
from_seqlens_and_offsets
(
(
q_seq_lens
,
kv_seq_lens
),
(
q_seq_offsets
,
kv_seq_offsets
)
(
q_seq_lens
,
kv_seq_lens
),
(
q_seq_offsets
,
kv_seq_offsets
)
),
),
...
@@ -945,6 +1076,7 @@ def fused_attn_thd(
...
@@ -945,6 +1076,7 @@ def fused_attn_thd(
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
softmax_type
=
AttnSoftmaxType
.
VANILLA_SOFTMAX
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
is_training
=
is_training
,
max_segments_per_seq
=
max_segments_per_seq
,
max_segments_per_seq
=
max_segments_per_seq
,
...
@@ -957,15 +1089,17 @@ def fused_attn_thd(
...
@@ -957,15 +1089,17 @@ def fused_attn_thd(
return
output
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
))
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
))
def
_fused_attn
(
def
_fused_attn
(
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
bias
:
Optional
[
jnp
.
ndarray
],
bias
:
Optional
[
jnp
.
ndarray
],
softmax_offset
:
Optional
[
jnp
.
ndarray
],
sequence_descriptor
:
SequenceDescriptor
,
sequence_descriptor
:
SequenceDescriptor
,
seed
:
Optional
[
jnp
.
ndarray
],
seed
:
Optional
[
jnp
.
ndarray
],
attn_bias_type
:
AttnBiasType
,
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
scaling_factor
:
float
,
dropout_probability
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
is_training
:
bool
,
...
@@ -975,15 +1109,18 @@ def _fused_attn(
...
@@ -975,15 +1109,18 @@ def _fused_attn(
context_parallel_causal_load_balanced
:
bool
,
context_parallel_causal_load_balanced
:
bool
,
context_parallel_axis
:
str
,
context_parallel_axis
:
str
,
context_checkpoint_name
:
str
=
"context"
,
context_checkpoint_name
:
str
=
"context"
,
stripe_size
:
int
|
None
=
None
,
):
):
output
,
_
=
_fused_attn_fwd_rule
(
output
,
_
=
_fused_attn_fwd_rule
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
sequence_descriptor
,
sequence_descriptor
,
seed
,
seed
,
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
qkv_layout
,
qkv_layout
,
softmax_type
,
scaling_factor
,
scaling_factor
,
dropout_probability
,
dropout_probability
,
is_training
,
is_training
,
...
@@ -993,6 +1130,7 @@ def _fused_attn(
...
@@ -993,6 +1130,7 @@ def _fused_attn(
context_parallel_causal_load_balanced
,
context_parallel_causal_load_balanced
,
context_parallel_axis
,
context_parallel_axis
,
context_checkpoint_name
=
context_checkpoint_name
,
context_checkpoint_name
=
context_checkpoint_name
,
stripe_size
=
stripe_size
,
)
)
return
output
return
output
...
@@ -1000,11 +1138,13 @@ def _fused_attn(
...
@@ -1000,11 +1138,13 @@ def _fused_attn(
def
_fused_attn_fwd_rule
(
def
_fused_attn_fwd_rule
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
sequence_descriptor
,
sequence_descriptor
,
seed
,
seed
,
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
qkv_layout
,
qkv_layout
,
softmax_type
,
scaling_factor
,
scaling_factor
,
dropout_probability
,
dropout_probability
,
is_training
,
is_training
,
...
@@ -1014,14 +1154,17 @@ def _fused_attn_fwd_rule(
...
@@ -1014,14 +1154,17 @@ def _fused_attn_fwd_rule(
context_parallel_causal_load_balanced
,
context_parallel_causal_load_balanced
,
context_parallel_axis
,
context_parallel_axis
,
context_checkpoint_name
,
context_checkpoint_name
,
stripe_size
,
):
):
output
,
softmax_aux
,
rng_state
=
tex
.
fused_attn_fwd
(
output
,
softmax_aux
,
rng_state
=
tex
.
fused_attn_fwd
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
sequence_descriptor
,
sequence_descriptor
,
seed
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
...
@@ -1031,6 +1174,7 @@ def _fused_attn_fwd_rule(
...
@@ -1031,6 +1174,7 @@ def _fused_attn_fwd_rule(
context_parallel_strategy
=
context_parallel_strategy
,
context_parallel_strategy
=
context_parallel_strategy
,
context_parallel_causal_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_causal_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_axis
=
context_parallel_axis
,
context_parallel_axis
=
context_parallel_axis
,
stripe_size
=
stripe_size
,
)
)
output
=
checkpoint_name
(
output
,
context_checkpoint_name
)
output
=
checkpoint_name
(
output
,
context_checkpoint_name
)
softmax_aux
=
checkpoint_name
(
softmax_aux
,
context_checkpoint_name
)
softmax_aux
=
checkpoint_name
(
softmax_aux
,
context_checkpoint_name
)
...
@@ -1041,6 +1185,7 @@ def _fused_attn_fwd_rule(
...
@@ -1041,6 +1185,7 @@ def _fused_attn_fwd_rule(
sequence_descriptor
,
sequence_descriptor
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
softmax_offset
,
output
,
output
,
)
)
...
@@ -1049,6 +1194,7 @@ def _fused_attn_bwd_rule(
...
@@ -1049,6 +1194,7 @@ def _fused_attn_bwd_rule(
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
qkv_layout
,
qkv_layout
,
softmax_type
,
scaling_factor
,
scaling_factor
,
dropout_probability
,
dropout_probability
,
is_training
,
is_training
,
...
@@ -1058,6 +1204,7 @@ def _fused_attn_bwd_rule(
...
@@ -1058,6 +1204,7 @@ def _fused_attn_bwd_rule(
context_parallel_causal_load_balanced
,
context_parallel_causal_load_balanced
,
context_parallel_axis
,
context_parallel_axis
,
context_checkpoint_name
,
context_checkpoint_name
,
stripe_size
,
ctx
,
ctx
,
dz
,
dz
,
):
):
...
@@ -1068,11 +1215,13 @@ def _fused_attn_bwd_rule(
...
@@ -1068,11 +1215,13 @@ def _fused_attn_bwd_rule(
sequence_descriptor
,
sequence_descriptor
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
softmax_offset
,
output
,
output
,
)
=
ctx
)
=
ctx
grad_qkv
,
grad_bias
=
tex
.
fused_attn_bwd
(
grad_qkv
,
grad_bias
,
grad_softmax_offset
=
tex
.
fused_attn_bwd
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -1080,6 +1229,7 @@ def _fused_attn_bwd_rule(
...
@@ -1080,6 +1229,7 @@ def _fused_attn_bwd_rule(
sequence_descriptor
,
sequence_descriptor
,
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
...
@@ -1089,12 +1239,16 @@ def _fused_attn_bwd_rule(
...
@@ -1089,12 +1239,16 @@ def _fused_attn_bwd_rule(
context_parallel_strategy
=
context_parallel_strategy
,
context_parallel_strategy
=
context_parallel_strategy
,
context_parallel_causal_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_causal_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_axis
=
context_parallel_axis
,
context_parallel_axis
=
context_parallel_axis
,
stripe_size
=
stripe_size
,
)
)
if
attn_bias_type
==
AttnBiasType
.
NO_BIAS
:
if
attn_bias_type
==
AttnBiasType
.
NO_BIAS
:
grad_bias
=
None
grad_bias
=
None
if
softmax_type
!=
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
:
grad_softmax_offset
=
None
return
(
return
(
grad_qkv
,
grad_qkv
,
grad_bias
,
grad_bias
,
grad_softmax_offset
,
None
,
None
,
None
,
None
,
)
)
...
@@ -1111,6 +1265,7 @@ def fused_attn(
...
@@ -1111,6 +1265,7 @@ def fused_attn(
attn_bias_type
:
AttnBiasType
,
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
scaling_factor
:
float
,
dropout_probability
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
is_training
:
bool
,
...
@@ -1120,6 +1275,8 @@ def fused_attn(
...
@@ -1120,6 +1275,8 @@ def fused_attn(
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
context_parallel_axis
:
str
=
""
,
context_checkpoint_name
:
str
=
"context"
,
context_checkpoint_name
:
str
=
"context"
,
softmax_offset
:
Optional
[
jnp
.
ndarray
]
=
None
,
stripe_size
:
int
|
None
=
None
,
):
):
"""
"""
Perform cuDNN fused attention.
Perform cuDNN fused attention.
...
@@ -1139,6 +1296,7 @@ def fused_attn(
...
@@ -1139,6 +1296,7 @@ def fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
dropout_probability (float): Dropout probability to apply during attention.
...
@@ -1153,6 +1311,14 @@ def fused_attn(
...
@@ -1153,6 +1311,14 @@ def fused_attn(
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape
[1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX.
If provided, this parameter will receive gradients during backpropagation.
stripe_size (int | None):
Indicates the striping size to be used when using ReorderStrategy.Striped.
Currently, a stripe_size > 1 is only supported for CP + THD + Striped + AG, whereas a stripe_size=1
is supported for both, CP + THD + Striped + AG and CP + THD + Striped + P2P(Ring)
None indicates no striping strategy
Returns:
Returns:
(jnp.ndarray): The output tensor from the fused attention.
(jnp.ndarray): The output tensor from the fused attention.
...
@@ -1200,6 +1366,7 @@ def fused_attn(
...
@@ -1200,6 +1366,7 @@ def fused_attn(
seed
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
...
@@ -1208,15 +1375,18 @@ def fused_attn(
...
@@ -1208,15 +1375,18 @@ def fused_attn(
context_parallel_strategy
=
context_parallel_strategy
,
context_parallel_strategy
=
context_parallel_strategy
,
context_parallel_causal_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_causal_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_axis
=
context_parallel_axis
,
context_parallel_axis
=
context_parallel_axis
,
softmax_offset
=
softmax_offset
,
)
)
output
=
_fused_attn
(
output
=
_fused_attn
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
sequence_descriptor
,
sequence_descriptor
,
seed
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
softmax_type
=
softmax_type
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
is_training
=
is_training
,
...
@@ -1226,5 +1396,6 @@ def fused_attn(
...
@@ -1226,5 +1396,6 @@ def fused_attn(
context_parallel_causal_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_causal_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_axis
=
context_parallel_axis
,
context_parallel_axis
=
context_parallel_axis
,
context_checkpoint_name
=
context_checkpoint_name
,
context_checkpoint_name
=
context_checkpoint_name
,
stripe_size
=
stripe_size
,
)
)
return
output
return
output
transformer_engine/jax/checkpoint_policies.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""Checkpoint policies for Transformer Engine in JAX.
"""Checkpoint policies for Transformer Engine in JAX.
...
...
transformer_engine/jax/cpp_extensions/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""Python interface for c++ extensions"""
"""Python interface for c++ extensions"""
...
...
transformer_engine/jax/cpp_extensions/activation.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
"""JAX/TE custom ops for activation"""
...
@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
...
@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
,
NoScaleTensor
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
,
NoScaleTensor
from
..quantize
import
(
from
..quantize
import
(
Quantizer
,
Quantizer
,
QuantizeLayout
,
DelayedScaleQuantizer
,
DelayedScaleQuantizer
,
ScalingMode
,
ScalingMode
,
QuantizeLayout
,
)
)
...
...
transformer_engine/jax/cpp_extensions/amax.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""JAX/TE custom ops for amax calculation"""
"""JAX/TE custom ops for amax calculation"""
...
@@ -73,7 +73,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
...
@@ -73,7 +73,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
transpose_batch_sequence
,
transpose_batch_sequence
,
):
):
"""
"""
amax calcuation abstract
amax calcu
l
ation abstract
"""
"""
del
amax_scope
,
transpose_batch_sequence
del
amax_scope
,
transpose_batch_sequence
...
@@ -251,7 +251,7 @@ class RHTAmaxCalculationPrimitive(BasePrimitive):
...
@@ -251,7 +251,7 @@ class RHTAmaxCalculationPrimitive(BasePrimitive):
flatten_axis
,
flatten_axis
,
):
):
"""
"""
amax calcuation implementation
amax calcu
l
ation implementation
"""
"""
assert
RHTAmaxCalculationPrimitive
.
inner_primitive
is
not
None
assert
RHTAmaxCalculationPrimitive
.
inner_primitive
is
not
None
(
(
...
...
transformer_engine/jax/cpp_extensions/attention.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""JAX/TE custom ops for attention"""
"""JAX/TE custom ops for attention"""
...
@@ -20,11 +20,13 @@ from transformer_engine_jax import NVTE_Fused_Attn_Backend
...
@@ -20,11 +20,13 @@ from transformer_engine_jax import NVTE_Fused_Attn_Backend
from
transformer_engine.jax.attention
import
(
from
transformer_engine.jax.attention
import
(
AttnBiasType
,
AttnBiasType
,
AttnMaskType
,
AttnMaskType
,
AttnSoftmaxType
,
QKVLayout
,
QKVLayout
,
QKVFormat
,
QKVFormat
,
CPStrategy
,
CPStrategy
,
SequenceDescriptor
,
SequenceDescriptor
,
)
)
from
..sharding
import
with_sharding_constraint_by_logical_axes
,
HEAD_AXES
,
is_mesh_available
from
.base
import
BasePrimitive
,
register_primitive
from
.base
import
BasePrimitive
,
register_primitive
from
.misc
import
(
from
.misc
import
(
...
@@ -61,6 +63,7 @@ __all__ = [
...
@@ -61,6 +63,7 @@ __all__ = [
meta_fields
=
[
meta_fields
=
[
"attn_bias_type"
,
"attn_bias_type"
,
"attn_mask_type"
,
"attn_mask_type"
,
"softmax_type"
,
"qkv_layout"
,
"qkv_layout"
,
"scaling_factor"
,
"scaling_factor"
,
"dropout_probability"
,
"dropout_probability"
,
...
@@ -70,6 +73,7 @@ __all__ = [
...
@@ -70,6 +73,7 @@ __all__ = [
"context_parallel_load_balanced"
,
"context_parallel_load_balanced"
,
"cp_axis"
,
"cp_axis"
,
"cp_striped_window_size"
,
"cp_striped_window_size"
,
"stripe_size"
,
],
],
)
)
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
@@ -80,6 +84,7 @@ class _FusedAttnConfig:
...
@@ -80,6 +84,7 @@ class _FusedAttnConfig:
attn_bias_type
:
AttnBiasType
attn_bias_type
:
AttnBiasType
attn_mask_type
:
AttnMaskType
attn_mask_type
:
AttnMaskType
softmax_type
:
AttnSoftmaxType
qkv_layout
:
QKVLayout
qkv_layout
:
QKVLayout
scaling_factor
:
float
scaling_factor
:
float
dropout_probability
:
float
dropout_probability
:
float
...
@@ -88,7 +93,10 @@ class _FusedAttnConfig:
...
@@ -88,7 +93,10 @@ class _FusedAttnConfig:
window_size
:
Tuple
[
int
,
int
]
window_size
:
Tuple
[
int
,
int
]
context_parallel_load_balanced
:
bool
context_parallel_load_balanced
:
bool
cp_axis
:
str
cp_axis
:
str
cp_striped_window_size
:
Tuple
[
int
,
int
]
# Only for CP + Ring + THD + SWA
cp_striped_window_size
:
Tuple
[
int
,
int
]
# Only for CP + Ring P2P + THD + SWA
stripe_size
:
(
int
|
None
)
# Only for CP + Striped. For Ring P2P, stripe_size=1 only.For AG, stripe_size>=1.
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
@@ -103,6 +111,7 @@ class FusedAttnHelper:
...
@@ -103,6 +111,7 @@ class FusedAttnHelper:
qkv_layout
:
QKVLayout
qkv_layout
:
QKVLayout
attn_bias_type
:
AttnBiasType
attn_bias_type
:
AttnBiasType
attn_mask_type
:
AttnMaskType
attn_mask_type
:
AttnMaskType
softmax_type
:
AttnSoftmaxType
dropout_probability
:
float
dropout_probability
:
float
q_num_heads
:
int
q_num_heads
:
int
kv_num_heads
:
int
kv_num_heads
:
int
...
@@ -125,6 +134,7 @@ class FusedAttnHelper:
...
@@ -125,6 +134,7 @@ class FusedAttnHelper:
self
.
qkv_layout
.
value
,
self
.
qkv_layout
.
value
,
self
.
attn_bias_type
.
value
,
self
.
attn_bias_type
.
value
,
self
.
attn_mask_type
.
value
,
self
.
attn_mask_type
.
value
,
self
.
softmax_type
.
value
,
self
.
dropout_probability
,
self
.
dropout_probability
,
self
.
q_num_heads
,
self
.
q_num_heads
,
self
.
kv_num_heads
,
self
.
kv_num_heads
,
...
@@ -254,7 +264,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -254,7 +264,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
name
=
"te_fused_attn_forward_ffi"
name
=
"te_fused_attn_forward_ffi"
multiple_results
=
True
multiple_results
=
True
impl_static_args
=
(
1
3
,)
impl_static_args
=
(
1
4
,)
inner_primitive
=
None
inner_primitive
=
None
outer_primitive
=
None
outer_primitive
=
None
...
@@ -264,6 +274,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -264,6 +274,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k_aval
,
k_aval
,
v_aval
,
v_aval
,
bias_aval
,
bias_aval
,
softmax_offset_aval
,
seed_aval
,
seed_aval
,
q_seqlen_or_cu_seqlen_aval
,
q_seqlen_or_cu_seqlen_aval
,
kv_seqlen_or_cu_seqlen_aval
,
kv_seqlen_or_cu_seqlen_aval
,
...
@@ -312,6 +323,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -312,6 +323,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config
.
qkv_layout
,
config
.
qkv_layout
,
config
.
attn_bias_type
,
config
.
attn_bias_type
,
config
.
attn_mask_type
,
config
.
attn_mask_type
,
config
.
softmax_type
,
config
.
dropout_probability
,
config
.
dropout_probability
,
attn_heads
,
attn_heads
,
num_gqa_groups
,
num_gqa_groups
,
...
@@ -375,6 +387,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -375,6 +387,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config
.
dropout_probability
,
config
.
dropout_probability
,
config
.
attn_bias_type
.
value
,
config
.
attn_bias_type
.
value
,
config
.
attn_mask_type
.
value
,
config
.
attn_mask_type
.
value
,
config
.
softmax_type
.
value
,
config
.
qkv_layout
.
value
,
config
.
qkv_layout
.
value
,
jax_dtype_to_te_dtype
(
q_aval
.
dtype
),
jax_dtype_to_te_dtype
(
q_aval
.
dtype
),
config
.
is_training
,
config
.
is_training
,
...
@@ -386,6 +399,12 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -386,6 +399,12 @@ class FusedAttnFwdPrimitive(BasePrimitive):
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
)
)
assert
softmax_offset_aval
.
dtype
==
jnp
.
float32
if
config
.
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
assert
softmax_offset_aval
.
shape
==
(
1
,
attn_heads
,
1
,
1
)
else
:
assert
softmax_offset_aval
.
shape
==
(
0
,)
return
out_aval
,
softmax_aux_aval
,
rng_state_aval
,
wkspace_aval
return
out_aval
,
softmax_aux_aval
,
rng_state_aval
,
wkspace_aval
@
staticmethod
@
staticmethod
...
@@ -405,6 +424,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -405,6 +424,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
seed
,
seed
,
q_cu_seqlen
,
q_cu_seqlen
,
kv_cu_seqlen
,
kv_cu_seqlen
,
...
@@ -453,6 +473,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -453,6 +473,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
seed
,
seed
,
q_cu_seqlen
,
q_cu_seqlen
,
kv_cu_seqlen
,
kv_cu_seqlen
,
...
@@ -481,6 +502,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -481,6 +502,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
deterministic
=
not
FusedAttnHelper
.
is_non_deterministic_allowed
(),
deterministic
=
not
FusedAttnHelper
.
is_non_deterministic_allowed
(),
window_size_left
=
window_size_left
,
window_size_left
=
window_size_left
,
window_size_right
=
window_size_right
,
window_size_right
=
window_size_right
,
softmax_type
=
int
(
config
.
softmax_type
.
value
),
)
)
@
staticmethod
@
staticmethod
...
@@ -489,6 +511,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -489,6 +511,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
seed
,
seed
,
q_seqlen
,
q_seqlen
,
kv_seqlen
,
kv_seqlen
,
...
@@ -508,7 +531,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -508,7 +531,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
segment_ids
=
(
_q_segment_ids
,
_kv_segment_ids
),
segment_ids
=
(
_q_segment_ids
,
_kv_segment_ids
),
segment_pos
=
(
_q_segment_pos
,
_kv_segment_pos
),
segment_pos
=
(
_q_segment_pos
,
_kv_segment_pos
),
)
)
(
q_seqlen
,
kv_seqlen
),
(
q_seq_offsets
,
k_seq_offsets
)
=
(
(
q_seqlen
,
kv_seqlen
),
(
q_seq_offsets
,
k_seq_offsets
)
=
(
sequence_descriptor
.
get_seqlens_and_offsets
(
sequence_descriptor
.
get_seqlens_and_offsets
(
config
.
attn_mask_type
,
config
.
attn_mask_type
,
...
@@ -517,7 +539,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -517,7 +539,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config
.
max_segments_per_seq
,
config
.
max_segments_per_seq
,
)
)
)
)
if
config
.
qkv_layout
.
is_thd
():
if
config
.
qkv_layout
.
is_thd
():
def
_fix_len_take
(
x
,
condition
,
fill_value
=-
1
):
def
_fix_len_take
(
x
,
condition
,
fill_value
=-
1
):
...
@@ -579,6 +600,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -579,6 +600,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
seed
,
seed
,
q_cu_seqlen
,
q_cu_seqlen
,
kv_cu_seqlen
,
kv_cu_seqlen
,
...
@@ -596,7 +618,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -596,7 +618,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
def
batcher
(
batched_args
,
batch_dims
,
*
,
config
):
def
batcher
(
batched_args
,
batch_dims
,
*
,
config
):
check_valid_batch_dims
(
batch_dims
)
check_valid_batch_dims
(
batch_dims
)
assert
FusedAttnFwdPrimitive
.
outer_primitive
is
not
None
assert
FusedAttnFwdPrimitive
.
outer_primitive
is
not
None
q_bdim
,
_
,
_
,
_
,
seed_bdim
,
*
_
=
batch_dims
q_bdim
,
_
,
_
,
_
,
_
,
seed_bdim
,
*
_
=
batch_dims
out_bdims
=
q_bdim
,
q_bdim
,
seed_bdim
out_bdims
=
q_bdim
,
q_bdim
,
seed_bdim
return
(
return
(
...
@@ -662,7 +684,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -662,7 +684,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
)
)
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
4
]
=
seed_sharding
arg_shardings
[
5
]
=
seed_sharding
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
=
tuple
(
arg_shardings
)
arg_shardings
=
tuple
(
arg_shardings
)
...
@@ -710,7 +732,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -710,7 +732,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
name
=
"te_fused_attn_backward_ffi"
name
=
"te_fused_attn_backward_ffi"
multiple_results
=
True
multiple_results
=
True
impl_static_args
=
(
1
6
,)
impl_static_args
=
(
1
7
,)
inner_primitive
=
None
inner_primitive
=
None
outer_primitive
=
None
outer_primitive
=
None
...
@@ -720,6 +742,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -720,6 +742,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_aval
,
k_aval
,
v_aval
,
v_aval
,
bias_aval
,
bias_aval
,
softmax_offset_aval
,
softmax_aux_aval
,
softmax_aux_aval
,
rng_state_aval
,
rng_state_aval
,
output_aval
,
output_aval
,
...
@@ -781,6 +804,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -781,6 +804,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
config
.
dropout_probability
,
config
.
dropout_probability
,
config
.
attn_bias_type
.
value
,
config
.
attn_bias_type
.
value
,
config
.
attn_mask_type
.
value
,
config
.
attn_mask_type
.
value
,
config
.
softmax_type
.
value
,
config
.
qkv_layout
.
value
,
config
.
qkv_layout
.
value
,
jax_dtype_to_te_dtype
(
q_aval
.
dtype
),
jax_dtype_to_te_dtype
(
q_aval
.
dtype
),
config
.
is_training
,
config
.
is_training
,
...
@@ -798,15 +822,39 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -798,15 +822,39 @@ class FusedAttnBwdPrimitive(BasePrimitive):
shape
=
wkspace_shape
,
dtype
=
te_dtype_to_jax_dtype
(
wkspace_dtype
)
shape
=
wkspace_shape
,
dtype
=
te_dtype_to_jax_dtype
(
wkspace_dtype
)
)
)
return
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
,
wkspace_aval
# Validate incoming softmax_offset shape and dtype
assert
(
softmax_offset_aval
.
dtype
==
jnp
.
float32
),
f
"Incorrect softmax_offset dtype:
{
softmax_offset_aval
.
dtype
}
, expected:
{
jnp
.
float32
}
"
if
config
.
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
assert
softmax_offset_aval
.
shape
==
(
1
,
attn_heads
,
1
,
1
),
(
f
"Incorrect softmax_offset shape for
{
config
.
softmax_type
}
:"
f
"
{
softmax_offset_aval
.
shape
}
, expected: (1,
{
attn_heads
}
, 1, 1)"
)
else
:
assert
softmax_offset_aval
.
shape
==
(
0
,),
(
f
"Incorrect softmax_offset shape for
{
config
.
softmax_type
}
:"
f
"
{
softmax_offset_aval
.
shape
}
, expected: (0,)"
)
if
config
.
softmax_type
==
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
dsoftmax_offset_aval
=
q_aval
.
update
(
shape
=
softmax_offset_aval
.
shape
,
dtype
=
softmax_offset_aval
.
dtype
)
else
:
dsoftmax_offset_aval
=
q_aval
.
update
(
shape
=
(
1
,
attn_heads
,
1
,
1
),
dtype
=
jnp
.
float32
)
return
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
,
dsoftmax_offset_aval
,
wkspace_aval
@
staticmethod
@
staticmethod
def
outer_abstract
(
*
args
,
**
kwargs
):
def
outer_abstract
(
*
args
,
**
kwargs
):
"""
"""
Fused attention fwd outer primitive abstract
Fused attention fwd outer primitive abstract
"""
"""
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
,
_
=
FusedAttnBwdPrimitive
.
abstract
(
*
args
,
**
kwargs
)
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
,
dsoftmax_offset_aval
,
_
=
(
return
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
FusedAttnBwdPrimitive
.
abstract
(
*
args
,
**
kwargs
)
)
return
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
,
dsoftmax_offset_aval
@
staticmethod
@
staticmethod
def
lowering
(
def
lowering
(
...
@@ -815,6 +863,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -815,6 +863,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -866,6 +915,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -866,6 +915,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -897,6 +947,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -897,6 +947,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
deterministic
=
not
FusedAttnHelper
.
is_non_deterministic_allowed
(),
deterministic
=
not
FusedAttnHelper
.
is_non_deterministic_allowed
(),
window_size_left
=
window_size_left
,
window_size_left
=
window_size_left
,
window_size_right
=
window_size_right
,
window_size_right
=
window_size_right
,
softmax_type
=
int
(
config
.
softmax_type
.
value
),
)
)
@
staticmethod
@
staticmethod
...
@@ -905,6 +956,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -905,6 +956,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -993,11 +1045,12 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -993,11 +1045,12 @@ class FusedAttnBwdPrimitive(BasePrimitive):
q_cu_seqlen
=
generate_cu_seqlen
(
q_seqlen
.
flatten
())
q_cu_seqlen
=
generate_cu_seqlen
(
q_seqlen
.
flatten
())
kv_cu_seqlen
=
generate_cu_seqlen
(
kv_seqlen
.
flatten
())
kv_cu_seqlen
=
generate_cu_seqlen
(
kv_seqlen
.
flatten
())
dq
,
dk
,
dv
,
dbias
,
_
=
FusedAttnBwdPrimitive
.
inner_primitive
.
bind
(
dq
,
dk
,
dv
,
dbias
,
dsoftmax_offset
,
_
=
FusedAttnBwdPrimitive
.
inner_primitive
.
bind
(
q
,
q
,
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -1012,15 +1065,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -1012,15 +1065,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_kv_segment_pos
,
_kv_segment_pos
,
config
=
config
,
config
=
config
,
)
)
return
dq
,
dk
,
dv
,
dbias
return
dq
,
dk
,
dv
,
dbias
,
dsoftmax_offset
@
staticmethod
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
config
):
def
batcher
(
batched_args
,
batch_dims
,
*
,
config
):
check_valid_batch_dims
(
batch_dims
)
check_valid_batch_dims
(
batch_dims
)
assert
FusedAttnBwdPrimitive
.
outer_primitive
is
not
None
assert
FusedAttnBwdPrimitive
.
outer_primitive
is
not
None
q_bdim
,
k_bdim
,
v_bdim
,
*
_
=
batch_dims
q_bdim
,
k_bdim
,
v_bdim
,
bias_bdim
,
softmax_offset_bdim
,
*
_
=
batch_dims
out_bdims
=
q_bdim
,
k_bdim
,
v_bdim
,
q
_bdim
out_bdims
=
q_bdim
,
k_bdim
,
v_bdim
,
bias_bdim
,
softmax_offset
_bdim
return
(
return
(
FusedAttnBwdPrimitive
.
outer_primitive
.
bind
(
*
batched_args
,
config
=
config
),
FusedAttnBwdPrimitive
.
outer_primitive
.
bind
(
*
batched_args
,
config
=
config
),
out_bdims
,
out_bdims
,
...
@@ -1033,11 +1086,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -1033,11 +1086,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
softmax_offset_spec
=
get_padded_spec
(
arg_infos
[
4
])
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
return
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
)
dsoftmax_offset_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
softmax_offset_spec
))
return
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
,
dsoftmax_offset_sharding
)
@
staticmethod
@
staticmethod
def
partition
(
config
,
mesh
,
arg_infos
,
result_infos
):
def
partition
(
config
,
mesh
,
arg_infos
,
result_infos
):
...
@@ -1046,21 +1101,30 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -1046,21 +1101,30 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
softmax_offset_spec
=
get_padded_spec
(
arg_infos
[
4
])
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
dsoftmax_offset_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
softmax_offset_spec
))
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
=
tuple
(
arg_shardings
)
arg_shardings
=
tuple
(
arg_shardings
)
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
)
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
,
dsoftmax_offset_sharding
,
)
def
sharded_impl
(
def
sharded_impl
(
q
,
q
,
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -1074,11 +1138,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -1074,11 +1138,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_q_segment_pos
,
_q_segment_pos
,
_kv_segment_pos
,
_kv_segment_pos
,
):
):
local_dq
,
local_dk
,
local_dv
,
local_dbias
=
FusedAttnBwdPrimitive
.
impl
(
local_dq
,
local_dk
,
local_dv
,
local_dbias
,
local_dsoftmax_offset
=
(
FusedAttnBwdPrimitive
.
impl
(
q
,
q
,
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -1093,17 +1159,22 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -1093,17 +1159,22 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_kv_segment_pos
,
_kv_segment_pos
,
config
=
config
,
config
=
config
,
)
)
)
global_dbias
=
local_dbias
global_dbias
=
local_dbias
if
config
.
attn_bias_type
is
not
AttnBiasType
.
NO_BIAS
:
if
config
.
attn_bias_type
is
not
AttnBiasType
.
NO_BIAS
:
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
local_dbias
,
mesh
)
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
local_dbias
,
mesh
)
return
local_dq
,
local_dk
,
local_dv
,
global_dbias
global_dsoftmax_offset
=
local_dsoftmax_offset
if
config
.
softmax_type
==
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
:
global_dsoftmax_offset
=
all_reduce_sum_along_dp_fsdp
(
local_dsoftmax_offset
,
mesh
)
return
local_dq
,
local_dk
,
local_dv
,
global_dbias
,
global_dsoftmax_offset
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
@
staticmethod
@
staticmethod
def
shardy_sharding_rule
(
config
,
mesh
,
value_types
,
result_types
):
def
shardy_sharding_rule
(
config
,
mesh
,
value_types
,
result_types
):
del
config
,
mesh
del
config
,
mesh
# We only care about the four first arguments.
# Keep in sync with `infer_sharding_from_operands`.
# Keep in sync with `infer_sharding_from_operands`.
input_spec
=
tuple
((
f
"…
{
x
}
"
,)
for
x
in
range
(
len
(
value_types
)))
input_spec
=
tuple
((
f
"…
{
x
}
"
,)
for
x
in
range
(
len
(
value_types
)))
output_spec
=
tuple
((
f
"…
{
x
}
"
,)
for
x
in
range
(
len
(
result_types
)))
output_spec
=
tuple
((
f
"…
{
x
}
"
,)
for
x
in
range
(
len
(
result_types
)))
...
@@ -1165,31 +1236,38 @@ def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contig
...
@@ -1165,31 +1236,38 @@ def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contig
return
combined
.
reshape
(
ori_tensor_shape
)
return
combined
.
reshape
(
ori_tensor_shape
)
def
reorder_causal_striped
(
tensor
,
cp_size
:
int
,
seq_dim
:
int
,
is_inverse
:
bool
):
def
reorder_causal_striped
(
tensor
,
cp_size
:
int
,
seq_dim
:
int
,
is_inverse
:
bool
,
stripe_size
:
int
=
1
):
"""Reorders a tensor for load balancing with striped pattern"""
"""Reorders a tensor for load balancing with striped pattern"""
origin_shape
=
tensor
.
shape
origin_shape
=
tensor
.
shape
if
origin_shape
[
seq_dim
]
%
cp
_size
!
=
0
:
if
stripe
_size
<
=
0
:
raise
ValueError
(
raise
ValueError
(
"Expected origin_shape[seq_dim] is multiple of cp_size but got"
f
"Incorrect value for CP reordering
{
stripe_size
=
}
. stripe_size must be a positive"
f
"
{
origin_shape
[
seq_dim
]
=
}
and
{
cp_size
=
}
"
" integer"
)
if
origin_shape
[
seq_dim
]
%
(
cp_size
*
stripe_size
)
!=
0
:
raise
ValueError
(
"Expected origin_shape[seq_dim] is multiple of cp_size*stripe_size but got"
f
"
{
origin_shape
[
seq_dim
]
=
}
,
{
cp_size
=
}
,
{
stripe_size
=
}
,
{
cp_size
*
stripe_size
=
}
"
)
)
if
not
is_inverse
:
if
not
is_inverse
:
new_shape
=
[
new_shape
=
[
*
origin_shape
[:
seq_dim
],
*
origin_shape
[:
seq_dim
],
*
[
origin_shape
[
seq_dim
]
//
cp_size
,
cp
_size
],
*
[
origin_shape
[
seq_dim
]
//
(
cp_size
*
stripe_size
),
cp_size
,
stripe
_size
],
*
origin_shape
[
seq_dim
+
1
:],
*
origin_shape
[
seq_dim
+
1
:],
]
]
else
:
else
:
new_shape
=
[
new_shape
=
[
*
origin_shape
[:
seq_dim
],
*
origin_shape
[:
seq_dim
],
*
[
cp_size
,
origin_shape
[
seq_dim
]
//
cp_size
],
*
[
cp_size
,
origin_shape
[
seq_dim
]
//
(
cp_size
*
stripe_size
),
stripe_size
],
*
origin_shape
[
seq_dim
+
1
:],
*
origin_shape
[
seq_dim
+
1
:],
]
]
chunk
ed_tensor
=
tensor
.
reshape
(
new_shape
)
strip
ed_tensor
=
tensor
.
reshape
(
new_shape
)
reordered_
chunk
ed_tensor
=
jnp
.
swapaxes
(
chunk
ed_tensor
,
seq_dim
,
seq_dim
+
1
)
reordered_
strip
ed_tensor
=
jnp
.
swapaxes
(
strip
ed_tensor
,
seq_dim
,
seq_dim
+
1
)
return
reordered_
chunk
ed_tensor
.
reshape
(
origin_shape
)
return
reordered_
strip
ed_tensor
.
reshape
(
origin_shape
)
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
@@ -1203,43 +1281,85 @@ class _FusedAttnCPWithAllGatherHelper:
...
@@ -1203,43 +1281,85 @@ class _FusedAttnCPWithAllGatherHelper:
"""Checks if the context parallel implementation is supported by the given arguments."""
"""Checks if the context parallel implementation is supported by the given arguments."""
header
=
"Context parallel fused attention"
header
=
"Context parallel fused attention"
allowed_layouts
=
[
QKVLayout
.
BSHD_BS2HD
,
QKVLayout
.
BSHD_BSHD_BSHD
]
allowed_layouts
=
[
QKVLayout
.
BSHD_BS2HD
,
QKVLayout
.
BSHD_BSHD_BSHD
,
QKVLayout
.
THD_T2HD
,
QKVLayout
.
THD_THD_THD
,
]
if
self
.
config
.
qkv_layout
not
in
allowed_layouts
:
if
self
.
config
.
qkv_layout
not
in
allowed_layouts
:
raise
ValueError
(
raise
ValueError
(
f
"
{
header
}
only supports layouts:"
f
"
{
header
}
only supports layouts:"
f
"
{
','
.
join
(
map
(
str
,
allowed_layouts
))
}
got:
{
self
.
config
.
qkv_layout
}
"
f
"
{
','
.
join
(
map
(
str
,
allowed_layouts
))
}
got:
{
self
.
config
.
qkv_layout
}
"
)
)
if
(
not
self
.
config
.
qkv_layout
.
is_thd
()
and
self
.
config
.
stripe_size
is
not
None
)
or
(
self
.
config
.
qkv_layout
.
is_thd
()
and
self
.
config
.
stripe_size
is
None
):
raise
ValueError
(
f
"
{
header
}
only supports Dual Chunk load balancing with BSHD layouts and Striped"
" load balancing with THD layouts"
)
if
self
.
config
.
attn_bias_type
!=
AttnBiasType
.
NO_BIAS
:
if
self
.
config
.
attn_bias_type
!=
AttnBiasType
.
NO_BIAS
:
raise
ValueError
(
f
"
{
header
}
does not support bias got:
{
self
.
config
.
attn_bias_type
}
"
)
raise
ValueError
(
f
"
{
header
}
does not support bias got:
{
self
.
config
.
attn_bias_type
}
"
)
allowed_masks
=
[
AttnMaskType
.
NO_MASK
,
AttnMaskType
.
CAUSAL_MASK
]
allowed_masks
=
[
AttnMaskType
.
NO_MASK
,
AttnMaskType
.
CAUSAL_MASK
]
if
self
.
config
.
qkv_layout
.
is_thd
():
allowed_masks
.
append
(
AttnMaskType
.
PADDING_CAUSAL_MASK
)
if
self
.
config
.
attn_mask_type
not
in
allowed_masks
:
if
self
.
config
.
attn_mask_type
not
in
allowed_masks
:
raise
ValueError
(
raise
ValueError
(
f
"
{
header
}
only supports masking types: "
f
"
{
header
}
only supports masking types: "
f
"
{
','
.
join
(
map
(
str
,
allowed_masks
))
}
got:
{
self
.
config
.
attn_mask_type
}
"
f
"
{
','
.
join
(
map
(
str
,
allowed_masks
))
}
got:
{
self
.
config
.
attn_mask_type
}
"
)
)
# Do not allow CP + AG + THD + Striped with NO_MASK
if
(
self
.
config
.
attn_mask_type
is
not
AttnMaskType
.
PADDING_CAUSAL_MASK
and
self
.
config
.
qkv_layout
.
is_thd
()
):
raise
ValueError
(
f
"
{
header
}
only supports PADDING_CAUSAL_MASK for THD types"
)
if
self
.
config
.
max_segments_per_seq
!=
1
:
if
self
.
config
.
max_segments_per_seq
!=
1
and
(
not
self
.
config
.
qkv_layout
.
is_thd
)
:
raise
ValueError
(
raise
ValueError
(
f
"
{
header
}
only supports max_segments_per_seq == 1 got:"
f
"
{
header
}
only supports max_segments_per_seq == 1
for BSHD layouts,
got:"
f
"
{
self
.
config
.
max_segments_per_seq
}
"
f
"
{
self
.
config
.
max_segments_per_seq
}
"
)
)
if
self
.
config
.
dropout_probability
!=
0.0
:
if
self
.
config
.
dropout_probability
!=
0.0
:
raise
ValueError
(
f
"
{
header
}
does not support dropout"
)
raise
ValueError
(
f
"
{
header
}
does not support dropout"
)
if
self
.
config
.
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
raise
ValueError
(
f
"
{
header
}
only supports VANILLA_SOFTMAX, got:
{
self
.
config
.
softmax_type
}
"
)
def
get_adjusted_mask
(
self
):
def
get_adjusted_mask
(
self
):
"""Converts the mask for context parallelism."""
"""Converts the mask for context parallelism."""
if
self
.
config
.
attn_mask_type
==
AttnMaskType
.
CAUSAL_MASK
:
if
(
self
.
config
.
attn_mask_type
==
AttnMaskType
.
CAUSAL_MASK
and
not
self
.
config
.
qkv_layout
.
is_thd
()
):
# BSHD AG case only
return
AttnMaskType
.
CAUSAL_BOTTOM_RIGHT_MASK
return
AttnMaskType
.
CAUSAL_BOTTOM_RIGHT_MASK
if
(
self
.
config
.
attn_mask_type
==
AttnMaskType
.
PADDING_CAUSAL_MASK
and
self
.
config
.
qkv_layout
.
is_thd
()
):
# THD AG case only
return
AttnMaskType
.
PADDING_CAUSAL_BOTTOM_RIGHT_MASK
return
self
.
config
.
attn_mask_type
return
self
.
config
.
attn_mask_type
def
get_adjusted_max_segments_per_seq
(
self
,
max_seqlen
,
cp_size
):
"""Converts the max segments per seq for context parallelism AG + THD."""
# Estimating adjusted max segments per seq
return
(
max_seqlen
//
(
self
.
config
.
stripe_size
*
cp_size
)
)
+
self
.
config
.
max_segments_per_seq
def
get_step_config
(
self
)
->
_FusedAttnConfig
:
def
get_step_config
(
self
)
->
_FusedAttnConfig
:
"""Returns a _FusedAttnConfig for single CP step call to fused attention."""
"""Returns a _FusedAttnConfig for single CP step call to fused attention."""
return
_FusedAttnConfig
(
return
_FusedAttnConfig
(
attn_bias_type
=
self
.
config
.
attn_bias_type
,
attn_bias_type
=
self
.
config
.
attn_bias_type
,
attn_mask_type
=
self
.
get_adjusted_mask
(),
attn_mask_type
=
self
.
get_adjusted_mask
(),
softmax_type
=
self
.
config
.
softmax_type
,
qkv_layout
=
self
.
config
.
qkv_layout
,
qkv_layout
=
self
.
config
.
qkv_layout
,
scaling_factor
=
self
.
config
.
scaling_factor
,
scaling_factor
=
self
.
config
.
scaling_factor
,
dropout_probability
=
self
.
config
.
dropout_probability
,
dropout_probability
=
self
.
config
.
dropout_probability
,
...
@@ -1249,10 +1369,29 @@ class _FusedAttnCPWithAllGatherHelper:
...
@@ -1249,10 +1369,29 @@ class _FusedAttnCPWithAllGatherHelper:
context_parallel_load_balanced
=
self
.
config
.
context_parallel_load_balanced
,
context_parallel_load_balanced
=
self
.
config
.
context_parallel_load_balanced
,
cp_axis
=
self
.
config
.
cp_axis
,
cp_axis
=
self
.
config
.
cp_axis
,
cp_striped_window_size
=
None
,
cp_striped_window_size
=
None
,
stripe_size
=
self
.
config
.
stripe_size
,
)
def
get_step_config_for_striped
(
self
,
max_seqlen
,
cp_size
)
->
_FusedAttnConfig
:
"""Returns a _FusedAttnConfig for single CP step call (made via a striped AG primitive) to fused attention."""
return
_FusedAttnConfig
(
attn_bias_type
=
self
.
config
.
attn_bias_type
,
attn_mask_type
=
self
.
get_adjusted_mask
(),
softmax_type
=
self
.
config
.
softmax_type
,
qkv_layout
=
self
.
config
.
qkv_layout
,
scaling_factor
=
self
.
config
.
scaling_factor
,
dropout_probability
=
self
.
config
.
dropout_probability
,
is_training
=
self
.
config
.
is_training
,
max_segments_per_seq
=
self
.
get_adjusted_max_segments_per_seq
(
max_seqlen
,
cp_size
),
window_size
=
self
.
config
.
window_size
,
context_parallel_load_balanced
=
self
.
config
.
context_parallel_load_balanced
,
cp_axis
=
self
.
config
.
cp_axis
,
cp_striped_window_size
=
None
,
stripe_size
=
self
.
config
.
stripe_size
,
)
)
def
all_gather_kv
(
self
,
k
,
v
):
def
all_gather_kv
(
self
,
k
,
v
):
"""Performs a all-gather of k and v over context parallel ranks."""
"""Performs a
n
all-gather of k and v over context parallel ranks."""
def
ag
(
x
):
def
ag
(
x
):
x
=
lax_paral_op
(
x
=
lax_paral_op
(
...
@@ -1260,6 +1399,9 @@ class _FusedAttnCPWithAllGatherHelper:
...
@@ -1260,6 +1399,9 @@ class _FusedAttnCPWithAllGatherHelper:
)
)
if
self
.
config
.
context_parallel_load_balanced
:
if
self
.
config
.
context_parallel_load_balanced
:
cp_size
=
get_mesh_axis_size
(
self
.
config
.
cp_axis
,
self
.
mesh
)
cp_size
=
get_mesh_axis_size
(
self
.
config
.
cp_axis
,
self
.
mesh
)
if
self
.
config
.
qkv_layout
.
is_thd
():
x
=
reorder_causal_striped
(
x
,
cp_size
,
1
,
True
,
self
.
config
.
stripe_size
)
else
:
x
=
reorder_causal_dual_chunk_swap
(
x
,
cp_size
,
1
,
to_contiguous
=
True
)
x
=
reorder_causal_dual_chunk_swap
(
x
,
cp_size
,
1
,
to_contiguous
=
True
)
return
x
return
x
...
@@ -1270,12 +1412,35 @@ class _FusedAttnCPWithAllGatherHelper:
...
@@ -1270,12 +1412,35 @@ class _FusedAttnCPWithAllGatherHelper:
return
k
,
v
# fall through
return
k
,
v
# fall through
def
all_gather_segment_ids_and_pos
(
self
,
kv_segment_ids
,
kv_segment_pos
):
"""Performs an all-gather of kv segment ids and kv segment pos over context parallel ranks."""
kv_segment_ids
=
lax_paral_op
(
kv_segment_ids
,
lax
.
all_gather
,
self
.
config
.
cp_axis
,
mesh
=
self
.
mesh
,
axis
=
1
,
tiled
=
True
)
kv_segment_pos
=
lax_paral_op
(
kv_segment_pos
,
lax
.
all_gather
,
self
.
config
.
cp_axis
,
mesh
=
self
.
mesh
,
axis
=
1
,
tiled
=
True
)
if
self
.
config
.
context_parallel_load_balanced
:
cp_size
=
get_mesh_axis_size
(
self
.
config
.
cp_axis
,
self
.
mesh
)
if
self
.
config
.
qkv_layout
.
is_thd
():
kv_segment_ids_ag
=
reorder_causal_striped
(
kv_segment_ids
,
cp_size
,
1
,
True
,
self
.
config
.
stripe_size
)
kv_segment_pos_ag
=
reorder_causal_striped
(
kv_segment_pos
,
cp_size
,
1
,
True
,
self
.
config
.
stripe_size
)
return
kv_segment_ids_ag
,
kv_segment_pos_ag
return
kv_segment_ids
,
kv_segment_pos
# fall through
def
reduce_scatter_dkv
(
self
,
dk
,
dv
):
def
reduce_scatter_dkv
(
self
,
dk
,
dv
):
"""Performs a reduce-scatter of dk and dv over context parallel ranks."""
"""Performs a reduce-scatter of dk and dv over context parallel ranks."""
def
rs
(
x
):
def
rs
(
x
):
if
self
.
config
.
context_parallel_load_balanced
:
if
self
.
config
.
context_parallel_load_balanced
:
cp_size
=
get_mesh_axis_size
(
self
.
config
.
cp_axis
,
self
.
mesh
)
cp_size
=
get_mesh_axis_size
(
self
.
config
.
cp_axis
,
self
.
mesh
)
if
self
.
config
.
qkv_layout
.
is_thd
():
x
=
reorder_causal_striped
(
x
,
cp_size
,
1
,
False
,
self
.
config
.
stripe_size
)
else
:
x
=
reorder_causal_dual_chunk_swap
(
x
,
cp_size
,
1
,
to_contiguous
=
False
)
x
=
reorder_causal_dual_chunk_swap
(
x
,
cp_size
,
1
,
to_contiguous
=
False
)
return
lax_paral_op
(
return
lax_paral_op
(
...
@@ -1349,6 +1514,227 @@ class _FusedAttnCPWithAllGatherHelper:
...
@@ -1349,6 +1514,227 @@ class _FusedAttnCPWithAllGatherHelper:
return
dk
,
dv
# fall through
return
dk
,
dv
# fall through
# Below are the sharded post AG q seg ids and pos for a given rank:
# q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
# q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
# max_segments_per_seq = 7
# Below are some intermediate representations:
# non_zero_indices = [[ 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1]]
# segment_changes = [[ True, False, False, False, True, False, False, False, True, False, False, False, True, True, True, True]]
# seqlens_pre = [[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0]]
# seqlens_all_pad_neg = [[ 4, 4, 4, -1, -1, -1, -1]]
def
q_seqlens_for_striped_for_rank
(
self
,
q_segment_ids
,
q_segment_pos
,
max_segments_per_seq
):
"""Extract the q seqlens for striped primitive (post AG) from the sharded q seg ids and seg pos"""
# Create mask for non-zero seg ids and get the non-zero indices associated with the same
non_zero_mask
=
q_segment_ids
!=
0
max_size
=
q_segment_ids
.
shape
[
-
1
]
non_zero_indices
=
jax
.
vmap
(
lambda
mask_row
:
jnp
.
where
(
mask_row
,
size
=
max_size
,
fill_value
=-
1
)[
0
]
)(
non_zero_mask
)
# Pick non-zero seg ids and seg pos using take_along_axis to index within the seg ids and pos
# Clip -1 to 0 for safe indexing
clipped_indices
=
jnp
.
clip
(
non_zero_indices
,
0
,
None
)
valid_segment_ids
=
jnp
.
where
(
non_zero_indices
>=
0
,
jnp
.
take_along_axis
(
q_segment_ids
,
clipped_indices
,
axis
=-
1
),
0
)
valid_segment_pos
=
jnp
.
where
(
non_zero_indices
>=
0
,
jnp
.
take_along_axis
(
q_segment_pos
,
clipped_indices
,
axis
=-
1
),
0
)
# Create a mask for actual valid entries (not padding)
actual_valid
=
valid_segment_ids
!=
0
# First element is True only if it's actually valid
first_is_segment
=
actual_valid
[...,
0
:
1
]
# Detect segment breaks in the valid tokens only (not full seq)
# Padding will always be true as the segment change condition is being applied
# on the valid segments (which have padding at the end so they'll always trigger True)
segment_changes
=
jnp
.
concatenate
(
[
first_is_segment
,
# First valid element starts a segment
(
valid_segment_ids
[...,
1
:]
!=
valid_segment_ids
[...,
:
-
1
])
|
(
valid_segment_pos
[...,
1
:]
!=
valid_segment_pos
[...,
:
-
1
]
+
1
),
],
axis
=-
1
,
)
new_segment_ids
=
jnp
.
cumsum
(
segment_changes
,
axis
=-
1
)
seqlens_pre
=
jax
.
vmap
(
lambda
av_row
,
nsi_row
:
jnp
.
where
(
av_row
,
nsi_row
,
0
).
astype
(
jnp
.
int32
)
)(
actual_valid
,
new_segment_ids
)
seqlens_all
=
jax
.
vmap
(
lambda
sp_row
:
jnp
.
bincount
(
sp_row
,
length
=
max_segments_per_seq
+
1
)[
1
:]
)(
seqlens_pre
)
seqlens_all_pad_neg
=
jnp
.
where
(
seqlens_all
==
0
,
-
1
,
seqlens_all
)
return
seqlens_all_pad_neg
# Below are the sharded post AG q seg ids and pos for a given rank:
# q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
# q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
# max_segments_per_seq = 7
# Below are some intermediate representations:
# segment_changes = [[ True, False, False, False, True, False, False, False, True, False, False, False, True, False, False, False]]
# segment_changes_masked = [[ True, False, False, False, False, False, False, False, True, False, False, False, True, False, False, False]]
# seq_offsets = [[ 0, 8, 12, -1, -1, -1, -1, -1]]
def
q_seqoffsets_for_striped_for_rank
(
self
,
q_segment_ids
,
q_segment_pos
,
max_segments_per_seq
):
"""Extract the q seqoffets for striped primitive (post AG) from the sharded q seg ids and seg pos"""
segment_changes
=
jnp
.
concatenate
(
[
jnp
.
full
(
(
q_segment_pos
.
shape
[
0
],
1
),
True
,
dtype
=
bool
),
# First valid element starts a segment
(
q_segment_pos
[...,
1
:]
!=
q_segment_pos
[...,
:
-
1
]
+
1
),
# Segment pos changed
],
axis
=-
1
,
)
# Remove any padded region segment changes
segment_changes_masked
=
jnp
.
where
(
q_segment_ids
!=
0
,
segment_changes
,
False
)
# Get the indices for segment changes (these are the offsets)
seq_offsets
=
jax
.
vmap
(
lambda
scm_row
:
jnp
.
where
(
scm_row
,
size
=
max_segments_per_seq
,
fill_value
=-
1
)[
0
]
)(
segment_changes_masked
)
return
seq_offsets
# Below are the sharded post AG q seg ids and pos for a given rank:
# kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
# kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
# max_segments_per_seq = 7
# Below are some intermediate representations:
# non_zero_mask = [[ True, True, True, True, False, False, False, False, True, True, True, True, True, True, True, True]]
# non_zero_indices = [[ 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1]]
# segment_changes = [[False, False, False, True, False, False, False, True, False, False, False, True, True, True, True, False]]
# selected_values = [[ 4, 15, 31, -1, -1, -1, -1, -1]]
def
kv_seqlens_for_striped_for_rank
(
self
,
kv_segment_ids
,
kv_segment_pos
,
max_segments_per_seq
):
"""Extract the kv seqlens for striped primitive (post AG) from the sharded kv seg ids and seg pos"""
# Create mask for non-zero seg ids and get the non-zero indices associated with the same
non_zero_mask
=
kv_segment_ids
!=
0
max_size
=
kv_segment_ids
.
shape
[
-
1
]
non_zero_indices
=
jax
.
vmap
(
lambda
mask_row
:
jnp
.
where
(
mask_row
,
size
=
max_size
,
fill_value
=-
1
)[
0
]
)(
non_zero_mask
)
# Pick non zero seg ids and seg pos using take_along_axis
# Clip -1 to 0 for safe indexing
clipped_indices
=
jnp
.
clip
(
non_zero_indices
,
0
,
None
)
valid_segment_ids
=
jnp
.
where
(
non_zero_indices
>=
0
,
jnp
.
take_along_axis
(
kv_segment_ids
,
clipped_indices
,
axis
=-
1
),
0
)
valid_segment_pos
=
jnp
.
where
(
non_zero_indices
>=
0
,
jnp
.
take_along_axis
(
kv_segment_pos
,
clipped_indices
,
axis
=-
1
),
0
)
actual_valid
=
valid_segment_ids
!=
0
# Detect segment breaks (only for non-zero segments)
segment_changes
=
jnp
.
concatenate
(
[
(
(
valid_segment_ids
[...,
1
:]
!=
valid_segment_ids
[...,
:
-
1
])
&
actual_valid
[...,
1
:]
)
|
(
valid_segment_pos
[...,
1
:]
!=
valid_segment_pos
[...,
:
-
1
]
+
1
),
actual_valid
[...,
-
1
:],
],
axis
=-
1
,
)
# Get the indices for segment changes
segment_changes_valid
=
jax
.
vmap
(
lambda
sc_row
,
av_row
:
jnp
.
where
(
sc_row
&
av_row
,
size
=
max_segments_per_seq
,
fill_value
=-
1
)[
0
]
)(
segment_changes
,
actual_valid
)
safe_indices
=
jnp
.
maximum
(
segment_changes_valid
,
0
)
# Select values using take_along_axis per row
selected_values
=
jnp
.
where
(
segment_changes_valid
>=
0
,
jnp
.
take_along_axis
(
valid_segment_pos
,
safe_indices
,
axis
=-
1
)
+
1
,
-
1
,
)
return
selected_values
# Below are the sharded post AG q seg ids and pos for a given rank:
# kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
# kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
# kv_segment_ids_ag = [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
# 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
# kv_segment_pos_ag = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
# 18, 19, 20, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
# 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
# max_segments_per_seq = 7
# Below are some intermediate representations:
# segment_changes_first_true_masked = [[ True, False, False, False, False, False, False, False, True,
# False, False, False, True, False, False, False]]
# segment_changes_indices = [[ 0, 8, 12, -1, -1, -1, -1, -1, -1]]
# segment_ids = [[ 1, 2, 2, -1, -1, -1, -1, -1, -1]]
# segment_changes_ag_first_true_masked = [[ True, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False,
# False, False, False, True, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False,
# False]
# segment_changes_ag_indices = [[ 0, 21, -1, -1, -1, -1, -1, -1, -1]]
# seq_offsets = [[ 0, 21, 21, -1, -1, -1, -1, -1, -1]]
def
kv_seqoffsets_for_striped_for_rank
(
self
,
kv_segment_pos
,
kv_segment_ids
,
kv_segment_pos_ag
,
kv_segment_ids_ag
,
max_segments_per_seq
,
):
"""Extract the kv seqoffsets for striped primitive (post AG) from the sharded kv seg ids and seg pos,
AG kv seg ids and seg pos."""
# Calculate the segment pos change mask
segment_changes_first_true
=
jnp
.
concatenate
(
[
jnp
.
full
(
(
kv_segment_pos
.
shape
[
0
],
1
),
True
,
dtype
=
bool
),
# Assume valid element starts a segment and mask afterwards
(
kv_segment_pos
[...,
1
:]
!=
kv_segment_pos
[...,
:
-
1
]
+
1
),
# Segment pos changed
],
axis
=-
1
,
)
segment_changes_first_true_masked
=
jnp
.
where
(
kv_segment_ids
!=
0
,
segment_changes_first_true
,
False
)
# Get segment change indices for rank
segment_changes_indices
=
jax
.
vmap
(
lambda
sc_row
:
jnp
.
where
(
sc_row
,
size
=
max_segments_per_seq
,
fill_value
=-
1
)[
0
]
)(
segment_changes_first_true_masked
)
# Get segment ids associated with the segment_changes_indices for rank
segment_ids
=
jax
.
vmap
(
lambda
sci_row
,
ksi_row
:
jnp
.
where
(
sci_row
>=
0
,
ksi_row
[
sci_row
],
-
1
)
)(
segment_changes_indices
,
kv_segment_ids
)
# Get segment change indices for AG
segment_changes_ag_first_true
=
jnp
.
concatenate
(
[
jnp
.
full
(
(
kv_segment_pos
.
shape
[
0
],
1
),
True
,
dtype
=
bool
),
# Assume valid element starts a segment and mask afterwards
(
kv_segment_pos_ag
[...,
1
:]
!=
kv_segment_pos_ag
[...,
:
-
1
]
+
1
),
# Segment pos changed
],
axis
=-
1
,
)
segment_changes_ag_first_true_masked
=
jnp
.
where
(
kv_segment_ids_ag
!=
0
,
segment_changes_ag_first_true
,
False
)
# Get segment change indices for AG
segment_changes_ag_indices
=
jax
.
vmap
(
lambda
scag_row
:
jnp
.
where
(
scag_row
,
size
=
max_segments_per_seq
,
fill_value
=-
1
)[
0
]
)(
segment_changes_ag_first_true_masked
)
# Use the segment ids picked per rank to get the offsets from the AG indices
seq_offsets
=
jax
.
vmap
(
lambda
si_row
,
sca_row
:
jnp
.
where
(
si_row
>
0
,
sca_row
[
si_row
-
1
],
-
1
)
)(
segment_ids
,
segment_changes_ag_indices
)
return
seq_offsets
class
FusedAttnCPWithAllGatherFwdPrimitive
(
FusedAttnFwdPrimitive
):
class
FusedAttnCPWithAllGatherFwdPrimitive
(
FusedAttnFwdPrimitive
):
"""
"""
...
@@ -1376,7 +1762,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1376,7 +1762,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
)
)
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
4
]
=
seed_sharding
arg_shardings
[
5
]
=
seed_sharding
arg_shardings
=
tuple
(
arg_shardings
)
arg_shardings
=
tuple
(
arg_shardings
)
out_shardings
=
(
out_sharding
,
softmax_aux_sharding
,
rng_state_sharding
)
out_shardings
=
(
out_sharding
,
softmax_aux_sharding
,
rng_state_sharding
)
...
@@ -1385,6 +1771,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1385,6 +1771,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
seed
,
seed
,
q_seqlen
,
q_seqlen
,
kv_seqlen
,
kv_seqlen
,
...
@@ -1404,7 +1791,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1404,7 +1791,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
# meeting the expectation of the SPMD model.
# meeting the expectation of the SPMD model.
# TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
# TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
# mask/sequence length tensor to avoid this unrolled loop.
# mask/sequence length tensor to avoid this unrolled loop.
def
_cross_attn
(
idx
,
q
,
k
,
v
,
bias
,
q_seqlen
,
kv_seqlen
,
seed
):
def
_cross_attn
(
idx
,
q
,
k
,
v
,
bias
,
softmax_offset
,
q_seqlen
,
kv_seqlen
,
seed
):
kv_max_seqlen
=
k
.
shape
[
1
]
kv_max_seqlen
=
k
.
shape
[
1
]
kv_seqlen_per_subrank
=
kv_max_seqlen
//
(
cp_size
*
2
)
kv_seqlen_per_subrank
=
kv_max_seqlen
//
(
cp_size
*
2
)
assert
kv_max_seqlen
%
cp_size
==
0
,
"sequence length must evenly divide cp size"
assert
kv_max_seqlen
%
cp_size
==
0
,
"sequence length must evenly divide cp size"
...
@@ -1425,12 +1812,12 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1425,12 +1812,12 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
q_seqlen_for_step
=
q_seqlen
/
(
cp_size
*
2
)
q_seqlen_for_step
=
q_seqlen
/
(
cp_size
*
2
)
num_kv_chunks
=
kv_max_seqlen
//
kv_seqlens_for_rank
[
sub_idx
]
num_kv_chunks
=
kv_max_seqlen
//
kv_seqlens_for_rank
[
sub_idx
]
kv_seqlen_for_step
=
(
kv_seqlen
/
(
cp_size
*
2
))
*
num_kv_chunks
kv_seqlen_for_step
=
(
kv_seqlen
/
(
cp_size
*
2
))
*
num_kv_chunks
output
,
softmax_aux
,
rng_state
=
FusedAttnFwdPrimitive
.
impl
(
output
,
softmax_aux
,
rng_state
=
FusedAttnFwdPrimitive
.
impl
(
q_split
[
sub_idx
],
q_split
[
sub_idx
],
k_unmasked
,
k_unmasked
,
v_unmasked
,
v_unmasked
,
bias
,
bias
,
softmax_offset
,
seed
,
seed
,
q_seqlen_for_step
,
q_seqlen_for_step
,
kv_seqlen_for_step
,
kv_seqlen_for_step
,
...
@@ -1453,7 +1840,9 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1453,7 +1840,9 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k_ag
,
v_ag
=
helper
.
all_gather_kv
(
k
,
v
)
k_ag
,
v_ag
=
helper
.
all_gather_kv
(
k
,
v
)
functions
=
[
functions
=
[
partial
(
_cross_attn
,
idx
,
q
,
k_ag
,
v_ag
,
bias
,
q_seqlen
,
kv_seqlen
,
seed
)
partial
(
_cross_attn
,
idx
,
q
,
k_ag
,
v_ag
,
bias
,
softmax_offset
,
q_seqlen
,
kv_seqlen
,
seed
)
for
idx
in
range
(
cp_size
)
for
idx
in
range
(
cp_size
)
]
]
...
@@ -1492,18 +1881,27 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -1492,18 +1881,27 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
softmax_offset_spec
=
get_padded_spec
(
arg_infos
[
4
])
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
dsoftmax_offset_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
softmax_offset_spec
))
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
)
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
,
dsoftmax_offset_sharding
,
)
def
impl
(
def
impl
(
q
,
q
,
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -1527,6 +1925,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -1527,6 +1925,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -1562,11 +1961,12 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -1562,11 +1961,12 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
num_kv_chunks
=
kv_max_seqlen
//
kv_seqlens_for_rank
[
sub_idx
]
num_kv_chunks
=
kv_max_seqlen
//
kv_seqlens_for_rank
[
sub_idx
]
kv_seqlen_for_step
=
(
kv_seqlen
//
(
cp_size
*
2
))
*
num_kv_chunks
kv_seqlen_for_step
=
(
kv_seqlen
//
(
cp_size
*
2
))
*
num_kv_chunks
dq_local
,
dk_local
,
dv_local
,
dbias_local
=
FusedAttnBwdPrimitive
.
impl
(
dq_local
,
dk_local
,
dv_local
,
dbias_local
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q_split
[
sub_idx
],
q_split
[
sub_idx
],
k_unmasked
,
k_unmasked
,
v_unmasked
,
v_unmasked
,
bias
,
bias
,
softmax_offset
,
softmax_aux_split
[
sub_idx
],
softmax_aux_split
[
sub_idx
],
rng_state
,
rng_state
,
output_split
[
sub_idx
],
output_split
[
sub_idx
],
...
@@ -1604,6 +2004,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -1604,6 +2004,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k_ag
,
k_ag
,
v_ag
,
v_ag
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -1621,7 +2022,9 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -1621,7 +2022,9 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
dq
,
dk_local
,
dv_local
,
dbias
=
lax
.
switch
(
cp_rank
,
functions
)
dq
,
dk_local
,
dv_local
,
dbias
=
lax
.
switch
(
cp_rank
,
functions
)
dk
,
dv
=
helper
.
reduce_scatter_dkv
(
dk_local
,
dv_local
)
dk
,
dv
=
helper
.
reduce_scatter_dkv
(
dk_local
,
dv_local
)
return
dq
,
dk
,
dv
,
dbias
# Return dummy dsoftmax_offset for arity matching (all-gather CP doesn't use it)
dummy_dsoftmax_offset
=
jnp
.
empty_like
(
softmax_offset
)
return
dq
,
dk
,
dv
,
dbias
,
dummy_dsoftmax_offset
return
mesh
,
impl
,
out_shardings
,
arg_shardings
return
mesh
,
impl
,
out_shardings
,
arg_shardings
...
@@ -1629,6 +2032,314 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -1629,6 +2032,314 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
register_primitive
(
FusedAttnCPWithAllGatherBwdPrimitive
)
register_primitive
(
FusedAttnCPWithAllGatherBwdPrimitive
)
class
FusedAttnCPStripedWithAllGatherFwdPrimitive
(
FusedAttnFwdPrimitive
):
"""
Fused Attention Forward with Context Parallelism and Striped Load Balancing Primitive
This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks.
"""
@
staticmethod
def
partition
(
config
,
mesh
,
arg_infos
,
result_infos
):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel
=
get_mesh_axis_size
(
config
.
cp_axis
,
mesh
)
>
1
if
not
is_context_parallel
:
return
FusedAttnFwdPrimitive
.
partition
(
config
,
mesh
,
arg_infos
,
result_infos
)
helper
=
_FusedAttnCPWithAllGatherHelper
(
mesh
,
config
)
helper
.
check_supported
()
out_sharding
=
result_infos
[
0
].
sharding
softmax_aux_sharding
=
result_infos
[
1
].
sharding
rng_state_sharding
=
seed_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
)
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
5
]
=
seed_sharding
arg_shardings
=
tuple
(
arg_shardings
)
out_shardings
=
(
out_sharding
,
softmax_aux_sharding
,
rng_state_sharding
)
def
impl
(
q
,
k
,
v
,
bias
,
softmax_offset
,
seed
,
q_seqlen
,
kv_seqlen
,
q_seq_offsets
,
k_seq_offsets
,
_q_segment_ids
,
_kv_segment_ids
,
_q_segment_pos
,
_kv_segment_pos
,
):
# pylint: disable=unused-argument
cp_size
=
get_mesh_axis_size
(
config
.
cp_axis
,
mesh
)
cp_rank
=
get_mesh_axis_rank
(
config
.
cp_axis
,
mesh
)
# cuDNN does not support right-aligned masking with dynamic sequence length padding.
# Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch
# to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor
# meeting the expectation of the SPMD model.
# TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
# mask/sequence length tensor to avoid this unrolled loop.
# Each rank receives the ag k and v along with the ag kv seg ids and kv seg offsets
# Each rank sees the sharded view for 5 tensors -> q, _q_segment_ids, _q_segment_pos,
# _kv_segment_ids, _kv_segment_pos -> Note these have also been reordered before passing in.
def
_cross_attn
(
q
,
k
,
v
,
bias
,
softmax_offset
,
kv_segment_ids_ag
,
kv_segment_pos_ag
,
seed
):
# Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive
# Unset the segment_ids and segment_pos by passing placeholders so that the seqlens_from_segment_ids_pos()
# does not go down that route but instead just picks the pre-computed seqlens and offsets passed onto it
kv_max_seqlen
=
k
.
shape
[
1
]
# Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq
adjusted_max_segments_per_seq
=
helper
.
get_adjusted_max_segments_per_seq
(
max_seqlen
=
kv_max_seqlen
,
cp_size
=
cp_size
)
q_seqlens_for_rank
=
helper
.
q_seqlens_for_striped_for_rank
(
_q_segment_ids
,
_q_segment_pos
,
adjusted_max_segments_per_seq
)
q_seq_offsets_for_rank
=
helper
.
q_seqoffsets_for_striped_for_rank
(
q_segment_ids
=
_q_segment_ids
,
q_segment_pos
=
_q_segment_pos
,
max_segments_per_seq
=
adjusted_max_segments_per_seq
,
)
kv_seqlens_for_rank
=
helper
.
kv_seqlens_for_striped_for_rank
(
kv_segment_ids
=
_kv_segment_ids
,
kv_segment_pos
=
_kv_segment_pos
,
max_segments_per_seq
=
adjusted_max_segments_per_seq
,
)
kv_seq_offsets_for_rank
=
helper
.
kv_seqoffsets_for_striped_for_rank
(
kv_segment_pos
=
_kv_segment_pos
,
kv_segment_ids
=
_kv_segment_ids
,
kv_segment_pos_ag
=
kv_segment_pos_ag
,
kv_segment_ids_ag
=
kv_segment_ids_ag
,
max_segments_per_seq
=
adjusted_max_segments_per_seq
,
)
output
,
softmax_aux
,
rng_state
=
FusedAttnFwdPrimitive
.
impl
(
q
,
# sharded for rank
k
,
# ag
v
,
# ag
bias
,
softmax_offset
,
seed
,
q_seqlens_for_rank
,
kv_seqlens_for_rank
,
q_seq_offsets_for_rank
,
kv_seq_offsets_for_rank
,
jnp
.
zeros
(
0
),
jnp
.
zeros
(
0
),
jnp
.
zeros
(
0
),
jnp
.
zeros
(
0
),
config
=
helper
.
get_step_config_for_striped
(
max_seqlen
=
kv_max_seqlen
,
cp_size
=
cp_size
),
)
return
output
,
softmax_aux
,
rng_state
# AG the k, v, kv_segment_ids and kv_segment_pos
k_ag
,
v_ag
=
helper
.
all_gather_kv
(
k
,
v
)
_kv_segment_ids_ag
,
_kv_segment_pos_ag
=
helper
.
all_gather_segment_ids_and_pos
(
_kv_segment_ids
,
_kv_segment_pos
)
functions
=
[
partial
(
_cross_attn
,
q
,
k_ag
,
v_ag
,
bias
,
softmax_offset
,
_kv_segment_ids_ag
,
_kv_segment_pos_ag
,
seed
,
)
for
_
in
range
(
cp_size
)
]
return
lax
.
switch
(
cp_rank
,
functions
)
return
mesh
,
impl
,
out_shardings
,
arg_shardings
register_primitive
(
FusedAttnCPStripedWithAllGatherFwdPrimitive
)
class
FusedAttnCPStripedWithAllGatherBwdPrimitive
(
FusedAttnBwdPrimitive
):
"""
Fused Attention Backward with Context Parallelism and Striped Load Balancing Primitive.
This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks.
The gradients are subsequently reduce-scattered back to each context parallel rank.
"""
@
staticmethod
def
partition
(
config
,
mesh
,
arg_infos
,
result_infos
):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel
=
get_mesh_axis_size
(
config
.
cp_axis
,
mesh
)
>
1
if
not
is_context_parallel
:
return
FusedAttnBwdPrimitive
.
partition
(
config
,
mesh
,
arg_infos
,
result_infos
)
# Ensure we can support this configuration with context parallelism.
helper
=
_FusedAttnCPWithAllGatherHelper
(
mesh
,
config
)
helper
.
check_supported
()
del
result_infos
q_spec
=
get_padded_spec
(
arg_infos
[
0
])
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
softmax_offset_spec
=
get_padded_spec
(
arg_infos
[
4
])
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
dsoftmax_offset_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
softmax_offset_spec
))
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
,
dsoftmax_offset_sharding
,
)
def
impl
(
q
,
k
,
v
,
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
doutput
,
q_seqlen
,
kv_seqlen
,
q_seq_offsets
,
k_seq_offsets
,
_q_segment_ids
,
_kv_segment_ids
,
_q_segment_pos
,
_kv_segment_pos
,
):
# pylint: disable=unused-argument
cp_size
=
get_mesh_axis_size
(
config
.
cp_axis
,
mesh
)
cp_rank
=
get_mesh_axis_rank
(
config
.
cp_axis
,
mesh
)
# See comment in FusedAttnCPFwdPrimitive.partition for why we define this function.
def
_cross_attn_bwd
(
q
,
k
,
v
,
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
doutput
,
_q_segment_ids
,
kv_segment_ids_ag
,
_q_segment_pos
,
kv_segment_pos_ag
,
):
# Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive
# Unset the segment_ids and segment_pos by passing placeholders so that the seqlens_from_segment_ids_pos()
# does not go down that route but instead just picks the pre-computed seqlens and offsets passed onto it
kv_max_seqlen
=
k
.
shape
[
1
]
# Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq
adjusted_max_segments_per_seq
=
helper
.
get_adjusted_max_segments_per_seq
(
max_seqlen
=
kv_max_seqlen
,
cp_size
=
cp_size
)
q_seqlens_for_rank
=
helper
.
q_seqlens_for_striped_for_rank
(
_q_segment_ids
,
_q_segment_pos
,
adjusted_max_segments_per_seq
)
q_seq_offsets_for_rank
=
helper
.
q_seqoffsets_for_striped_for_rank
(
q_segment_ids
=
_q_segment_ids
,
q_segment_pos
=
_q_segment_pos
,
max_segments_per_seq
=
adjusted_max_segments_per_seq
,
)
kv_seqlens_for_rank
=
helper
.
kv_seqlens_for_striped_for_rank
(
kv_segment_ids
=
_kv_segment_ids
,
kv_segment_pos
=
_kv_segment_pos
,
max_segments_per_seq
=
adjusted_max_segments_per_seq
,
)
kv_seq_offsets_for_rank
=
helper
.
kv_seqoffsets_for_striped_for_rank
(
kv_segment_pos
=
_kv_segment_pos
,
kv_segment_ids
=
_kv_segment_ids
,
kv_segment_pos_ag
=
kv_segment_pos_ag
,
kv_segment_ids_ag
=
kv_segment_ids_ag
,
max_segments_per_seq
=
adjusted_max_segments_per_seq
,
)
dq_local
,
dk_local
,
dv_local
,
dbias_local
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q
,
# sharded for rank
k
,
# ag
v
,
# ag
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
doutput
,
q_seqlens_for_rank
,
kv_seqlens_for_rank
,
q_seq_offsets_for_rank
,
kv_seq_offsets_for_rank
,
jnp
.
zeros
(
0
),
jnp
.
zeros
(
0
),
jnp
.
zeros
(
0
),
jnp
.
zeros
(
0
),
config
=
helper
.
get_step_config_for_striped
(
max_seqlen
=
kv_max_seqlen
,
cp_size
=
cp_size
),
)
return
dq_local
,
dk_local
,
dv_local
,
dbias_local
# AG the k, v, kv_segment_ids and kv_segment_pos
k_ag
,
v_ag
=
helper
.
all_gather_kv
(
k
,
v
)
_kv_segment_ids_ag
,
_kv_segment_pos_ag
=
helper
.
all_gather_segment_ids_and_pos
(
_kv_segment_ids
,
_kv_segment_pos
)
functions
=
[
partial
(
_cross_attn_bwd
,
q
,
k_ag
,
v_ag
,
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
doutput
,
_q_segment_ids
,
_kv_segment_ids_ag
,
_q_segment_pos
,
_kv_segment_pos_ag
,
)
for
_
in
range
(
cp_size
)
]
dq
,
dk_local
,
dv_local
,
dbias
=
lax
.
switch
(
cp_rank
,
functions
)
# RS the dk and dv
dk
,
dv
=
helper
.
reduce_scatter_dkv
(
dk_local
,
dv_local
)
# Return dummy dsoftmax_offset for arity matching (all-gather CP doesn't use it)
dummy_dsoftmax_offset
=
jnp
.
empty_like
(
softmax_offset
)
return
dq
,
dk
,
dv
,
dbias
,
dummy_dsoftmax_offset
return
mesh
,
impl
,
out_shardings
,
arg_shardings
register_primitive
(
FusedAttnCPStripedWithAllGatherBwdPrimitive
)
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
_FusedAttnCPWithP2PHelper
:
class
_FusedAttnCPWithP2PHelper
:
"""Helper class to assist with running the P2P ring strategy for CP attention."""
"""Helper class to assist with running the P2P ring strategy for CP attention."""
...
@@ -1639,7 +2350,8 @@ class _FusedAttnCPWithP2PHelper:
...
@@ -1639,7 +2350,8 @@ class _FusedAttnCPWithP2PHelper:
@
staticmethod
@
staticmethod
def
use_scanloop
():
def
use_scanloop
():
"""Returns true if the implementation will use a scan loop for iteration."""
"""Returns true if the implementation will use a scan loop for iteration."""
use_scan
=
bool
(
int
(
os
.
getenv
(
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
,
"1"
)))
# TODO(KshitijLakhani): Reset default to 1, once the extra kv permute op issue is resolved
use_scan
=
bool
(
int
(
os
.
getenv
(
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
,
"0"
)))
return
use_scan
return
use_scan
def
check_supported
(
self
):
def
check_supported
(
self
):
...
@@ -1679,13 +2391,20 @@ class _FusedAttnCPWithP2PHelper:
...
@@ -1679,13 +2391,20 @@ class _FusedAttnCPWithP2PHelper:
if
self
.
config
.
dropout_probability
!=
0.0
:
if
self
.
config
.
dropout_probability
!=
0.0
:
raise
ValueError
(
f
"
{
header
}
does not support dropout"
)
raise
ValueError
(
f
"
{
header
}
does not support dropout"
)
# We want to encourage use of scan loop to minimize unrolling and ensure more
if
self
.
config
.
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
# predictable scheduling from XLA. The unrolled flavor will be supported but
raise
ValueError
(
# not the prefered implementation.
f
"
{
header
}
only supports VANILLA_SOFTMAX, got:
{
self
.
config
.
softmax_type
}
"
if
not
self
.
use_scanloop
():
)
# TODO(KshitijLakhani): Flip the condition to check for disabled scan loop and warn
# against using unrolled loops once the scan issue is resolved.
# We want to discourage the use of scan loop as additional kv permute op observed.
# The scan loop flavor will be supported but not the prefered implementation until
# a resolution for the additional kv permute op, which degrades perf, is found.
if
self
.
use_scanloop
():
warnings
.
warn
(
warnings
.
warn
(
"Scan loop is
dis
abled for fused ring attention. To
en
able set"
"Scan loop is
en
abled for fused ring attention. To
dis
able set"
" NVTE_FUSED_RING_ATTENTION_USE_SCAN=
1
in your environment"
" NVTE_FUSED_RING_ATTENTION_USE_SCAN=
0
in your environment"
)
)
# If using scanloop, idx in scan_kv_block() will be a traced device value, but
# If using scanloop, idx in scan_kv_block() will be a traced device value, but
...
@@ -1703,6 +2422,7 @@ class _FusedAttnCPWithP2PHelper:
...
@@ -1703,6 +2422,7 @@ class _FusedAttnCPWithP2PHelper:
return
_FusedAttnConfig
(
return
_FusedAttnConfig
(
attn_bias_type
=
self
.
config
.
attn_bias_type
,
attn_bias_type
=
self
.
config
.
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
self
.
config
.
softmax_type
,
qkv_layout
=
QKVLayout
.
BSHD_BS2HD
,
qkv_layout
=
QKVLayout
.
BSHD_BS2HD
,
scaling_factor
=
self
.
config
.
scaling_factor
,
scaling_factor
=
self
.
config
.
scaling_factor
,
dropout_probability
=
self
.
config
.
dropout_probability
,
dropout_probability
=
self
.
config
.
dropout_probability
,
...
@@ -1712,6 +2432,7 @@ class _FusedAttnCPWithP2PHelper:
...
@@ -1712,6 +2432,7 @@ class _FusedAttnCPWithP2PHelper:
context_parallel_load_balanced
=
self
.
config
.
context_parallel_load_balanced
,
context_parallel_load_balanced
=
self
.
config
.
context_parallel_load_balanced
,
cp_axis
=
self
.
config
.
cp_axis
,
cp_axis
=
self
.
config
.
cp_axis
,
cp_striped_window_size
=
None
,
cp_striped_window_size
=
None
,
stripe_size
=
self
.
config
.
stripe_size
,
)
)
def
stack_kv
(
self
,
k
,
v
):
def
stack_kv
(
self
,
k
,
v
):
...
@@ -1783,7 +2504,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1783,7 +2504,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
)
)
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
4
]
=
seed_sharding
arg_shardings
[
5
]
=
seed_sharding
# Ensure segment_pos gets same sharding as ID.
# Ensure segment_pos gets same sharding as ID.
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
...
@@ -1795,6 +2516,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1795,6 +2516,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
_softmax_offset
,
seed
,
seed
,
q_seqlen
,
q_seqlen
,
kv_seqlen
,
kv_seqlen
,
...
@@ -1840,6 +2562,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1840,6 +2562,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv
,
kv
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
seed
,
seed
,
q_seqlen_per_step
,
q_seqlen_per_step
,
kv_seqlen_per_step
,
kv_seqlen_per_step
,
...
@@ -1865,6 +2588,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1865,6 +2588,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv_part
,
kv_part
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
seed
,
seed
,
q_seqlen_per_step
,
q_seqlen_per_step
,
kv_seqlen_per_step
,
kv_seqlen_per_step
,
...
@@ -1887,6 +2611,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1887,6 +2611,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv
,
kv
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
seed
,
seed
,
q_seqlen_per_step
,
q_seqlen_per_step
,
kv_seqlen_per_step
,
kv_seqlen_per_step
,
...
@@ -1990,18 +2715,24 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -1990,18 +2715,24 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
softmax_offset_spec
=
get_padded_spec
(
arg_infos
[
4
])
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
# Ring attention doesn't use dsoftmax_offset, but we need to return it for arity matching
dsoftmax_offset_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
softmax_offset_spec
))
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
# Ensure segment_pos gets same sharding as ID.
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
=
tuple
(
arg_shardings
)
arg_shardings
=
tuple
(
arg_shardings
)
out_shardings
=
(
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
)
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
,
dsoftmax_offset_sharding
,
)
helper
=
_FusedAttnCPWithP2PHelper
(
mesh
,
config
)
helper
=
_FusedAttnCPWithP2PHelper
(
mesh
,
config
)
helper
.
check_supported
()
helper
.
check_supported
()
...
@@ -2011,6 +2742,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2011,6 +2742,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
_softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -2054,11 +2786,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2054,11 +2786,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
def
mask_compute
(
attn_mask_type
):
def
mask_compute
(
attn_mask_type
):
q_seqlen_per_step
=
helper
.
adjust_seqlen
(
q_seqlen
,
q_max_seqlen
,
idx
)
q_seqlen_per_step
=
helper
.
adjust_seqlen
(
q_seqlen
,
q_max_seqlen
,
idx
)
kv_seqlen_per_step
=
helper
.
adjust_seqlen
(
kv_seqlen
,
kv_max_seqlen
,
idx
)
kv_seqlen_per_step
=
helper
.
adjust_seqlen
(
kv_seqlen
,
kv_max_seqlen
,
idx
)
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
=
FusedAttnBwdPrimitive
.
impl
(
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q
,
q
,
kv
,
kv
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -2082,11 +2815,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2082,11 +2815,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
q_seqlen_per_step
=
helper
.
adjust_seqlen
(
q_seqlen
,
q_max_seqlen
,
idx
)
q_seqlen_per_step
=
helper
.
adjust_seqlen
(
q_seqlen
,
q_max_seqlen
,
idx
)
kv_seqlen_per_step
=
helper
.
adjust_seqlen
(
kv_seqlen
,
kv_max_seqlen
,
idx
)
//
2
kv_seqlen_per_step
=
helper
.
adjust_seqlen
(
kv_seqlen
,
kv_max_seqlen
,
idx
)
//
2
kv_part
=
lax
.
slice_in_dim
(
kv
,
0
,
kv_max_seqlen
//
2
,
axis
=
1
)
kv_part
=
lax
.
slice_in_dim
(
kv
,
0
,
kv_max_seqlen
//
2
,
axis
=
1
)
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
=
FusedAttnBwdPrimitive
.
impl
(
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q
,
q
,
kv_part
,
kv_part
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -2120,11 +2854,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2120,11 +2854,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
softmax_aux
,
q_max_seqlen
//
2
,
q_max_seqlen
,
axis
=
2
softmax_aux
,
q_max_seqlen
//
2
,
q_max_seqlen
,
axis
=
2
)
)
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
=
FusedAttnBwdPrimitive
.
impl
(
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q_part
,
q_part
,
kv
,
kv
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
softmax_aux_part
,
softmax_aux_part
,
rng_state
,
rng_state
,
output_part
,
output_part
,
...
@@ -2184,7 +2919,9 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2184,7 +2919,9 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
dbias
,
mesh
)
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
dbias
,
mesh
)
dk
,
dv
=
helper
.
unstack_kv
(
dk_dv
)
dk
,
dv
=
helper
.
unstack_kv
(
dk_dv
)
return
dq
,
dk
,
dv
,
global_dbias
# Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
dummy_dsoftmax_offset
=
jnp
.
empty_like
(
_softmax_offset
)
return
dq
,
dk
,
dv
,
global_dbias
,
dummy_dsoftmax_offset
return
mesh
,
ring_attn_bwd_impl
,
out_shardings
,
arg_shardings
return
mesh
,
ring_attn_bwd_impl
,
out_shardings
,
arg_shardings
...
@@ -2273,7 +3010,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -2273,7 +3010,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
)
)
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
4
]
=
seed_sharding
arg_shardings
[
5
]
=
seed_sharding
# Ensure segment_pos gets same sharding as ID.
# Ensure segment_pos gets same sharding as ID.
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
...
@@ -2285,6 +3022,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -2285,6 +3022,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
_softmax_offset
,
seed
,
seed
,
q_seqlen
,
q_seqlen
,
kv_seqlen
,
kv_seqlen
,
...
@@ -2336,6 +3074,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -2336,6 +3074,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
kv
,
kv
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
seed
,
seed
,
q_seqlen
,
q_seqlen
,
kv_seqlen
,
kv_seqlen
,
...
@@ -2345,7 +3084,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -2345,7 +3084,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
kv_segment_ids
,
kv_segment_ids
,
q_segment_pos
,
q_segment_pos
,
kv_segment_pos
,
kv_segment_pos
,
config
,
config
=
config
,
)
)
if
config
.
window_size
!=
(
-
1
,
-
1
):
if
config
.
window_size
!=
(
-
1
,
-
1
):
...
@@ -2420,8 +3159,8 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2420,8 +3159,8 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
=
tuple
(
arg_shardings
)
arg_shardings
=
tuple
(
arg_shardings
)
# dq, dk, dv, dbias sharding = q, k, v, bias sharding
# dq, dk, dv, dbias
, dsoftmax_offset
sharding = q, k, v, bias
, softmax_offset
sharding
out_shardings
=
tuple
(
arg
.
sharding
for
arg
in
arg_infos
[:
4
])
out_shardings
=
tuple
(
arg
.
sharding
for
arg
in
arg_infos
[:
5
])
helper
=
_FusedAttnCPWithP2PHelper
(
mesh
,
config
)
helper
=
_FusedAttnCPWithP2PHelper
(
mesh
,
config
)
helper
.
check_supported
()
helper
.
check_supported
()
...
@@ -2431,6 +3170,7 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2431,6 +3170,7 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
_softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -2478,11 +3218,12 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2478,11 +3218,12 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
kv_segment_pos_next
=
helper
.
permute_kv
(
kv_segment_pos
,
cp_perm
)
kv_segment_pos_next
=
helper
.
permute_kv
(
kv_segment_pos
,
cp_perm
)
def
compute
(
config
):
def
compute
(
config
):
dq_per_step
,
dkv_per_step
,
_
,
dbias_per_step
=
FusedAttnBwdPrimitive
.
impl
(
dq_per_step
,
dkv_per_step
,
_
,
dbias_per_step
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q
,
q
,
kv
,
kv
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -2536,7 +3277,9 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2536,7 +3277,9 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
dbias
,
mesh
)
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
dbias
,
mesh
)
dk
,
dv
=
helper
.
unstack_kv
(
dkv
)
dk
,
dv
=
helper
.
unstack_kv
(
dkv
)
return
dq
,
dk
,
dv
,
global_dbias
# Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
dummy_dsoftmax_offset
=
jnp
.
empty_like
(
_softmax_offset
)
return
dq
,
dk
,
dv
,
global_dbias
,
dummy_dsoftmax_offset
return
mesh
,
bwd_impl
,
out_shardings
,
arg_shardings
return
mesh
,
bwd_impl
,
out_shardings
,
arg_shardings
...
@@ -2545,7 +3288,7 @@ register_primitive(FusedRingAttnStripedBwdPrimitive)
...
@@ -2545,7 +3288,7 @@ register_primitive(FusedRingAttnStripedBwdPrimitive)
def
_maybe_context_parallel_axis
(
cp_axis
:
str
):
def
_maybe_context_parallel_axis
(
cp_axis
:
str
):
if
not
cp_axis
:
if
not
cp_axis
and
is_mesh_available
()
:
gmr
=
global_mesh_resource
()
gmr
=
global_mesh_resource
()
if
gmr
is
not
None
:
if
gmr
is
not
None
:
cp_axis
=
gmr
.
cp_resource
cp_axis
=
gmr
.
cp_resource
...
@@ -2557,10 +3300,12 @@ def _maybe_context_parallel_axis(cp_axis: str):
...
@@ -2557,10 +3300,12 @@ def _maybe_context_parallel_axis(cp_axis: str):
def
fused_attn_fwd
(
def
fused_attn_fwd
(
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
bias
:
Optional
[
jnp
.
ndarray
],
bias
:
Optional
[
jnp
.
ndarray
],
softmax_offset
:
Optional
[
jnp
.
ndarray
],
sequence_descriptor
:
SequenceDescriptor
,
sequence_descriptor
:
SequenceDescriptor
,
seed
:
Optional
[
jnp
.
ndarray
],
seed
:
Optional
[
jnp
.
ndarray
],
attn_bias_type
:
AttnBiasType
,
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
attn_mask_type
:
AttnMaskType
,
softmax_type
:
AttnSoftmaxType
,
qkv_layout
:
QKVLayout
,
qkv_layout
:
QKVLayout
,
scaling_factor
:
float
,
scaling_factor
:
float
,
dropout_probability
:
float
,
dropout_probability
:
float
,
...
@@ -2570,6 +3315,7 @@ def fused_attn_fwd(
...
@@ -2570,6 +3315,7 @@ def fused_attn_fwd(
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
context_parallel_axis
:
str
=
""
,
stripe_size
:
int
|
None
=
None
,
)
->
jnp
.
ndarray
:
)
->
jnp
.
ndarray
:
"""
"""
Perform the forward pass of with cuDNN fused attention implementations.
Perform the forward pass of with cuDNN fused attention implementations.
...
@@ -2585,6 +3331,7 @@ def fused_attn_fwd(
...
@@ -2585,6 +3331,7 @@ def fused_attn_fwd(
query has a different shape (e.g., cross-attention).
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
q_seq_offsets (jnp.ndarray):
q_seq_offsets (jnp.ndarray):
...
@@ -2594,6 +3341,7 @@ def fused_attn_fwd(
...
@@ -2594,6 +3341,7 @@ def fused_attn_fwd(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
dropout_probability (float): Dropout probability to apply during attention.
...
@@ -2606,6 +3354,7 @@ def fused_attn_fwd(
...
@@ -2606,6 +3354,7 @@ def fused_attn_fwd(
context_parallel_causal_load_balanced (bool):
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_parallel_axis (str): The name of the context parallel axis.
stripe_size (int | None): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing
Returns:
Returns:
(jnp.ndarray): The output tensor from the fused attention.
(jnp.ndarray): The output tensor from the fused attention.
"""
"""
...
@@ -2633,10 +3382,36 @@ def fused_attn_fwd(
...
@@ -2633,10 +3382,36 @@ def fused_attn_fwd(
assert
bias
is
None
assert
bias
is
None
bias
=
jnp
.
zeros
(
0
,
dtype
=
qkv
[
0
].
dtype
)
bias
=
jnp
.
zeros
(
0
,
dtype
=
qkv
[
0
].
dtype
)
if
softmax_offset
is
None
:
assert
(
softmax_type
!=
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
),
f
"Softmax type
{
softmax_type
}
is not supported when softmax_offset is None"
if
softmax_type
==
AttnSoftmaxType
.
OFF_BY_ONE_SOFTMAX
:
num_heads
=
qkv
[
0
].
shape
[
-
2
]
# Create tensor [1, h, 1, 1] filled with zeros (logit value = 0)
# This adds exp(0 - x_max) = exp(-x_max) to the denominator,
# which contributes exactly 1 after normalization, giving: exp(x_i) / (sum(exp(x_j)) + 1)
softmax_offset
=
jnp
.
zeros
((
1
,
num_heads
,
1
,
1
),
dtype
=
jnp
.
float32
)
# Shard by heads dimension
softmax_offset
=
with_sharding_constraint_by_logical_axes
(
softmax_offset
,
(
None
,
HEAD_AXES
,
None
,
None
)
)
else
:
assert
softmax_type
==
AttnSoftmaxType
.
VANILLA_SOFTMAX
softmax_offset
=
jnp
.
zeros
(
0
,
dtype
=
jnp
.
float32
)
else
:
assert
softmax_offset
.
dtype
==
jnp
.
float32
# Shard by heads dimension if not VANILLA_SOFTMAX
if
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
softmax_offset
=
with_sharding_constraint_by_logical_axes
(
softmax_offset
,
(
None
,
HEAD_AXES
,
None
,
None
)
)
fused_config
=
_FusedAttnConfig
(
fused_config
=
_FusedAttnConfig
(
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
softmax_type
=
softmax_type
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
is_training
=
is_training
,
...
@@ -2645,11 +3420,15 @@ def fused_attn_fwd(
...
@@ -2645,11 +3420,15 @@ def fused_attn_fwd(
context_parallel_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_load_balanced
=
context_parallel_causal_load_balanced
,
cp_axis
=
_maybe_context_parallel_axis
(
context_parallel_axis
),
cp_axis
=
_maybe_context_parallel_axis
(
context_parallel_axis
),
cp_striped_window_size
=
None
,
cp_striped_window_size
=
None
,
stripe_size
=
stripe_size
,
)
)
primitive
=
None
primitive
=
None
match
context_parallel_strategy
:
match
context_parallel_strategy
:
case
CPStrategy
.
DEFAULT
|
CPStrategy
.
ALL_GATHER
:
case
CPStrategy
.
DEFAULT
|
CPStrategy
.
ALL_GATHER
:
if
qkv_layout
.
is_thd
():
primitive
=
FusedAttnCPStripedWithAllGatherFwdPrimitive
.
outer_primitive
else
:
primitive
=
FusedAttnCPWithAllGatherFwdPrimitive
.
outer_primitive
primitive
=
FusedAttnCPWithAllGatherFwdPrimitive
.
outer_primitive
case
CPStrategy
.
RING
:
case
CPStrategy
.
RING
:
# We must use stripe attention for THD-RING
# We must use stripe attention for THD-RING
...
@@ -2662,6 +3441,7 @@ def fused_attn_fwd(
...
@@ -2662,6 +3441,7 @@ def fused_attn_fwd(
output
,
softmax_aux
,
rng_state
=
primitive
.
bind
(
output
,
softmax_aux
,
rng_state
=
primitive
.
bind
(
*
qkv_for_primitive
,
*
qkv_for_primitive
,
bias
,
bias
,
softmax_offset
,
seed
,
seed
,
*
seq_desc_flatten
,
*
seq_desc_flatten
,
config
=
fused_config
,
config
=
fused_config
,
...
@@ -2673,6 +3453,7 @@ def fused_attn_fwd(
...
@@ -2673,6 +3453,7 @@ def fused_attn_fwd(
def
fused_attn_bwd
(
def
fused_attn_bwd
(
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
bias
:
Optional
[
jnp
.
ndarray
],
bias
:
Optional
[
jnp
.
ndarray
],
softmax_offset
:
Optional
[
jnp
.
ndarray
],
softmax_aux
:
jnp
.
ndarray
,
softmax_aux
:
jnp
.
ndarray
,
rng_state
:
jnp
.
ndarray
,
rng_state
:
jnp
.
ndarray
,
output
:
jnp
.
ndarray
,
output
:
jnp
.
ndarray
,
...
@@ -2681,6 +3462,7 @@ def fused_attn_bwd(
...
@@ -2681,6 +3462,7 @@ def fused_attn_bwd(
attn_bias_type
:
AttnBiasType
,
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
scaling_factor
:
float
,
dropout_probability
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
is_training
:
bool
,
...
@@ -2689,6 +3471,7 @@ def fused_attn_bwd(
...
@@ -2689,6 +3471,7 @@ def fused_attn_bwd(
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
context_parallel_axis
:
str
=
""
,
stripe_size
:
int
|
None
=
None
,
):
):
"""
"""
Perform the backward pass of the cuDNN fused attention implementations.
Perform the backward pass of the cuDNN fused attention implementations.
...
@@ -2702,6 +3485,7 @@ def fused_attn_bwd(
...
@@ -2702,6 +3485,7 @@ def fused_attn_bwd(
query has a different shape (e.g., cross-attention).
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
output (jnp.ndarray): The output tensor from the forward pass.
output (jnp.ndarray): The output tensor from the forward pass.
...
@@ -2714,6 +3498,7 @@ def fused_attn_bwd(
...
@@ -2714,6 +3498,7 @@ def fused_attn_bwd(
The offsets in the sequence dim for the query, with shape [batch + 1,].
The offsets in the sequence dim for the query, with shape [batch + 1,].
attn_bias_type (AttnBiasType): Type of attention bias.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
dropout_probability (float): Dropout probability to apply during attention.
...
@@ -2726,6 +3511,7 @@ def fused_attn_bwd(
...
@@ -2726,6 +3511,7 @@ def fused_attn_bwd(
context_parallel_causal_load_balanced (bool):
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_parallel_axis (str): The name of the context parallel axis.
stripe_size (int | None): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing
Returns:
Returns:
Tuple[jnp.ndarray, ...], jnp.ndarray:
Tuple[jnp.ndarray, ...], jnp.ndarray:
- The first tuple contains the gradients with respect to the input `qkv` tensors in the
- The first tuple contains the gradients with respect to the input `qkv` tensors in the
...
@@ -2755,6 +3541,28 @@ def fused_attn_bwd(
...
@@ -2755,6 +3541,28 @@ def fused_attn_bwd(
assert
bias
is
None
assert
bias
is
None
bias
=
jnp
.
zeros
(
0
,
dtype
=
qkv
[
0
].
dtype
)
bias
=
jnp
.
zeros
(
0
,
dtype
=
qkv
[
0
].
dtype
)
if
softmax_offset
is
None
:
assert
softmax_type
!=
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
,
f
"Unknown
{
softmax_type
=
}
"
if
softmax_type
==
AttnSoftmaxType
.
OFF_BY_ONE_SOFTMAX
:
num_heads
=
qkv
[
0
].
shape
[
-
2
]
# Create tensor [1, h, 1, 1] filled with zeros
softmax_offset
=
jnp
.
zeros
((
1
,
num_heads
,
1
,
1
),
dtype
=
jnp
.
float32
)
# Shard by heads dimension
softmax_offset
=
with_sharding_constraint_by_logical_axes
(
softmax_offset
,
(
None
,
HEAD_AXES
,
None
,
None
)
)
elif
softmax_type
==
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
softmax_offset
=
jnp
.
zeros
(
0
,
dtype
=
jnp
.
float32
)
else
:
raise
NotImplementedError
(
f
"Unknown
{
softmax_type
=
}
"
)
else
:
softmax_offset
=
softmax_offset
.
astype
(
jnp
.
float32
)
# Shard by heads dimension if not VANILLA_SOFTMAX
if
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
softmax_offset
=
with_sharding_constraint_by_logical_axes
(
softmax_offset
,
(
None
,
HEAD_AXES
,
None
,
None
)
)
# TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# sm100+
# sm100+
compute_capabilities
=
get_all_device_compute_capability
()
compute_capabilities
=
get_all_device_compute_capability
()
...
@@ -2767,6 +3575,7 @@ def fused_attn_bwd(
...
@@ -2767,6 +3575,7 @@ def fused_attn_bwd(
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
softmax_type
=
softmax_type
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
is_training
=
is_training
,
...
@@ -2775,11 +3584,15 @@ def fused_attn_bwd(
...
@@ -2775,11 +3584,15 @@ def fused_attn_bwd(
context_parallel_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_load_balanced
=
context_parallel_causal_load_balanced
,
cp_axis
=
_maybe_context_parallel_axis
(
context_parallel_axis
),
cp_axis
=
_maybe_context_parallel_axis
(
context_parallel_axis
),
cp_striped_window_size
=
None
,
cp_striped_window_size
=
None
,
stripe_size
=
stripe_size
,
)
)
primitive
=
None
primitive
=
None
match
context_parallel_strategy
:
match
context_parallel_strategy
:
case
CPStrategy
.
DEFAULT
|
CPStrategy
.
ALL_GATHER
:
case
CPStrategy
.
DEFAULT
|
CPStrategy
.
ALL_GATHER
:
if
qkv_layout
.
is_thd
():
primitive
=
FusedAttnCPStripedWithAllGatherBwdPrimitive
.
outer_primitive
else
:
primitive
=
FusedAttnCPWithAllGatherBwdPrimitive
.
outer_primitive
primitive
=
FusedAttnCPWithAllGatherBwdPrimitive
.
outer_primitive
case
CPStrategy
.
RING
:
case
CPStrategy
.
RING
:
if
qkv_layout
.
is_thd
():
if
qkv_layout
.
is_thd
():
...
@@ -2788,9 +3601,10 @@ def fused_attn_bwd(
...
@@ -2788,9 +3601,10 @@ def fused_attn_bwd(
primitive
=
FusedRingAttnBwdPrimitive
.
outer_primitive
primitive
=
FusedRingAttnBwdPrimitive
.
outer_primitive
seq_desc_flatten
,
_
=
jax
.
tree
.
flatten
(
sequence_descriptor
)
seq_desc_flatten
,
_
=
jax
.
tree
.
flatten
(
sequence_descriptor
)
*
qkv_grads
,
bias_grad
=
primitive
.
bind
(
*
qkv_grads
,
bias_grad
,
softmax_offset_grad
=
primitive
.
bind
(
*
qkv_for_primitive
,
*
qkv_for_primitive
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -2798,4 +3612,4 @@ def fused_attn_bwd(
...
@@ -2798,4 +3612,4 @@ def fused_attn_bwd(
*
seq_desc_flatten
,
*
seq_desc_flatten
,
config
=
fused_config
,
config
=
fused_config
,
)
)
return
tuple
(
qkv_grads
[:
len
(
qkv
)]),
bias_grad
return
tuple
(
qkv_grads
[:
len
(
qkv
)]),
bias_grad
,
softmax_offset_grad
transformer_engine/jax/cpp_extensions/base.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""JAX/TE base custom ops"""
"""JAX/TE base custom ops"""
...
@@ -176,6 +176,9 @@ _primitive_registry = {}
...
@@ -176,6 +176,9 @@ _primitive_registry = {}
def
register_primitive
(
cls
,
outer_only
=
False
):
def
register_primitive
(
cls
,
outer_only
=
False
):
"""
"""
Register a JAX primitive and add it to the internal registry.
Register a JAX primitive and add it to the internal registry.
Inner primitive - single device, no sharding awareness, eager mode fallback
Outer primitive - multi device, sharding aware, partition() distributes work,
used when there's a dev mesh context
"""
"""
_primitive_registry
[
cls
.
__name__
]
=
cls
_primitive_registry
[
cls
.
__name__
]
=
cls
...
@@ -190,14 +193,17 @@ def register_primitive(cls, outer_only=False):
...
@@ -190,14 +193,17 @@ def register_primitive(cls, outer_only=False):
inner_p
=
core
.
Primitive
(
cls
.
name
)
inner_p
=
core
.
Primitive
(
cls
.
name
)
dispatch
.
prim_requires_devices_during_lowering
.
add
(
inner_p
)
dispatch
.
prim_requires_devices_during_lowering
.
add
(
inner_p
)
inner_p
.
multiple_results
=
cls
.
multiple_results
inner_p
.
multiple_results
=
cls
.
multiple_results
# Define eager execution implementation (by invoking it's MLIR lowering)
inner_p
.
def_impl
(
partial
(
xla
.
apply_primitive
,
inner_p
))
inner_p
.
def_impl
(
partial
(
xla
.
apply_primitive
,
inner_p
))
inner_p
.
def_abstract_eval
(
cls
.
abstract
)
inner_p
.
def_abstract_eval
(
cls
.
abstract
)
mlir
.
register_lowering
(
inner_p
,
cls
.
lowering
,
platform
=
"cuda"
)
mlir
.
register_lowering
(
inner_p
,
cls
.
lowering
,
platform
=
"cuda"
)
cls
.
inner_primitive
=
inner_p
cls
.
inner_primitive
=
inner_p
# Create the outer primitive for distributed execution
outer_p
=
core
.
Primitive
(
name_of_wrapper_p
())
outer_p
=
core
.
Primitive
(
name_of_wrapper_p
())
dispatch
.
prim_requires_devices_during_lowering
.
add
(
outer_p
)
dispatch
.
prim_requires_devices_during_lowering
.
add
(
outer_p
)
outer_p
.
multiple_results
=
cls
.
multiple_results
outer_p
.
multiple_results
=
cls
.
multiple_results
# Define the eager execution implementation
outer_p
.
def_impl
(
cls
.
outer_impl
)
outer_p
.
def_impl
(
cls
.
outer_impl
)
outer_p
.
def_abstract_eval
(
cls
.
outer_abstract
)
outer_p
.
def_abstract_eval
(
cls
.
outer_abstract
)
batching
.
primitive_batchers
[
outer_p
]
=
cls
.
batcher
batching
.
primitive_batchers
[
outer_p
]
=
cls
.
batcher
...
...
transformer_engine/jax/cpp_extensions/gemm.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""JAX te modules"""
"""JAX te modules"""
...
@@ -39,12 +39,12 @@ from ..quantize import (
...
@@ -39,12 +39,12 @@ from ..quantize import (
Quantizer
,
Quantizer
,
GroupedQuantizer
,
GroupedQuantizer
,
QuantizerSet
,
QuantizerSet
,
QuantizeLayout
,
noop_quantizer_set
,
noop_quantizer_set
,
is_fp8_gemm_with_all_layouts_supported
,
is_fp8_gemm_with_all_layouts_supported
,
apply_padding_to_scale_inv
,
apply_padding_to_scale_inv
,
get_quantize_config_with_recipe
,
get_quantize_config_with_recipe
,
get_global_quantize_recipe
,
get_global_quantize_recipe
,
QuantizeLayout
,
)
)
from
.misc
import
get_padded_spec
,
is_all_reduce_in_float32
from
.misc
import
get_padded_spec
,
is_all_reduce_in_float32
from
..sharding
import
(
from
..sharding
import
(
...
...
Prev
1
…
19
20
21
22
23
24
25
26
27
…
33
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment