Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -9,7 +9,7 @@ These wrappers add logic related to debugging, using the nvdlfw_inspect package. ...@@ -9,7 +9,7 @@ These wrappers add logic related to debugging, using the nvdlfw_inspect package.
""" """
from __future__ import annotations from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union from typing import Optional, Tuple, Iterable, Union, List
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -62,12 +62,17 @@ class DebugQuantizer(Quantizer): ...@@ -62,12 +62,17 @@ class DebugQuantizer(Quantizer):
self.tp_group = tp_group # used in inspect_tensor calls self.tp_group = tp_group # used in inspect_tensor calls
self.iteration = TEDebugState.get_iteration() self.iteration = TEDebugState.get_iteration()
# Configure parent quantizer
if parent_quantizer is not None:
# .internal = True is slightly faster, but results # .internal = True is slightly faster, but results
# in errors when caching the weights. # in errors when caching the weights.
# Setting .internal = False is safer. # Setting .internal = False is safer.
if parent_quantizer is not None:
parent_quantizer.internal = False parent_quantizer.internal = False
# .optimize_for_gemm = True is not supported because debug
# quantizers perform non-GEMM operations.
parent_quantizer.optimize_for_gemm = False
self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name] self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name]
# next iteration when this quantizer will call any API # next iteration when this quantizer will call any API
...@@ -556,6 +561,23 @@ class DebugQuantizer(Quantizer): ...@@ -556,6 +561,23 @@ class DebugQuantizer(Quantizer):
if not self.output_tensor: if not self.output_tensor:
self._update_parent_quantizer_usage() self._update_parent_quantizer_usage()
@classmethod
def multi_tensor_quantize(
cls,
tensor: torch.Tensor,
quantizers: List[Quantizer],
m_splits: List[int],
activation_dtype: torch.dtype,
) -> List[DebugQuantizedTensor]:
"""
Splits a tensor into a list of tensors and quantizes each tensor using a list of quantizers.
"""
tensors = torch.split(tensor, m_splits)
output = []
for tensor, quantizer in zip(tensors, quantizers):
output.append(quantizer.quantize(tensor, dtype=activation_dtype))
return output
class DebugQuantizedTensor(QuantizedTensorStorage): class DebugQuantizedTensor(QuantizedTensorStorage):
""" """
...@@ -623,9 +645,9 @@ class DebugQuantizedTensor(QuantizedTensorStorage): ...@@ -623,9 +645,9 @@ class DebugQuantizedTensor(QuantizedTensorStorage):
"""Is used in the python gemm() to get tensor or transpose of the tensor.""" """Is used in the python gemm() to get tensor or transpose of the tensor."""
return self.rowwise_gemm_tensor if not transpose else self.columnwise_gemm_tensor return self.rowwise_gemm_tensor if not transpose else self.columnwise_gemm_tensor
def size(self): def size(self, *args):
"""Size of the tensor.""" """Size of the tensor."""
return self.rowwise_gemm_tensor.size() return self.rowwise_gemm_tensor.size(*args)
def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None): def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None):
"""Update usage of the tensor.""" """Update usage of the tensor."""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for JAX. """Transformer Engine bindings for JAX.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Activation functions for Transformer Engine in JAX. """Activation functions for Transformer Engine in JAX.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX multi-head attention modules""" """JAX multi-head attention modules"""
...@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type ...@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_QKV_Format from transformer_engine_jax import NVTE_QKV_Format
from transformer_engine_jax import nvte_get_qkv_format from transformer_engine_jax import nvte_get_qkv_format
from transformer_engine_jax import NVTE_Softmax_Type
from . import cpp_extensions as tex from . import cpp_extensions as tex
...@@ -74,6 +75,35 @@ class AttnMaskType(Enum): ...@@ -74,6 +75,35 @@ class AttnMaskType(Enum):
] ]
class AttnSoftmaxType(Enum):
"""
VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)),
LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [H].
"""
VANILLA_SOFTMAX = NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX
OFF_BY_ONE_SOFTMAX = NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX
LEARNABLE_SOFTMAX = NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX
@classmethod
def from_str(cls, softmax_type: str) -> "AttnSoftmaxType":
"""Convert string to AttnSoftmaxType: 'vanilla', 'off_by_one', or 'learnable'."""
softmax_type_map = {
"vanilla": cls.VANILLA_SOFTMAX,
"off_by_one": cls.OFF_BY_ONE_SOFTMAX,
"learnable": cls.LEARNABLE_SOFTMAX,
}
result = softmax_type_map.get(softmax_type)
if result is None:
raise ValueError(
f"Unknown softmax_type: {softmax_type}. "
"Valid options: 'vanilla', 'off_by_one', 'learnable'"
)
return result
class QKVFormat(Enum): class QKVFormat(Enum):
""" """
SBHD: q,k,v memory layout with [s, b, ..., h, d] SBHD: q,k,v memory layout with [s, b, ..., h, d]
...@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available( ...@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available(
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_probability, dropout_probability,
q_num_heads, q_num_heads,
kv_num_heads, kv_num_heads,
...@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available( ...@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available(
""" """
To check whether the fused attention kernel is supported To check whether the fused attention kernel is supported
""" """
window_size_tuple = (-1, -1) if window_size is None else window_size
def make_helper(attn_mask_type): def make_helper(attn_mask_type):
return tex.FusedAttnHelper( return tex.FusedAttnHelper(
...@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available( ...@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available(
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
softmax_type,
dropout_probability, dropout_probability,
q_num_heads, q_num_heads,
kv_num_heads, kv_num_heads,
...@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available( ...@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available(
kv_max_seqlen, kv_max_seqlen,
head_dim_qk, head_dim_qk,
head_dim_v, head_dim_v,
(-1, -1) if window_size is None else window_size, window_size_tuple,
) )
return make_helper(attn_mask_type).is_fused_attn_kernel_available() return make_helper(attn_mask_type).is_fused_attn_kernel_available()
...@@ -353,23 +386,57 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout): ...@@ -353,23 +386,57 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
return batch, q_max_seqlen, kv_max_seqlen return batch, q_max_seqlen, kv_max_seqlen
def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int): def reorder_causal_load_balancing(
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int | None = None
):
"""Reorders a tensor for load balancing the compute of causal attention.""" """Reorders a tensor for load balancing the compute of causal attention."""
if strategy == ReorderStrategy.DualChunkSwap: if strategy == ReorderStrategy.DualChunkSwap:
if stripe_size is not None:
raise ValueError(
f"Incorrect value for CP dual chunk reordering {stripe_size=}. stripe_size must be"
" None"
)
return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False) return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False)
if strategy == ReorderStrategy.Striped: if strategy == ReorderStrategy.Striped:
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False) # stripe_size > 1 is only supported for CP+THD+AG+Striped>1+SWA
# stripe_size = 128 is recommended for CP+THD+AG+Striped>1+SWA
if stripe_size is not None and stripe_size <= 0:
raise ValueError(
f"Incorrect value for CP striped reordering {stripe_size=}. stripe_size must be a"
" positive integer"
)
# Supporting old API defaults of stripe_size=1
effective_stripe_size = 1 if stripe_size is None else stripe_size
return tex.attention.reorder_causal_striped(
tensor, cp_size, seq_dim, False, effective_stripe_size
)
raise ValueError(f"Unsupported {strategy=}") raise ValueError(f"Unsupported {strategy=}")
def inverse_reorder_causal_load_balancing( def inverse_reorder_causal_load_balancing(
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int | None = None
): ):
"""Inverse operation of `reorder_causal_load_balancing`.""" """Inverse operation of `reorder_causal_load_balancing`."""
if strategy == ReorderStrategy.DualChunkSwap: if strategy == ReorderStrategy.DualChunkSwap:
if stripe_size is not None:
raise ValueError(
f"Incorrect value for CP dual chunk reordering {stripe_size=}. stripe_size must be"
" None"
)
return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True) return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True)
if strategy == ReorderStrategy.Striped: if strategy == ReorderStrategy.Striped:
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True) # stripe_size > 1 is only supported for CP+THD+AG+Striped>1+SWA
# stripe_size = 128 is recommended for CP+THD+AG+Striped>1+SWA
if stripe_size is not None and stripe_size <= 0:
raise ValueError(
f"Incorrect value for CP reordering {stripe_size=}. stripe_size must be a positive"
" integer"
)
# Supporting old API defaults of stripe_size=1
effective_stripe_size = 1 if stripe_size is None else stripe_size
return tex.attention.reorder_causal_striped(
tensor, cp_size, seq_dim, True, effective_stripe_size
)
raise ValueError(f"Unsupported {strategy=}") raise ValueError(f"Unsupported {strategy=}")
...@@ -497,6 +564,11 @@ def _segment_ids_pos_to_seqlens_offsets( ...@@ -497,6 +564,11 @@ def _segment_ids_pos_to_seqlens_offsets(
# #
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to # This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements. # examine only O(Q+KV) elements.
# For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation
# using the segment ids and pos along with mask type (causal or brcm) is sufficient.
# It does not need to involve SW for this mask's creation
# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
if (attn_mask_type.is_causal() and window_size is None) or ( if (attn_mask_type.is_causal() and window_size is None) or (
window_size == (-1, -1) and not attn_mask_type.is_bottom_right() window_size == (-1, -1) and not attn_mask_type.is_bottom_right()
...@@ -558,21 +630,6 @@ def _segment_ids_pos_to_seqlens_offsets( ...@@ -558,21 +630,6 @@ def _segment_ids_pos_to_seqlens_offsets(
) )
attn_mask = jnp.logical_and(segment_mask, causal_mask) attn_mask = jnp.logical_and(segment_mask, causal_mask)
# TODO(KshitijLakhani): Evaluate if swa_mask is needed to procure seqlen and offsets
swa_mask = (
make_swa_mask(
segment_pos_q,
segment_pos_kv,
window_size,
dtype=jnp.bool,
segment_ids_q=segment_ids_q,
segment_ids_kv=segment_ids_kv,
)
if attn_mask_type.is_bottom_right()
else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool)
)
attn_mask = jnp.logical_and(attn_mask, swa_mask)
attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0) attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0)
q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset( q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset(
attn_mask_with_id, max_segments_per_seq attn_mask_with_id, max_segments_per_seq
...@@ -601,7 +658,7 @@ class SequenceDescriptor: ...@@ -601,7 +658,7 @@ class SequenceDescriptor:
- SequenceDescriptor.from_seqlens_and_offsets - SequenceDescriptor.from_seqlens_and_offsets
For THD (packed) cases, where each batch may have not only 1 sequence. For THD (packed) cases, where each batch may have not only 1 sequence.
- SequenceDescriptor.from_segment_ids_and_pos - SequenceDescriptor.from_segment_ids_and_pos
Experimental feature for THD (packed) cases with context parallelism. Experimental feature for BSHD (with and without reordering) and THD (packed) cases without reordering
""" """
seqlens: Optional[Tuple[jnp.ndarray, jnp.ndarray]] seqlens: Optional[Tuple[jnp.ndarray, jnp.ndarray]]
...@@ -739,9 +796,14 @@ class SequenceDescriptor: ...@@ -739,9 +796,14 @@ class SequenceDescriptor:
cls, cls,
segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None,
*,
is_thd: bool,
is_segment_ids_reordered: bool,
) -> SequenceDescriptor: ) -> SequenceDescriptor:
""" """
Experimental factory method for inputs with segment IDs and optional positions. (THD) Experimental factory method for inputs with segment IDs and optional positions.
segment_pos = None to be used only for: BSHD with or without load balancing and,
THD without load balancing
Args: Args:
segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids): segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids):
- q_segment_ids (jnp.ndarray): - q_segment_ids (jnp.ndarray):
...@@ -755,22 +817,84 @@ class SequenceDescriptor: ...@@ -755,22 +817,84 @@ class SequenceDescriptor:
The position inside each segment for query, with shape [batch, max_seqlen]. The position inside each segment for query, with shape [batch, max_seqlen].
- kv_segment_pos (jnp.ndarray): - kv_segment_pos (jnp.ndarray):
The position inside each segment for key, value, with shape [batch, max_seqlen]. The position inside each segment for key, value, with shape [batch, max_seqlen].
is_thd(bool): If True, QKVLayout is of type THD, else it is BSHD
is_segment_ids_reordered(bool): If True, the segment ids have been reordered for load balancing.
Only THD with load balancing is expected to have this flag set to True
Return: Return:
A SequenceDescriptor with segment_ids/segment_pos initialized. A SequenceDescriptor with segment_ids/segment_pos initialized.
""" """
q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids) q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids)
if segment_pos is not None: # Using defaults : segment pos has to be generated.
segment_pos = cls._expand_to_pair(segment_pos) if segment_pos is None:
else: # THD + load balanced segment_ids are not supported in this function
# BSHD + load balanced segment_ids are incorrect as BSHD handles reordering within the primitive itself
if is_segment_ids_reordered:
assert not is_thd, (
f"{segment_pos=} default arg is not supported for load balanced reordered"
" (Striped) THD inputs. Please pass the load balanced reordered segment_pos"
" and segment_ids explicitly to {from_segment_ids_and_pos.__qualname__}"
" using convenience function reorder_causal_load_balancing()"
)
assert is_thd, (
f"{segment_pos=} default arg is not supported for load balanced reordered (Dual"
" Chunk) BSHD inputs. BSHD segment_pos and segment_ids do not need to be load"
" balanced reordered. The reordering for these is performed within the"
" primitive"
)
# Generate the default pos for THD and BSHD non-reordered segment_ids
def generate_default_pos(seg_ids):
if is_thd:
batch_size, seq_size = seg_ids.shape
# Assume that the first token belongs to a segment and is not a padded token
first_is_segment = jnp.full((batch_size, 1), True, dtype=bool)
# Get segment start positions
segment_start = jnp.concatenate(
[
first_is_segment,
(seg_ids[..., 1:] != seg_ids[..., :-1]) & (seg_ids[..., 1:] != 0),
],
axis=-1,
)
# Get offset for location where new segment starts
segment_start_idx = jax.vmap(lambda row: jnp.arange(row.size) * row)(
segment_start
)
segment_start_offsets = jax.vmap(jnp.maximum.accumulate)(segment_start_idx)
# Get the last non-zero index - after this everything is padding
# (B,)
last_nonzero_idx = jax.vmap(
lambda segids_row: jnp.max(
jnp.where(segids_row != 0, jnp.arange(seq_size), -1)
)
)(seg_ids)
seg_pos_no_thd = jnp.arange(seq_size)
# Get a mask which can be used to zero out all the padding at the end (after the non-zero index)
mask = seg_pos_no_thd <= last_nonzero_idx[:, None]
# Get the unmasked seg_pos for the THD sequence
seg_pos = (
jnp.broadcast_to(jnp.arange(seq_size), seg_ids.shape)
- segment_start_offsets
)
# Use the mask to zero out the padding at the end (after the non-zero index)
segment_pos = jax.vmap(
lambda pos_row, mask_row: jnp.where(mask_row, pos_row, 0)
)(seg_pos, mask)
return segment_pos
def generate_default_pos(segment_ids): seqlen = seg_ids.shape[-1]
seqlen = segment_ids.shape[-1] return jnp.broadcast_to(jnp.arange(seqlen), seg_ids.shape)
return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape)
q_seg_pos = generate_default_pos(q_seg_ids) q_seg_pos = generate_default_pos(q_seg_ids)
kv_seg_pos = generate_default_pos(kv_seg_ids) kv_seg_pos = generate_default_pos(kv_seg_ids)
segment_pos = (q_seg_pos, kv_seg_pos) segment_pos = (q_seg_pos, kv_seg_pos)
# Explicitly passed segment_pos
else:
segment_pos = cls._expand_to_pair(segment_pos)
return cls( return cls(
segment_ids=(q_seg_ids, kv_seg_ids), segment_ids=(q_seg_ids, kv_seg_ids),
...@@ -786,6 +910,7 @@ def _legacy_fused_attn( ...@@ -786,6 +910,7 @@ def _legacy_fused_attn(
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout, qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
...@@ -793,6 +918,7 @@ def _legacy_fused_attn( ...@@ -793,6 +918,7 @@ def _legacy_fused_attn(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
softmax_offset: Optional[jnp.ndarray] = None,
): ):
""" """
Perform non-THD (non-packed) cuDNN fused attention. Perform non-THD (non-packed) cuDNN fused attention.
...@@ -815,6 +941,7 @@ def _legacy_fused_attn( ...@@ -815,6 +941,7 @@ def _legacy_fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout. seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
...@@ -863,10 +990,12 @@ def _legacy_fused_attn( ...@@ -863,10 +990,12 @@ def _legacy_fused_attn(
output = _fused_attn( output = _fused_attn(
qkv, qkv,
bias, bias,
softmax_offset,
SequenceDescriptor.from_seqlens((q_seq_lens, kv_seq_lens)), SequenceDescriptor.from_seqlens((q_seq_lens, kv_seq_lens)),
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
...@@ -900,6 +1029,7 @@ def fused_attn_thd( ...@@ -900,6 +1029,7 @@ def fused_attn_thd(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
softmax_offset: Optional[jnp.ndarray] = None,
): ):
""" """
Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
...@@ -937,6 +1067,7 @@ def fused_attn_thd( ...@@ -937,6 +1067,7 @@ def fused_attn_thd(
output = _fused_attn( output = _fused_attn(
qkv, qkv,
bias, bias,
softmax_offset,
SequenceDescriptor.from_seqlens_and_offsets( SequenceDescriptor.from_seqlens_and_offsets(
(q_seq_lens, kv_seq_lens), (q_seq_offsets, kv_seq_offsets) (q_seq_lens, kv_seq_lens), (q_seq_offsets, kv_seq_offsets)
), ),
...@@ -945,6 +1076,7 @@ def fused_attn_thd( ...@@ -945,6 +1076,7 @@ def fused_attn_thd(
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq, max_segments_per_seq=max_segments_per_seq,
...@@ -957,15 +1089,17 @@ def fused_attn_thd( ...@@ -957,15 +1089,17 @@ def fused_attn_thd(
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)) @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
def _fused_attn( def _fused_attn(
qkv: Tuple[jnp.ndarray, ...], qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray], bias: Optional[jnp.ndarray],
softmax_offset: Optional[jnp.ndarray],
sequence_descriptor: SequenceDescriptor, sequence_descriptor: SequenceDescriptor,
seed: Optional[jnp.ndarray], seed: Optional[jnp.ndarray],
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout, qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
...@@ -975,15 +1109,18 @@ def _fused_attn( ...@@ -975,15 +1109,18 @@ def _fused_attn(
context_parallel_causal_load_balanced: bool, context_parallel_causal_load_balanced: bool,
context_parallel_axis: str, context_parallel_axis: str,
context_checkpoint_name: str = "context", context_checkpoint_name: str = "context",
stripe_size: int | None = None,
): ):
output, _ = _fused_attn_fwd_rule( output, _ = _fused_attn_fwd_rule(
qkv, qkv,
bias, bias,
softmax_offset,
sequence_descriptor, sequence_descriptor,
seed, seed,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
qkv_layout, qkv_layout,
softmax_type,
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
...@@ -993,6 +1130,7 @@ def _fused_attn( ...@@ -993,6 +1130,7 @@ def _fused_attn(
context_parallel_causal_load_balanced, context_parallel_causal_load_balanced,
context_parallel_axis, context_parallel_axis,
context_checkpoint_name=context_checkpoint_name, context_checkpoint_name=context_checkpoint_name,
stripe_size=stripe_size,
) )
return output return output
...@@ -1000,11 +1138,13 @@ def _fused_attn( ...@@ -1000,11 +1138,13 @@ def _fused_attn(
def _fused_attn_fwd_rule( def _fused_attn_fwd_rule(
qkv, qkv,
bias, bias,
softmax_offset,
sequence_descriptor, sequence_descriptor,
seed, seed,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
qkv_layout, qkv_layout,
softmax_type,
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
...@@ -1014,14 +1154,17 @@ def _fused_attn_fwd_rule( ...@@ -1014,14 +1154,17 @@ def _fused_attn_fwd_rule(
context_parallel_causal_load_balanced, context_parallel_causal_load_balanced,
context_parallel_axis, context_parallel_axis,
context_checkpoint_name, context_checkpoint_name,
stripe_size,
): ):
output, softmax_aux, rng_state = tex.fused_attn_fwd( output, softmax_aux, rng_state = tex.fused_attn_fwd(
qkv, qkv,
bias, bias,
softmax_offset,
sequence_descriptor, sequence_descriptor,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
...@@ -1031,6 +1174,7 @@ def _fused_attn_fwd_rule( ...@@ -1031,6 +1174,7 @@ def _fused_attn_fwd_rule(
context_parallel_strategy=context_parallel_strategy, context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
stripe_size=stripe_size,
) )
output = checkpoint_name(output, context_checkpoint_name) output = checkpoint_name(output, context_checkpoint_name)
softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name) softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name)
...@@ -1041,6 +1185,7 @@ def _fused_attn_fwd_rule( ...@@ -1041,6 +1185,7 @@ def _fused_attn_fwd_rule(
sequence_descriptor, sequence_descriptor,
softmax_aux, softmax_aux,
rng_state, rng_state,
softmax_offset,
output, output,
) )
...@@ -1049,6 +1194,7 @@ def _fused_attn_bwd_rule( ...@@ -1049,6 +1194,7 @@ def _fused_attn_bwd_rule(
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
qkv_layout, qkv_layout,
softmax_type,
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
...@@ -1058,6 +1204,7 @@ def _fused_attn_bwd_rule( ...@@ -1058,6 +1204,7 @@ def _fused_attn_bwd_rule(
context_parallel_causal_load_balanced, context_parallel_causal_load_balanced,
context_parallel_axis, context_parallel_axis,
context_checkpoint_name, context_checkpoint_name,
stripe_size,
ctx, ctx,
dz, dz,
): ):
...@@ -1068,11 +1215,13 @@ def _fused_attn_bwd_rule( ...@@ -1068,11 +1215,13 @@ def _fused_attn_bwd_rule(
sequence_descriptor, sequence_descriptor,
softmax_aux, softmax_aux,
rng_state, rng_state,
softmax_offset,
output, output,
) = ctx ) = ctx
grad_qkv, grad_bias = tex.fused_attn_bwd( grad_qkv, grad_bias, grad_softmax_offset = tex.fused_attn_bwd(
qkv, qkv,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1080,6 +1229,7 @@ def _fused_attn_bwd_rule( ...@@ -1080,6 +1229,7 @@ def _fused_attn_bwd_rule(
sequence_descriptor, sequence_descriptor,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
...@@ -1089,12 +1239,16 @@ def _fused_attn_bwd_rule( ...@@ -1089,12 +1239,16 @@ def _fused_attn_bwd_rule(
context_parallel_strategy=context_parallel_strategy, context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
stripe_size=stripe_size,
) )
if attn_bias_type == AttnBiasType.NO_BIAS: if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None grad_bias = None
if softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX:
grad_softmax_offset = None
return ( return (
grad_qkv, grad_qkv,
grad_bias, grad_bias,
grad_softmax_offset,
None, None,
None, None,
) )
...@@ -1111,6 +1265,7 @@ def fused_attn( ...@@ -1111,6 +1265,7 @@ def fused_attn(
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout, qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
...@@ -1120,6 +1275,8 @@ def fused_attn( ...@@ -1120,6 +1275,8 @@ def fused_attn(
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
context_checkpoint_name: str = "context", context_checkpoint_name: str = "context",
softmax_offset: Optional[jnp.ndarray] = None,
stripe_size: int | None = None,
): ):
""" """
Perform cuDNN fused attention. Perform cuDNN fused attention.
...@@ -1139,6 +1296,7 @@ def fused_attn( ...@@ -1139,6 +1296,7 @@ def fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout. seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
...@@ -1153,6 +1311,14 @@ def fused_attn( ...@@ -1153,6 +1311,14 @@ def fused_attn(
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass. context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape
[1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX.
If provided, this parameter will receive gradients during backpropagation.
stripe_size (int | None):
Indicates the striping size to be used when using ReorderStrategy.Striped.
Currently, a stripe_size > 1 is only supported for CP + THD + Striped + AG, whereas a stripe_size=1
is supported for both, CP + THD + Striped + AG and CP + THD + Striped + P2P(Ring)
None indicates no striping strategy
Returns: Returns:
(jnp.ndarray): The output tensor from the fused attention. (jnp.ndarray): The output tensor from the fused attention.
...@@ -1200,6 +1366,7 @@ def fused_attn( ...@@ -1200,6 +1366,7 @@ def fused_attn(
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
softmax_type=softmax_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
...@@ -1208,15 +1375,18 @@ def fused_attn( ...@@ -1208,15 +1375,18 @@ def fused_attn(
context_parallel_strategy=context_parallel_strategy, context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
softmax_offset=softmax_offset,
) )
output = _fused_attn( output = _fused_attn(
qkv, qkv,
bias, bias,
softmax_offset,
sequence_descriptor, sequence_descriptor,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
softmax_type=softmax_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
...@@ -1226,5 +1396,6 @@ def fused_attn( ...@@ -1226,5 +1396,6 @@ def fused_attn(
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
context_checkpoint_name=context_checkpoint_name, context_checkpoint_name=context_checkpoint_name,
stripe_size=stripe_size,
) )
return output return output
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Checkpoint policies for Transformer Engine in JAX. """Checkpoint policies for Transformer Engine in JAX.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Python interface for c++ extensions""" """Python interface for c++ extensions"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for activation""" """JAX/TE custom ops for activation"""
...@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a ...@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
QuantizeLayout,
DelayedScaleQuantizer, DelayedScaleQuantizer,
ScalingMode, ScalingMode,
QuantizeLayout,
) )
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for amax calculation""" """JAX/TE custom ops for amax calculation"""
...@@ -73,7 +73,7 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -73,7 +73,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
transpose_batch_sequence, transpose_batch_sequence,
): ):
""" """
amax calcuation abstract amax calculation abstract
""" """
del amax_scope, transpose_batch_sequence del amax_scope, transpose_batch_sequence
...@@ -251,7 +251,7 @@ class RHTAmaxCalculationPrimitive(BasePrimitive): ...@@ -251,7 +251,7 @@ class RHTAmaxCalculationPrimitive(BasePrimitive):
flatten_axis, flatten_axis,
): ):
""" """
amax calcuation implementation amax calculation implementation
""" """
assert RHTAmaxCalculationPrimitive.inner_primitive is not None assert RHTAmaxCalculationPrimitive.inner_primitive is not None
( (
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for attention""" """JAX/TE custom ops for attention"""
...@@ -20,11 +20,13 @@ from transformer_engine_jax import NVTE_Fused_Attn_Backend ...@@ -20,11 +20,13 @@ from transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
AttnSoftmaxType,
QKVLayout, QKVLayout,
QKVFormat, QKVFormat,
CPStrategy, CPStrategy,
SequenceDescriptor, SequenceDescriptor,
) )
from ..sharding import with_sharding_constraint_by_logical_axes, HEAD_AXES, is_mesh_available
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .misc import ( from .misc import (
...@@ -61,6 +63,7 @@ __all__ = [ ...@@ -61,6 +63,7 @@ __all__ = [
meta_fields=[ meta_fields=[
"attn_bias_type", "attn_bias_type",
"attn_mask_type", "attn_mask_type",
"softmax_type",
"qkv_layout", "qkv_layout",
"scaling_factor", "scaling_factor",
"dropout_probability", "dropout_probability",
...@@ -70,6 +73,7 @@ __all__ = [ ...@@ -70,6 +73,7 @@ __all__ = [
"context_parallel_load_balanced", "context_parallel_load_balanced",
"cp_axis", "cp_axis",
"cp_striped_window_size", "cp_striped_window_size",
"stripe_size",
], ],
) )
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -80,6 +84,7 @@ class _FusedAttnConfig: ...@@ -80,6 +84,7 @@ class _FusedAttnConfig:
attn_bias_type: AttnBiasType attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType attn_mask_type: AttnMaskType
softmax_type: AttnSoftmaxType
qkv_layout: QKVLayout qkv_layout: QKVLayout
scaling_factor: float scaling_factor: float
dropout_probability: float dropout_probability: float
...@@ -88,7 +93,10 @@ class _FusedAttnConfig: ...@@ -88,7 +93,10 @@ class _FusedAttnConfig:
window_size: Tuple[int, int] window_size: Tuple[int, int]
context_parallel_load_balanced: bool context_parallel_load_balanced: bool
cp_axis: str cp_axis: str
cp_striped_window_size: Tuple[int, int] # Only for CP + Ring + THD + SWA cp_striped_window_size: Tuple[int, int] # Only for CP + Ring P2P + THD + SWA
stripe_size: (
int | None
) # Only for CP + Striped. For Ring P2P, stripe_size=1 only.For AG, stripe_size>=1.
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -103,6 +111,7 @@ class FusedAttnHelper: ...@@ -103,6 +111,7 @@ class FusedAttnHelper:
qkv_layout: QKVLayout qkv_layout: QKVLayout
attn_bias_type: AttnBiasType attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType attn_mask_type: AttnMaskType
softmax_type: AttnSoftmaxType
dropout_probability: float dropout_probability: float
q_num_heads: int q_num_heads: int
kv_num_heads: int kv_num_heads: int
...@@ -125,6 +134,7 @@ class FusedAttnHelper: ...@@ -125,6 +134,7 @@ class FusedAttnHelper:
self.qkv_layout.value, self.qkv_layout.value,
self.attn_bias_type.value, self.attn_bias_type.value,
self.attn_mask_type.value, self.attn_mask_type.value,
self.softmax_type.value,
self.dropout_probability, self.dropout_probability,
self.q_num_heads, self.q_num_heads,
self.kv_num_heads, self.kv_num_heads,
...@@ -254,7 +264,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -254,7 +264,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
name = "te_fused_attn_forward_ffi" name = "te_fused_attn_forward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (13,) impl_static_args = (14,)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -264,6 +274,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -264,6 +274,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k_aval, k_aval,
v_aval, v_aval,
bias_aval, bias_aval,
softmax_offset_aval,
seed_aval, seed_aval,
q_seqlen_or_cu_seqlen_aval, q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval, kv_seqlen_or_cu_seqlen_aval,
...@@ -312,6 +323,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -312,6 +323,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config.qkv_layout, config.qkv_layout,
config.attn_bias_type, config.attn_bias_type,
config.attn_mask_type, config.attn_mask_type,
config.softmax_type,
config.dropout_probability, config.dropout_probability,
attn_heads, attn_heads,
num_gqa_groups, num_gqa_groups,
...@@ -375,6 +387,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -375,6 +387,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config.dropout_probability, config.dropout_probability,
config.attn_bias_type.value, config.attn_bias_type.value,
config.attn_mask_type.value, config.attn_mask_type.value,
config.softmax_type.value,
config.qkv_layout.value, config.qkv_layout.value,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype),
config.is_training, config.is_training,
...@@ -386,6 +399,12 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -386,6 +399,12 @@ class FusedAttnFwdPrimitive(BasePrimitive):
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
) )
assert softmax_offset_aval.dtype == jnp.float32
if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
assert softmax_offset_aval.shape == (1, attn_heads, 1, 1)
else:
assert softmax_offset_aval.shape == (0,)
return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
@staticmethod @staticmethod
...@@ -405,6 +424,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -405,6 +424,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
seed, seed,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
...@@ -453,6 +473,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -453,6 +473,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
seed, seed,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
...@@ -481,6 +502,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -481,6 +502,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=window_size_left, window_size_left=window_size_left,
window_size_right=window_size_right, window_size_right=window_size_right,
softmax_type=int(config.softmax_type.value),
) )
@staticmethod @staticmethod
...@@ -489,6 +511,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -489,6 +511,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
seed, seed,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
...@@ -508,7 +531,6 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -508,7 +531,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
segment_ids=(_q_segment_ids, _kv_segment_ids), segment_ids=(_q_segment_ids, _kv_segment_ids),
segment_pos=(_q_segment_pos, _kv_segment_pos), segment_pos=(_q_segment_pos, _kv_segment_pos),
) )
(q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
sequence_descriptor.get_seqlens_and_offsets( sequence_descriptor.get_seqlens_and_offsets(
config.attn_mask_type, config.attn_mask_type,
...@@ -517,7 +539,6 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -517,7 +539,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config.max_segments_per_seq, config.max_segments_per_seq,
) )
) )
if config.qkv_layout.is_thd(): if config.qkv_layout.is_thd():
def _fix_len_take(x, condition, fill_value=-1): def _fix_len_take(x, condition, fill_value=-1):
...@@ -579,6 +600,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -579,6 +600,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
seed, seed,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
...@@ -596,7 +618,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -596,7 +618,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
def batcher(batched_args, batch_dims, *, config): def batcher(batched_args, batch_dims, *, config):
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert FusedAttnFwdPrimitive.outer_primitive is not None assert FusedAttnFwdPrimitive.outer_primitive is not None
q_bdim, _, _, _, seed_bdim, *_ = batch_dims q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims
out_bdims = q_bdim, q_bdim, seed_bdim out_bdims = q_bdim, q_bdim, seed_bdim
return ( return (
...@@ -662,7 +684,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -662,7 +684,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
mesh, PartitionSpec(get_all_mesh_axes(), None) mesh, PartitionSpec(get_all_mesh_axes(), None)
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[5] = seed_sharding
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
...@@ -710,7 +732,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -710,7 +732,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
name = "te_fused_attn_backward_ffi" name = "te_fused_attn_backward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (16,) impl_static_args = (17,)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -720,6 +742,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -720,6 +742,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_aval, k_aval,
v_aval, v_aval,
bias_aval, bias_aval,
softmax_offset_aval,
softmax_aux_aval, softmax_aux_aval,
rng_state_aval, rng_state_aval,
output_aval, output_aval,
...@@ -781,6 +804,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -781,6 +804,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
config.dropout_probability, config.dropout_probability,
config.attn_bias_type.value, config.attn_bias_type.value,
config.attn_mask_type.value, config.attn_mask_type.value,
config.softmax_type.value,
config.qkv_layout.value, config.qkv_layout.value,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype),
config.is_training, config.is_training,
...@@ -798,15 +822,39 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -798,15 +822,39 @@ class FusedAttnBwdPrimitive(BasePrimitive):
shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype) shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype)
) )
return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval # Validate incoming softmax_offset shape and dtype
assert (
softmax_offset_aval.dtype == jnp.float32
), f"Incorrect softmax_offset dtype: {softmax_offset_aval.dtype}, expected: {jnp.float32}"
if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
assert softmax_offset_aval.shape == (1, attn_heads, 1, 1), (
f"Incorrect softmax_offset shape for {config.softmax_type}:"
f" {softmax_offset_aval.shape}, expected: (1, {attn_heads}, 1, 1)"
)
else:
assert softmax_offset_aval.shape == (0,), (
f"Incorrect softmax_offset shape for {config.softmax_type}:"
f" {softmax_offset_aval.shape}, expected: (0,)"
)
if config.softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX:
dsoftmax_offset_aval = q_aval.update(
shape=softmax_offset_aval.shape, dtype=softmax_offset_aval.dtype
)
else:
dsoftmax_offset_aval = q_aval.update(shape=(1, attn_heads, 1, 1), dtype=jnp.float32)
return dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval, wkspace_aval
@staticmethod @staticmethod
def outer_abstract(*args, **kwargs): def outer_abstract(*args, **kwargs):
""" """
Fused attention fwd outer primitive abstract Fused attention fwd outer primitive abstract
""" """
dq_aval, dk_aval, dv_aval, dbias_aval, _ = FusedAttnBwdPrimitive.abstract(*args, **kwargs) dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval, _ = (
return dq_aval, dk_aval, dv_aval, dbias_aval FusedAttnBwdPrimitive.abstract(*args, **kwargs)
)
return dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval
@staticmethod @staticmethod
def lowering( def lowering(
...@@ -815,6 +863,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -815,6 +863,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -866,6 +915,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -866,6 +915,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -897,6 +947,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -897,6 +947,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=window_size_left, window_size_left=window_size_left,
window_size_right=window_size_right, window_size_right=window_size_right,
softmax_type=int(config.softmax_type.value),
) )
@staticmethod @staticmethod
...@@ -905,6 +956,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -905,6 +956,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -993,11 +1045,12 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -993,11 +1045,12 @@ class FusedAttnBwdPrimitive(BasePrimitive):
q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind( dq, dk, dv, dbias, dsoftmax_offset, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
q, q,
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1012,15 +1065,15 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1012,15 +1065,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_kv_segment_pos, _kv_segment_pos,
config=config, config=config,
) )
return dq, dk, dv, dbias return dq, dk, dv, dbias, dsoftmax_offset
@staticmethod @staticmethod
def batcher(batched_args, batch_dims, *, config): def batcher(batched_args, batch_dims, *, config):
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert FusedAttnBwdPrimitive.outer_primitive is not None assert FusedAttnBwdPrimitive.outer_primitive is not None
q_bdim, k_bdim, v_bdim, *_ = batch_dims q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims
out_bdims = q_bdim, k_bdim, v_bdim, q_bdim out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim
return ( return (
FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config), FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config),
out_bdims, out_bdims,
...@@ -1033,11 +1086,13 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1033,11 +1086,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2]) v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3]) bias_spec = get_padded_spec(arg_infos[3])
softmax_offset_spec = get_padded_spec(arg_infos[4])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding, dsoftmax_offset_sharding)
@staticmethod @staticmethod
def partition(config, mesh, arg_infos, result_infos): def partition(config, mesh, arg_infos, result_infos):
...@@ -1046,21 +1101,30 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1046,21 +1101,30 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2]) v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3]) bias_spec = get_padded_spec(arg_infos[3])
softmax_offset_spec = get_padded_spec(arg_infos[4])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) out_shardings = (
dq_sharding,
dk_sharding,
dv_sharding,
dbias_sharding,
dsoftmax_offset_sharding,
)
def sharded_impl( def sharded_impl(
q, q,
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1074,11 +1138,13 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1074,11 +1138,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_q_segment_pos, _q_segment_pos,
_kv_segment_pos, _kv_segment_pos,
): ):
local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl( local_dq, local_dk, local_dv, local_dbias, local_dsoftmax_offset = (
FusedAttnBwdPrimitive.impl(
q, q,
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1093,17 +1159,22 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1093,17 +1159,22 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_kv_segment_pos, _kv_segment_pos,
config=config, config=config,
) )
)
global_dbias = local_dbias global_dbias = local_dbias
if config.attn_bias_type is not AttnBiasType.NO_BIAS: if config.attn_bias_type is not AttnBiasType.NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
return local_dq, local_dk, local_dv, global_dbias
global_dsoftmax_offset = local_dsoftmax_offset
if config.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
global_dsoftmax_offset = all_reduce_sum_along_dp_fsdp(local_dsoftmax_offset, mesh)
return local_dq, local_dk, local_dv, global_dbias, global_dsoftmax_offset
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod @staticmethod
def shardy_sharding_rule(config, mesh, value_types, result_types): def shardy_sharding_rule(config, mesh, value_types, result_types):
del config, mesh del config, mesh
# We only care about the four first arguments.
# Keep in sync with `infer_sharding_from_operands`. # Keep in sync with `infer_sharding_from_operands`.
input_spec = tuple((f"…{x}",) for x in range(len(value_types))) input_spec = tuple((f"…{x}",) for x in range(len(value_types)))
output_spec = tuple((f"…{x}",) for x in range(len(result_types))) output_spec = tuple((f"…{x}",) for x in range(len(result_types)))
...@@ -1165,31 +1236,38 @@ def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contig ...@@ -1165,31 +1236,38 @@ def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contig
return combined.reshape(ori_tensor_shape) return combined.reshape(ori_tensor_shape)
def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool): def reorder_causal_striped(
tensor, cp_size: int, seq_dim: int, is_inverse: bool, stripe_size: int = 1
):
"""Reorders a tensor for load balancing with striped pattern""" """Reorders a tensor for load balancing with striped pattern"""
origin_shape = tensor.shape origin_shape = tensor.shape
if origin_shape[seq_dim] % cp_size != 0: if stripe_size <= 0:
raise ValueError( raise ValueError(
"Expected origin_shape[seq_dim] is multiple of cp_size but got" f"Incorrect value for CP reordering {stripe_size=}. stripe_size must be a positive"
f" {origin_shape[seq_dim]=} and {cp_size=}" " integer"
)
if origin_shape[seq_dim] % (cp_size * stripe_size) != 0:
raise ValueError(
"Expected origin_shape[seq_dim] is multiple of cp_size*stripe_size but got"
f" {origin_shape[seq_dim]=}, {cp_size=}, {stripe_size=}, {cp_size*stripe_size=}"
) )
if not is_inverse: if not is_inverse:
new_shape = [ new_shape = [
*origin_shape[:seq_dim], *origin_shape[:seq_dim],
*[origin_shape[seq_dim] // cp_size, cp_size], *[origin_shape[seq_dim] // (cp_size * stripe_size), cp_size, stripe_size],
*origin_shape[seq_dim + 1 :], *origin_shape[seq_dim + 1 :],
] ]
else: else:
new_shape = [ new_shape = [
*origin_shape[:seq_dim], *origin_shape[:seq_dim],
*[cp_size, origin_shape[seq_dim] // cp_size], *[cp_size, origin_shape[seq_dim] // (cp_size * stripe_size), stripe_size],
*origin_shape[seq_dim + 1 :], *origin_shape[seq_dim + 1 :],
] ]
chunked_tensor = tensor.reshape(new_shape) striped_tensor = tensor.reshape(new_shape)
reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1) reordered_striped_tensor = jnp.swapaxes(striped_tensor, seq_dim, seq_dim + 1)
return reordered_chunked_tensor.reshape(origin_shape) return reordered_striped_tensor.reshape(origin_shape)
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -1203,43 +1281,85 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1203,43 +1281,85 @@ class _FusedAttnCPWithAllGatherHelper:
"""Checks if the context parallel implementation is supported by the given arguments.""" """Checks if the context parallel implementation is supported by the given arguments."""
header = "Context parallel fused attention" header = "Context parallel fused attention"
allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD] allowed_layouts = [
QKVLayout.BSHD_BS2HD,
QKVLayout.BSHD_BSHD_BSHD,
QKVLayout.THD_T2HD,
QKVLayout.THD_THD_THD,
]
if self.config.qkv_layout not in allowed_layouts: if self.config.qkv_layout not in allowed_layouts:
raise ValueError( raise ValueError(
f"{header} only supports layouts:" f"{header} only supports layouts:"
f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}" f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
) )
if (not self.config.qkv_layout.is_thd() and self.config.stripe_size is not None) or (
self.config.qkv_layout.is_thd() and self.config.stripe_size is None
):
raise ValueError(
f"{header} only supports Dual Chunk load balancing with BSHD layouts and Striped"
" load balancing with THD layouts"
)
if self.config.attn_bias_type != AttnBiasType.NO_BIAS: if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}") raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK] allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
if self.config.qkv_layout.is_thd():
allowed_masks.append(AttnMaskType.PADDING_CAUSAL_MASK)
if self.config.attn_mask_type not in allowed_masks: if self.config.attn_mask_type not in allowed_masks:
raise ValueError( raise ValueError(
f"{header} only supports masking types: " f"{header} only supports masking types: "
f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}" f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
) )
# Do not allow CP + AG + THD + Striped with NO_MASK
if (
self.config.attn_mask_type is not AttnMaskType.PADDING_CAUSAL_MASK
and self.config.qkv_layout.is_thd()
):
raise ValueError(f"{header} only supports PADDING_CAUSAL_MASK for THD types")
if self.config.max_segments_per_seq != 1: if self.config.max_segments_per_seq != 1 and (not self.config.qkv_layout.is_thd):
raise ValueError( raise ValueError(
f"{header} only supports max_segments_per_seq == 1 got:" f"{header} only supports max_segments_per_seq == 1 for BSHD layouts, got:"
f" {self.config.max_segments_per_seq}" f" {self.config.max_segments_per_seq}"
) )
if self.config.dropout_probability != 0.0: if self.config.dropout_probability != 0.0:
raise ValueError(f"{header} does not support dropout") raise ValueError(f"{header} does not support dropout")
if self.config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
raise ValueError(
f"{header} only supports VANILLA_SOFTMAX, got: {self.config.softmax_type}"
)
def get_adjusted_mask(self): def get_adjusted_mask(self):
"""Converts the mask for context parallelism.""" """Converts the mask for context parallelism."""
if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK: if (
self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK
and not self.config.qkv_layout.is_thd()
): # BSHD AG case only
return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
if (
self.config.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK
and self.config.qkv_layout.is_thd()
): # THD AG case only
return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK
return self.config.attn_mask_type return self.config.attn_mask_type
def get_adjusted_max_segments_per_seq(self, max_seqlen, cp_size):
"""Converts the max segments per seq for context parallelism AG + THD."""
# Estimating adjusted max segments per seq
return (
max_seqlen // (self.config.stripe_size * cp_size)
) + self.config.max_segments_per_seq
def get_step_config(self) -> _FusedAttnConfig: def get_step_config(self) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call to fused attention.""" """Returns a _FusedAttnConfig for single CP step call to fused attention."""
return _FusedAttnConfig( return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type, attn_bias_type=self.config.attn_bias_type,
attn_mask_type=self.get_adjusted_mask(), attn_mask_type=self.get_adjusted_mask(),
softmax_type=self.config.softmax_type,
qkv_layout=self.config.qkv_layout, qkv_layout=self.config.qkv_layout,
scaling_factor=self.config.scaling_factor, scaling_factor=self.config.scaling_factor,
dropout_probability=self.config.dropout_probability, dropout_probability=self.config.dropout_probability,
...@@ -1249,10 +1369,29 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1249,10 +1369,29 @@ class _FusedAttnCPWithAllGatherHelper:
context_parallel_load_balanced=self.config.context_parallel_load_balanced, context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis, cp_axis=self.config.cp_axis,
cp_striped_window_size=None, cp_striped_window_size=None,
stripe_size=self.config.stripe_size,
)
def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call (made via a striped AG primitive) to fused attention."""
return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type,
attn_mask_type=self.get_adjusted_mask(),
softmax_type=self.config.softmax_type,
qkv_layout=self.config.qkv_layout,
scaling_factor=self.config.scaling_factor,
dropout_probability=self.config.dropout_probability,
is_training=self.config.is_training,
max_segments_per_seq=self.get_adjusted_max_segments_per_seq(max_seqlen, cp_size),
window_size=self.config.window_size,
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
stripe_size=self.config.stripe_size,
) )
def all_gather_kv(self, k, v): def all_gather_kv(self, k, v):
"""Performs a all-gather of k and v over context parallel ranks.""" """Performs an all-gather of k and v over context parallel ranks."""
def ag(x): def ag(x):
x = lax_paral_op( x = lax_paral_op(
...@@ -1260,6 +1399,9 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1260,6 +1399,9 @@ class _FusedAttnCPWithAllGatherHelper:
) )
if self.config.context_parallel_load_balanced: if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
if self.config.qkv_layout.is_thd():
x = reorder_causal_striped(x, cp_size, 1, True, self.config.stripe_size)
else:
x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True) x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True)
return x return x
...@@ -1270,12 +1412,35 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1270,12 +1412,35 @@ class _FusedAttnCPWithAllGatherHelper:
return k, v # fall through return k, v # fall through
def all_gather_segment_ids_and_pos(self, kv_segment_ids, kv_segment_pos):
"""Performs an all-gather of kv segment ids and kv segment pos over context parallel ranks."""
kv_segment_ids = lax_paral_op(
kv_segment_ids, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
)
kv_segment_pos = lax_paral_op(
kv_segment_pos, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
)
if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
if self.config.qkv_layout.is_thd():
kv_segment_ids_ag = reorder_causal_striped(
kv_segment_ids, cp_size, 1, True, self.config.stripe_size
)
kv_segment_pos_ag = reorder_causal_striped(
kv_segment_pos, cp_size, 1, True, self.config.stripe_size
)
return kv_segment_ids_ag, kv_segment_pos_ag
return kv_segment_ids, kv_segment_pos # fall through
def reduce_scatter_dkv(self, dk, dv): def reduce_scatter_dkv(self, dk, dv):
"""Performs a reduce-scatter of dk and dv over context parallel ranks.""" """Performs a reduce-scatter of dk and dv over context parallel ranks."""
def rs(x): def rs(x):
if self.config.context_parallel_load_balanced: if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
if self.config.qkv_layout.is_thd():
x = reorder_causal_striped(x, cp_size, 1, False, self.config.stripe_size)
else:
x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False) x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False)
return lax_paral_op( return lax_paral_op(
...@@ -1349,6 +1514,227 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1349,6 +1514,227 @@ class _FusedAttnCPWithAllGatherHelper:
return dk, dv # fall through return dk, dv # fall through
# Below are the sharded post AG q seg ids and pos for a given rank:
# q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
# q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
# max_segments_per_seq = 7
# Below are some intermediate representations:
# non_zero_indices = [[ 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1]]
# segment_changes = [[ True, False, False, False, True, False, False, False, True, False, False, False, True, True, True, True]]
# seqlens_pre = [[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0]]
# seqlens_all_pad_neg = [[ 4, 4, 4, -1, -1, -1, -1]]
def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq):
"""Extract the q seqlens for striped primitive (post AG) from the sharded q seg ids and seg pos"""
# Create mask for non-zero seg ids and get the non-zero indices associated with the same
non_zero_mask = q_segment_ids != 0
max_size = q_segment_ids.shape[-1]
non_zero_indices = jax.vmap(
lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0]
)(non_zero_mask)
# Pick non-zero seg ids and seg pos using take_along_axis to index within the seg ids and pos
# Clip -1 to 0 for safe indexing
clipped_indices = jnp.clip(non_zero_indices, 0, None)
valid_segment_ids = jnp.where(
non_zero_indices >= 0, jnp.take_along_axis(q_segment_ids, clipped_indices, axis=-1), 0
)
valid_segment_pos = jnp.where(
non_zero_indices >= 0, jnp.take_along_axis(q_segment_pos, clipped_indices, axis=-1), 0
)
# Create a mask for actual valid entries (not padding)
actual_valid = valid_segment_ids != 0
# First element is True only if it's actually valid
first_is_segment = actual_valid[..., 0:1]
# Detect segment breaks in the valid tokens only (not full seq)
# Padding will always be true as the segment change condition is being applied
# on the valid segments (which have padding at the end so they'll always trigger True)
segment_changes = jnp.concatenate(
[
first_is_segment, # First valid element starts a segment
(valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1])
| (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1),
],
axis=-1,
)
new_segment_ids = jnp.cumsum(segment_changes, axis=-1)
seqlens_pre = jax.vmap(
lambda av_row, nsi_row: jnp.where(av_row, nsi_row, 0).astype(jnp.int32)
)(actual_valid, new_segment_ids)
seqlens_all = jax.vmap(
lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:]
)(seqlens_pre)
seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all)
return seqlens_all_pad_neg
# Below are the sharded post AG q seg ids and pos for a given rank:
# q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
# q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
# max_segments_per_seq = 7
# Below are some intermediate representations:
# segment_changes = [[ True, False, False, False, True, False, False, False, True, False, False, False, True, False, False, False]]
# segment_changes_masked = [[ True, False, False, False, False, False, False, False, True, False, False, False, True, False, False, False]]
# seq_offsets = [[ 0, 8, 12, -1, -1, -1, -1, -1]]
def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq):
"""Extract the q seqoffets for striped primitive (post AG) from the sharded q seg ids and seg pos"""
segment_changes = jnp.concatenate(
[
jnp.full(
(q_segment_pos.shape[0], 1), True, dtype=bool
), # First valid element starts a segment
(q_segment_pos[..., 1:] != q_segment_pos[..., :-1] + 1), # Segment pos changed
],
axis=-1,
)
# Remove any padded region segment changes
segment_changes_masked = jnp.where(q_segment_ids != 0, segment_changes, False)
# Get the indices for segment changes (these are the offsets)
seq_offsets = jax.vmap(
lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq, fill_value=-1)[0]
)(segment_changes_masked)
return seq_offsets
# Below are the sharded post AG q seg ids and pos for a given rank:
# kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
# kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
# max_segments_per_seq = 7
# Below are some intermediate representations:
# non_zero_mask = [[ True, True, True, True, False, False, False, False, True, True, True, True, True, True, True, True]]
# non_zero_indices = [[ 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1]]
# segment_changes = [[False, False, False, True, False, False, False, True, False, False, False, True, True, True, True, False]]
# selected_values = [[ 4, 15, 31, -1, -1, -1, -1, -1]]
def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_segments_per_seq):
"""Extract the kv seqlens for striped primitive (post AG) from the sharded kv seg ids and seg pos"""
# Create mask for non-zero seg ids and get the non-zero indices associated with the same
non_zero_mask = kv_segment_ids != 0
max_size = kv_segment_ids.shape[-1]
non_zero_indices = jax.vmap(
lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0]
)(non_zero_mask)
# Pick non zero seg ids and seg pos using take_along_axis
# Clip -1 to 0 for safe indexing
clipped_indices = jnp.clip(non_zero_indices, 0, None)
valid_segment_ids = jnp.where(
non_zero_indices >= 0, jnp.take_along_axis(kv_segment_ids, clipped_indices, axis=-1), 0
)
valid_segment_pos = jnp.where(
non_zero_indices >= 0, jnp.take_along_axis(kv_segment_pos, clipped_indices, axis=-1), 0
)
actual_valid = valid_segment_ids != 0
# Detect segment breaks (only for non-zero segments)
segment_changes = jnp.concatenate(
[
(
(valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1])
& actual_valid[..., 1:]
)
| (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1),
actual_valid[..., -1:],
],
axis=-1,
)
# Get the indices for segment changes
segment_changes_valid = jax.vmap(
lambda sc_row, av_row: jnp.where(
sc_row & av_row, size=max_segments_per_seq, fill_value=-1
)[0]
)(segment_changes, actual_valid)
safe_indices = jnp.maximum(segment_changes_valid, 0)
# Select values using take_along_axis per row
selected_values = jnp.where(
segment_changes_valid >= 0,
jnp.take_along_axis(valid_segment_pos, safe_indices, axis=-1) + 1,
-1,
)
return selected_values
# Below are the sharded post AG q seg ids and pos for a given rank:
# kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
# kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
# kv_segment_ids_ag = [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
# 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
# kv_segment_pos_ag = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
# 18, 19, 20, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
# 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
# max_segments_per_seq = 7
# Below are some intermediate representations:
# segment_changes_first_true_masked = [[ True, False, False, False, False, False, False, False, True,
# False, False, False, True, False, False, False]]
# segment_changes_indices = [[ 0, 8, 12, -1, -1, -1, -1, -1, -1]]
# segment_ids = [[ 1, 2, 2, -1, -1, -1, -1, -1, -1]]
# segment_changes_ag_first_true_masked = [[ True, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False,
# False, False, False, True, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False,
# False]
# segment_changes_ag_indices = [[ 0, 21, -1, -1, -1, -1, -1, -1, -1]]
# seq_offsets = [[ 0, 21, 21, -1, -1, -1, -1, -1, -1]]
def kv_seqoffsets_for_striped_for_rank(
self,
kv_segment_pos,
kv_segment_ids,
kv_segment_pos_ag,
kv_segment_ids_ag,
max_segments_per_seq,
):
"""Extract the kv seqoffsets for striped primitive (post AG) from the sharded kv seg ids and seg pos,
AG kv seg ids and seg pos."""
# Calculate the segment pos change mask
segment_changes_first_true = jnp.concatenate(
[
jnp.full(
(kv_segment_pos.shape[0], 1), True, dtype=bool
), # Assume valid element starts a segment and mask afterwards
(kv_segment_pos[..., 1:] != kv_segment_pos[..., :-1] + 1), # Segment pos changed
],
axis=-1,
)
segment_changes_first_true_masked = jnp.where(
kv_segment_ids != 0, segment_changes_first_true, False
)
# Get segment change indices for rank
segment_changes_indices = jax.vmap(
lambda sc_row: jnp.where(sc_row, size=max_segments_per_seq, fill_value=-1)[0]
)(segment_changes_first_true_masked)
# Get segment ids associated with the segment_changes_indices for rank
segment_ids = jax.vmap(
lambda sci_row, ksi_row: jnp.where(sci_row >= 0, ksi_row[sci_row], -1)
)(segment_changes_indices, kv_segment_ids)
# Get segment change indices for AG
segment_changes_ag_first_true = jnp.concatenate(
[
jnp.full(
(kv_segment_pos.shape[0], 1), True, dtype=bool
), # Assume valid element starts a segment and mask afterwards
(
kv_segment_pos_ag[..., 1:] != kv_segment_pos_ag[..., :-1] + 1
), # Segment pos changed
],
axis=-1,
)
segment_changes_ag_first_true_masked = jnp.where(
kv_segment_ids_ag != 0, segment_changes_ag_first_true, False
)
# Get segment change indices for AG
segment_changes_ag_indices = jax.vmap(
lambda scag_row: jnp.where(scag_row, size=max_segments_per_seq, fill_value=-1)[0]
)(segment_changes_ag_first_true_masked)
# Use the segment ids picked per rank to get the offsets from the AG indices
seq_offsets = jax.vmap(
lambda si_row, sca_row: jnp.where(si_row > 0, sca_row[si_row - 1], -1)
)(segment_ids, segment_changes_ag_indices)
return seq_offsets
class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
""" """
...@@ -1376,7 +1762,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1376,7 +1762,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
mesh, PartitionSpec(get_all_mesh_axes(), None) mesh, PartitionSpec(get_all_mesh_axes(), None)
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[5] = seed_sharding
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
...@@ -1385,6 +1771,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1385,6 +1771,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
seed, seed,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
...@@ -1404,7 +1791,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1404,7 +1791,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
# meeting the expectation of the SPMD model. # meeting the expectation of the SPMD model.
# TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding # TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
# mask/sequence length tensor to avoid this unrolled loop. # mask/sequence length tensor to avoid this unrolled loop.
def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed): def _cross_attn(idx, q, k, v, bias, softmax_offset, q_seqlen, kv_seqlen, seed):
kv_max_seqlen = k.shape[1] kv_max_seqlen = k.shape[1]
kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size" assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"
...@@ -1425,12 +1812,12 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1425,12 +1812,12 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
q_seqlen_for_step = q_seqlen / (cp_size * 2) q_seqlen_for_step = q_seqlen / (cp_size * 2)
num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks
output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl( output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl(
q_split[sub_idx], q_split[sub_idx],
k_unmasked, k_unmasked,
v_unmasked, v_unmasked,
bias, bias,
softmax_offset,
seed, seed,
q_seqlen_for_step, q_seqlen_for_step,
kv_seqlen_for_step, kv_seqlen_for_step,
...@@ -1453,7 +1840,9 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1453,7 +1840,9 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k_ag, v_ag = helper.all_gather_kv(k, v) k_ag, v_ag = helper.all_gather_kv(k, v)
functions = [ functions = [
partial(_cross_attn, idx, q, k_ag, v_ag, bias, q_seqlen, kv_seqlen, seed) partial(
_cross_attn, idx, q, k_ag, v_ag, bias, softmax_offset, q_seqlen, kv_seqlen, seed
)
for idx in range(cp_size) for idx in range(cp_size)
] ]
...@@ -1492,18 +1881,27 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1492,18 +1881,27 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2]) v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3]) bias_spec = get_padded_spec(arg_infos[3])
softmax_offset_spec = get_padded_spec(arg_infos[4])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) out_shardings = (
dq_sharding,
dk_sharding,
dv_sharding,
dbias_sharding,
dsoftmax_offset_sharding,
)
def impl( def impl(
q, q,
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1527,6 +1925,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1527,6 +1925,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k, k,
v, v,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1562,11 +1961,12 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1562,11 +1961,12 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx] num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks
dq_local, dk_local, dv_local, dbias_local = FusedAttnBwdPrimitive.impl( dq_local, dk_local, dv_local, dbias_local, _ = FusedAttnBwdPrimitive.impl(
q_split[sub_idx], q_split[sub_idx],
k_unmasked, k_unmasked,
v_unmasked, v_unmasked,
bias, bias,
softmax_offset,
softmax_aux_split[sub_idx], softmax_aux_split[sub_idx],
rng_state, rng_state,
output_split[sub_idx], output_split[sub_idx],
...@@ -1604,6 +2004,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1604,6 +2004,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k_ag, k_ag,
v_ag, v_ag,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -1621,7 +2022,9 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1621,7 +2022,9 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions) dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions)
dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local) dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local)
return dq, dk, dv, dbias # Return dummy dsoftmax_offset for arity matching (all-gather CP doesn't use it)
dummy_dsoftmax_offset = jnp.empty_like(softmax_offset)
return dq, dk, dv, dbias, dummy_dsoftmax_offset
return mesh, impl, out_shardings, arg_shardings return mesh, impl, out_shardings, arg_shardings
...@@ -1629,6 +2032,314 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1629,6 +2032,314 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
register_primitive(FusedAttnCPWithAllGatherBwdPrimitive) register_primitive(FusedAttnCPWithAllGatherBwdPrimitive)
class FusedAttnCPStripedWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
"""
Fused Attention Forward with Context Parallelism and Striped Load Balancing Primitive
This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks.
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
if not is_context_parallel:
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
helper.check_supported()
out_sharding = result_infos[0].sharding
softmax_aux_sharding = result_infos[1].sharding
rng_state_sharding = seed_sharding = NamedSharding(
mesh, PartitionSpec(get_all_mesh_axes(), None)
)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[5] = seed_sharding
arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
def impl(
q,
k,
v,
bias,
softmax_offset,
seed,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
): # pylint: disable=unused-argument
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
# cuDNN does not support right-aligned masking with dynamic sequence length padding.
# Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch
# to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor
# meeting the expectation of the SPMD model.
# TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
# mask/sequence length tensor to avoid this unrolled loop.
# Each rank receives the ag k and v along with the ag kv seg ids and kv seg offsets
# Each rank sees the sharded view for 5 tensors -> q, _q_segment_ids, _q_segment_pos,
# _kv_segment_ids, _kv_segment_pos -> Note these have also been reordered before passing in.
def _cross_attn(
q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segment_pos_ag, seed
):
# Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive
# Unset the segment_ids and segment_pos by passing placeholders so that the seqlens_from_segment_ids_pos()
# does not go down that route but instead just picks the pre-computed seqlens and offsets passed onto it
kv_max_seqlen = k.shape[1]
# Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq
adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq(
max_seqlen=kv_max_seqlen, cp_size=cp_size
)
q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank(
_q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq
)
q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank(
q_segment_ids=_q_segment_ids,
q_segment_pos=_q_segment_pos,
max_segments_per_seq=adjusted_max_segments_per_seq,
)
kv_seqlens_for_rank = helper.kv_seqlens_for_striped_for_rank(
kv_segment_ids=_kv_segment_ids,
kv_segment_pos=_kv_segment_pos,
max_segments_per_seq=adjusted_max_segments_per_seq,
)
kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(
kv_segment_pos=_kv_segment_pos,
kv_segment_ids=_kv_segment_ids,
kv_segment_pos_ag=kv_segment_pos_ag,
kv_segment_ids_ag=kv_segment_ids_ag,
max_segments_per_seq=adjusted_max_segments_per_seq,
)
output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl(
q, # sharded for rank
k, # ag
v, # ag
bias,
softmax_offset,
seed,
q_seqlens_for_rank,
kv_seqlens_for_rank,
q_seq_offsets_for_rank,
kv_seq_offsets_for_rank,
jnp.zeros(0),
jnp.zeros(0),
jnp.zeros(0),
jnp.zeros(0),
config=helper.get_step_config_for_striped(
max_seqlen=kv_max_seqlen, cp_size=cp_size
),
)
return output, softmax_aux, rng_state
# AG the k, v, kv_segment_ids and kv_segment_pos
k_ag, v_ag = helper.all_gather_kv(k, v)
_kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos(
_kv_segment_ids, _kv_segment_pos
)
functions = [
partial(
_cross_attn,
q,
k_ag,
v_ag,
bias,
softmax_offset,
_kv_segment_ids_ag,
_kv_segment_pos_ag,
seed,
)
for _ in range(cp_size)
]
return lax.switch(cp_rank, functions)
return mesh, impl, out_shardings, arg_shardings
register_primitive(FusedAttnCPStripedWithAllGatherFwdPrimitive)
class FusedAttnCPStripedWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
"""
Fused Attention Backward with Context Parallelism and Striped Load Balancing Primitive.
This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks.
The gradients are subsequently reduce-scattered back to each context parallel rank.
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
# Ensure we can support this configuration with context parallelism.
helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
helper.check_supported()
del result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
softmax_offset_spec = get_padded_spec(arg_infos[4])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (
dq_sharding,
dk_sharding,
dv_sharding,
dbias_sharding,
dsoftmax_offset_sharding,
)
def impl(
q,
k,
v,
bias,
softmax_offset,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
): # pylint: disable=unused-argument
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
# See comment in FusedAttnCPFwdPrimitive.partition for why we define this function.
def _cross_attn_bwd(
q,
k,
v,
bias,
softmax_offset,
softmax_aux,
rng_state,
output,
doutput,
_q_segment_ids,
kv_segment_ids_ag,
_q_segment_pos,
kv_segment_pos_ag,
):
# Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive
# Unset the segment_ids and segment_pos by passing placeholders so that the seqlens_from_segment_ids_pos()
# does not go down that route but instead just picks the pre-computed seqlens and offsets passed onto it
kv_max_seqlen = k.shape[1]
# Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq
adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq(
max_seqlen=kv_max_seqlen, cp_size=cp_size
)
q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank(
_q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq
)
q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank(
q_segment_ids=_q_segment_ids,
q_segment_pos=_q_segment_pos,
max_segments_per_seq=adjusted_max_segments_per_seq,
)
kv_seqlens_for_rank = helper.kv_seqlens_for_striped_for_rank(
kv_segment_ids=_kv_segment_ids,
kv_segment_pos=_kv_segment_pos,
max_segments_per_seq=adjusted_max_segments_per_seq,
)
kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(
kv_segment_pos=_kv_segment_pos,
kv_segment_ids=_kv_segment_ids,
kv_segment_pos_ag=kv_segment_pos_ag,
kv_segment_ids_ag=kv_segment_ids_ag,
max_segments_per_seq=adjusted_max_segments_per_seq,
)
dq_local, dk_local, dv_local, dbias_local, _ = FusedAttnBwdPrimitive.impl(
q, # sharded for rank
k, # ag
v, # ag
bias,
softmax_offset,
softmax_aux,
rng_state,
output,
doutput,
q_seqlens_for_rank,
kv_seqlens_for_rank,
q_seq_offsets_for_rank,
kv_seq_offsets_for_rank,
jnp.zeros(0),
jnp.zeros(0),
jnp.zeros(0),
jnp.zeros(0),
config=helper.get_step_config_for_striped(
max_seqlen=kv_max_seqlen, cp_size=cp_size
),
)
return dq_local, dk_local, dv_local, dbias_local
# AG the k, v, kv_segment_ids and kv_segment_pos
k_ag, v_ag = helper.all_gather_kv(k, v)
_kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos(
_kv_segment_ids, _kv_segment_pos
)
functions = [
partial(
_cross_attn_bwd,
q,
k_ag,
v_ag,
bias,
softmax_offset,
softmax_aux,
rng_state,
output,
doutput,
_q_segment_ids,
_kv_segment_ids_ag,
_q_segment_pos,
_kv_segment_pos_ag,
)
for _ in range(cp_size)
]
dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions)
# RS the dk and dv
dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local)
# Return dummy dsoftmax_offset for arity matching (all-gather CP doesn't use it)
dummy_dsoftmax_offset = jnp.empty_like(softmax_offset)
return dq, dk, dv, dbias, dummy_dsoftmax_offset
return mesh, impl, out_shardings, arg_shardings
register_primitive(FusedAttnCPStripedWithAllGatherBwdPrimitive)
@dataclass(frozen=True) @dataclass(frozen=True)
class _FusedAttnCPWithP2PHelper: class _FusedAttnCPWithP2PHelper:
"""Helper class to assist with running the P2P ring strategy for CP attention.""" """Helper class to assist with running the P2P ring strategy for CP attention."""
...@@ -1639,7 +2350,8 @@ class _FusedAttnCPWithP2PHelper: ...@@ -1639,7 +2350,8 @@ class _FusedAttnCPWithP2PHelper:
@staticmethod @staticmethod
def use_scanloop(): def use_scanloop():
"""Returns true if the implementation will use a scan loop for iteration.""" """Returns true if the implementation will use a scan loop for iteration."""
use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1"))) # TODO(KshitijLakhani): Reset default to 1, once the extra kv permute op issue is resolved
use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "0")))
return use_scan return use_scan
def check_supported(self): def check_supported(self):
...@@ -1679,13 +2391,20 @@ class _FusedAttnCPWithP2PHelper: ...@@ -1679,13 +2391,20 @@ class _FusedAttnCPWithP2PHelper:
if self.config.dropout_probability != 0.0: if self.config.dropout_probability != 0.0:
raise ValueError(f"{header} does not support dropout") raise ValueError(f"{header} does not support dropout")
# We want to encourage use of scan loop to minimize unrolling and ensure more if self.config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
# predictable scheduling from XLA. The unrolled flavor will be supported but raise ValueError(
# not the prefered implementation. f"{header} only supports VANILLA_SOFTMAX, got: {self.config.softmax_type}"
if not self.use_scanloop(): )
# TODO(KshitijLakhani): Flip the condition to check for disabled scan loop and warn
# against using unrolled loops once the scan issue is resolved.
# We want to discourage the use of scan loop as additional kv permute op observed.
# The scan loop flavor will be supported but not the prefered implementation until
# a resolution for the additional kv permute op, which degrades perf, is found.
if self.use_scanloop():
warnings.warn( warnings.warn(
"Scan loop is disabled for fused ring attention. To enable set" "Scan loop is enabled for fused ring attention. To disable set"
" NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment" " NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 in your environment"
) )
# If using scanloop, idx in scan_kv_block() will be a traced device value, but # If using scanloop, idx in scan_kv_block() will be a traced device value, but
...@@ -1703,6 +2422,7 @@ class _FusedAttnCPWithP2PHelper: ...@@ -1703,6 +2422,7 @@ class _FusedAttnCPWithP2PHelper:
return _FusedAttnConfig( return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type, attn_bias_type=self.config.attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
softmax_type=self.config.softmax_type,
qkv_layout=QKVLayout.BSHD_BS2HD, qkv_layout=QKVLayout.BSHD_BS2HD,
scaling_factor=self.config.scaling_factor, scaling_factor=self.config.scaling_factor,
dropout_probability=self.config.dropout_probability, dropout_probability=self.config.dropout_probability,
...@@ -1712,6 +2432,7 @@ class _FusedAttnCPWithP2PHelper: ...@@ -1712,6 +2432,7 @@ class _FusedAttnCPWithP2PHelper:
context_parallel_load_balanced=self.config.context_parallel_load_balanced, context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis, cp_axis=self.config.cp_axis,
cp_striped_window_size=None, cp_striped_window_size=None,
stripe_size=self.config.stripe_size,
) )
def stack_kv(self, k, v): def stack_kv(self, k, v):
...@@ -1783,7 +2504,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1783,7 +2504,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
mesh, PartitionSpec(get_all_mesh_axes(), None) mesh, PartitionSpec(get_all_mesh_axes(), None)
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[5] = seed_sharding
# Ensure segment_pos gets same sharding as ID. # Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
...@@ -1795,6 +2516,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1795,6 +2516,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
k, k,
v, v,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
...@@ -1840,6 +2562,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1840,6 +2562,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen_per_step, q_seqlen_per_step,
kv_seqlen_per_step, kv_seqlen_per_step,
...@@ -1865,6 +2588,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1865,6 +2588,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv_part, kv_part,
_not_used, _not_used,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen_per_step, q_seqlen_per_step,
kv_seqlen_per_step, kv_seqlen_per_step,
...@@ -1887,6 +2611,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1887,6 +2611,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen_per_step, q_seqlen_per_step,
kv_seqlen_per_step, kv_seqlen_per_step,
...@@ -1990,18 +2715,24 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1990,18 +2715,24 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
k_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2]) v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3]) bias_spec = get_padded_spec(arg_infos[3])
softmax_offset_spec = get_padded_spec(arg_infos[4])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
# Ring attention doesn't use dsoftmax_offset, but we need to return it for arity matching
dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
# Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
out_shardings = (
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) dq_sharding,
dk_sharding,
dv_sharding,
dbias_sharding,
dsoftmax_offset_sharding,
)
helper = _FusedAttnCPWithP2PHelper(mesh, config) helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported() helper.check_supported()
...@@ -2011,6 +2742,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2011,6 +2742,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
k, k,
v, v,
bias, bias,
_softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2054,11 +2786,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2054,11 +2786,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
def mask_compute(attn_mask_type): def mask_compute(attn_mask_type):
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( dq_per_step, dk_dv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
q, q,
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2082,11 +2815,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2082,11 +2815,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2 kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
kv_part = lax.slice_in_dim(kv, 0, kv_max_seqlen // 2, axis=1) kv_part = lax.slice_in_dim(kv, 0, kv_max_seqlen // 2, axis=1)
dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( dq_per_step, dk_dv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
q, q,
kv_part, kv_part,
_not_used, _not_used,
bias, bias,
_softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2120,11 +2854,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2120,11 +2854,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
softmax_aux, q_max_seqlen // 2, q_max_seqlen, axis=2 softmax_aux, q_max_seqlen // 2, q_max_seqlen, axis=2
) )
dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( dq_per_step, dk_dv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
q_part, q_part,
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
softmax_aux_part, softmax_aux_part,
rng_state, rng_state,
output_part, output_part,
...@@ -2184,7 +2919,9 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2184,7 +2919,9 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh) global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)
dk, dv = helper.unstack_kv(dk_dv) dk, dv = helper.unstack_kv(dk_dv)
return dq, dk, dv, global_dbias # Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
dummy_dsoftmax_offset = jnp.empty_like(_softmax_offset)
return dq, dk, dv, global_dbias, dummy_dsoftmax_offset
return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings
...@@ -2273,7 +3010,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2273,7 +3010,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
mesh, PartitionSpec(get_all_mesh_axes(), None) mesh, PartitionSpec(get_all_mesh_axes(), None)
) )
arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding arg_shardings[5] = seed_sharding
# Ensure segment_pos gets same sharding as ID. # Ensure segment_pos gets same sharding as ID.
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
...@@ -2285,6 +3022,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2285,6 +3022,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
k, k,
v, v,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
...@@ -2336,6 +3074,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2336,6 +3074,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
seed, seed,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
...@@ -2345,7 +3084,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -2345,7 +3084,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
kv_segment_ids, kv_segment_ids,
q_segment_pos, q_segment_pos,
kv_segment_pos, kv_segment_pos,
config, config=config,
) )
if config.window_size != (-1, -1): if config.window_size != (-1, -1):
...@@ -2420,8 +3159,8 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2420,8 +3159,8 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
arg_shardings[-1] = arg_shardings[-3] arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4] arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings) arg_shardings = tuple(arg_shardings)
# dq, dk, dv, dbias sharding = q, k, v, bias sharding # dq, dk, dv, dbias, dsoftmax_offset sharding = q, k, v, bias, softmax_offset sharding
out_shardings = tuple(arg.sharding for arg in arg_infos[:4]) out_shardings = tuple(arg.sharding for arg in arg_infos[:5])
helper = _FusedAttnCPWithP2PHelper(mesh, config) helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported() helper.check_supported()
...@@ -2431,6 +3170,7 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2431,6 +3170,7 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
k, k,
v, v,
bias, bias,
_softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2478,11 +3218,12 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2478,11 +3218,12 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm) kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)
def compute(config): def compute(config):
dq_per_step, dkv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( dq_per_step, dkv_per_step, _, dbias_per_step, _ = FusedAttnBwdPrimitive.impl(
q, q,
kv, kv,
_not_used, _not_used,
bias, bias,
_softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2536,7 +3277,9 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2536,7 +3277,9 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh) global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)
dk, dv = helper.unstack_kv(dkv) dk, dv = helper.unstack_kv(dkv)
return dq, dk, dv, global_dbias # Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
dummy_dsoftmax_offset = jnp.empty_like(_softmax_offset)
return dq, dk, dv, global_dbias, dummy_dsoftmax_offset
return mesh, bwd_impl, out_shardings, arg_shardings return mesh, bwd_impl, out_shardings, arg_shardings
...@@ -2545,7 +3288,7 @@ register_primitive(FusedRingAttnStripedBwdPrimitive) ...@@ -2545,7 +3288,7 @@ register_primitive(FusedRingAttnStripedBwdPrimitive)
def _maybe_context_parallel_axis(cp_axis: str): def _maybe_context_parallel_axis(cp_axis: str):
if not cp_axis: if not cp_axis and is_mesh_available():
gmr = global_mesh_resource() gmr = global_mesh_resource()
if gmr is not None: if gmr is not None:
cp_axis = gmr.cp_resource cp_axis = gmr.cp_resource
...@@ -2557,10 +3300,12 @@ def _maybe_context_parallel_axis(cp_axis: str): ...@@ -2557,10 +3300,12 @@ def _maybe_context_parallel_axis(cp_axis: str):
def fused_attn_fwd( def fused_attn_fwd(
qkv: Tuple[jnp.ndarray, ...], qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray], bias: Optional[jnp.ndarray],
softmax_offset: Optional[jnp.ndarray],
sequence_descriptor: SequenceDescriptor, sequence_descriptor: SequenceDescriptor,
seed: Optional[jnp.ndarray], seed: Optional[jnp.ndarray],
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
softmax_type: AttnSoftmaxType,
qkv_layout: QKVLayout, qkv_layout: QKVLayout,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
...@@ -2570,6 +3315,7 @@ def fused_attn_fwd( ...@@ -2570,6 +3315,7 @@ def fused_attn_fwd(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
stripe_size: int | None = None,
) -> jnp.ndarray: ) -> jnp.ndarray:
""" """
Perform the forward pass of with cuDNN fused attention implementations. Perform the forward pass of with cuDNN fused attention implementations.
...@@ -2585,6 +3331,7 @@ def fused_attn_fwd( ...@@ -2585,6 +3331,7 @@ def fused_attn_fwd(
query has a different shape (e.g., cross-attention). query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors. - `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores. bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,]. q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,]. kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
q_seq_offsets (jnp.ndarray): q_seq_offsets (jnp.ndarray):
...@@ -2594,6 +3341,7 @@ def fused_attn_fwd( ...@@ -2594,6 +3341,7 @@ def fused_attn_fwd(
seed (Optional[jnp.ndarray]): Optional random seed for dropout. seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
...@@ -2606,6 +3354,7 @@ def fused_attn_fwd( ...@@ -2606,6 +3354,7 @@ def fused_attn_fwd(
context_parallel_causal_load_balanced (bool): context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
stripe_size (int | None): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing
Returns: Returns:
(jnp.ndarray): The output tensor from the fused attention. (jnp.ndarray): The output tensor from the fused attention.
""" """
...@@ -2633,10 +3382,36 @@ def fused_attn_fwd( ...@@ -2633,10 +3382,36 @@ def fused_attn_fwd(
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype) bias = jnp.zeros(0, dtype=qkv[0].dtype)
if softmax_offset is None:
assert (
softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX
), f"Softmax type {softmax_type} is not supported when softmax_offset is None"
if softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
num_heads = qkv[0].shape[-2]
# Create tensor [1, h, 1, 1] filled with zeros (logit value = 0)
# This adds exp(0 - x_max) = exp(-x_max) to the denominator,
# which contributes exactly 1 after normalization, giving: exp(x_i) / (sum(exp(x_j)) + 1)
softmax_offset = jnp.zeros((1, num_heads, 1, 1), dtype=jnp.float32)
# Shard by heads dimension
softmax_offset = with_sharding_constraint_by_logical_axes(
softmax_offset, (None, HEAD_AXES, None, None)
)
else:
assert softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX
softmax_offset = jnp.zeros(0, dtype=jnp.float32)
else:
assert softmax_offset.dtype == jnp.float32
# Shard by heads dimension if not VANILLA_SOFTMAX
if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
softmax_offset = with_sharding_constraint_by_logical_axes(
softmax_offset, (None, HEAD_AXES, None, None)
)
fused_config = _FusedAttnConfig( fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
softmax_type=softmax_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
...@@ -2645,11 +3420,15 @@ def fused_attn_fwd( ...@@ -2645,11 +3420,15 @@ def fused_attn_fwd(
context_parallel_load_balanced=context_parallel_causal_load_balanced, context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None, cp_striped_window_size=None,
stripe_size=stripe_size,
) )
primitive = None primitive = None
match context_parallel_strategy: match context_parallel_strategy:
case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
if qkv_layout.is_thd():
primitive = FusedAttnCPStripedWithAllGatherFwdPrimitive.outer_primitive
else:
primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive
case CPStrategy.RING: case CPStrategy.RING:
# We must use stripe attention for THD-RING # We must use stripe attention for THD-RING
...@@ -2662,6 +3441,7 @@ def fused_attn_fwd( ...@@ -2662,6 +3441,7 @@ def fused_attn_fwd(
output, softmax_aux, rng_state = primitive.bind( output, softmax_aux, rng_state = primitive.bind(
*qkv_for_primitive, *qkv_for_primitive,
bias, bias,
softmax_offset,
seed, seed,
*seq_desc_flatten, *seq_desc_flatten,
config=fused_config, config=fused_config,
...@@ -2673,6 +3453,7 @@ def fused_attn_fwd( ...@@ -2673,6 +3453,7 @@ def fused_attn_fwd(
def fused_attn_bwd( def fused_attn_bwd(
qkv: Tuple[jnp.ndarray, ...], qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray], bias: Optional[jnp.ndarray],
softmax_offset: Optional[jnp.ndarray],
softmax_aux: jnp.ndarray, softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray, rng_state: jnp.ndarray,
output: jnp.ndarray, output: jnp.ndarray,
...@@ -2681,6 +3462,7 @@ def fused_attn_bwd( ...@@ -2681,6 +3462,7 @@ def fused_attn_bwd(
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout, qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
...@@ -2689,6 +3471,7 @@ def fused_attn_bwd( ...@@ -2689,6 +3471,7 @@ def fused_attn_bwd(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
stripe_size: int | None = None,
): ):
""" """
Perform the backward pass of the cuDNN fused attention implementations. Perform the backward pass of the cuDNN fused attention implementations.
...@@ -2702,6 +3485,7 @@ def fused_attn_bwd( ...@@ -2702,6 +3485,7 @@ def fused_attn_bwd(
query has a different shape (e.g., cross-attention). query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors. - `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores. bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass. softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass. rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
output (jnp.ndarray): The output tensor from the forward pass. output (jnp.ndarray): The output tensor from the forward pass.
...@@ -2714,6 +3498,7 @@ def fused_attn_bwd( ...@@ -2714,6 +3498,7 @@ def fused_attn_bwd(
The offsets in the sequence dim for the query, with shape [batch + 1,]. The offsets in the sequence dim for the query, with shape [batch + 1,].
attn_bias_type (AttnBiasType): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
...@@ -2726,6 +3511,7 @@ def fused_attn_bwd( ...@@ -2726,6 +3511,7 @@ def fused_attn_bwd(
context_parallel_causal_load_balanced (bool): context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
stripe_size (int | None): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing
Returns: Returns:
Tuple[jnp.ndarray, ...], jnp.ndarray: Tuple[jnp.ndarray, ...], jnp.ndarray:
- The first tuple contains the gradients with respect to the input `qkv` tensors in the - The first tuple contains the gradients with respect to the input `qkv` tensors in the
...@@ -2755,6 +3541,28 @@ def fused_attn_bwd( ...@@ -2755,6 +3541,28 @@ def fused_attn_bwd(
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype) bias = jnp.zeros(0, dtype=qkv[0].dtype)
if softmax_offset is None:
assert softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX, f"Unknown {softmax_type=}"
if softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
num_heads = qkv[0].shape[-2]
# Create tensor [1, h, 1, 1] filled with zeros
softmax_offset = jnp.zeros((1, num_heads, 1, 1), dtype=jnp.float32)
# Shard by heads dimension
softmax_offset = with_sharding_constraint_by_logical_axes(
softmax_offset, (None, HEAD_AXES, None, None)
)
elif softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX:
softmax_offset = jnp.zeros(0, dtype=jnp.float32)
else:
raise NotImplementedError(f"Unknown {softmax_type=}")
else:
softmax_offset = softmax_offset.astype(jnp.float32)
# Shard by heads dimension if not VANILLA_SOFTMAX
if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
softmax_offset = with_sharding_constraint_by_logical_axes(
softmax_offset, (None, HEAD_AXES, None, None)
)
# TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on # TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# sm100+ # sm100+
compute_capabilities = get_all_device_compute_capability() compute_capabilities = get_all_device_compute_capability()
...@@ -2767,6 +3575,7 @@ def fused_attn_bwd( ...@@ -2767,6 +3575,7 @@ def fused_attn_bwd(
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
softmax_type=softmax_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
...@@ -2775,11 +3584,15 @@ def fused_attn_bwd( ...@@ -2775,11 +3584,15 @@ def fused_attn_bwd(
context_parallel_load_balanced=context_parallel_causal_load_balanced, context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None, cp_striped_window_size=None,
stripe_size=stripe_size,
) )
primitive = None primitive = None
match context_parallel_strategy: match context_parallel_strategy:
case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
if qkv_layout.is_thd():
primitive = FusedAttnCPStripedWithAllGatherBwdPrimitive.outer_primitive
else:
primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive
case CPStrategy.RING: case CPStrategy.RING:
if qkv_layout.is_thd(): if qkv_layout.is_thd():
...@@ -2788,9 +3601,10 @@ def fused_attn_bwd( ...@@ -2788,9 +3601,10 @@ def fused_attn_bwd(
primitive = FusedRingAttnBwdPrimitive.outer_primitive primitive = FusedRingAttnBwdPrimitive.outer_primitive
seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
*qkv_grads, bias_grad = primitive.bind( *qkv_grads, bias_grad, softmax_offset_grad = primitive.bind(
*qkv_for_primitive, *qkv_for_primitive,
bias, bias,
softmax_offset,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
...@@ -2798,4 +3612,4 @@ def fused_attn_bwd( ...@@ -2798,4 +3612,4 @@ def fused_attn_bwd(
*seq_desc_flatten, *seq_desc_flatten,
config=fused_config, config=fused_config,
) )
return tuple(qkv_grads[: len(qkv)]), bias_grad return tuple(qkv_grads[: len(qkv)]), bias_grad, softmax_offset_grad
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE base custom ops""" """JAX/TE base custom ops"""
...@@ -176,6 +176,9 @@ _primitive_registry = {} ...@@ -176,6 +176,9 @@ _primitive_registry = {}
def register_primitive(cls, outer_only=False): def register_primitive(cls, outer_only=False):
""" """
Register a JAX primitive and add it to the internal registry. Register a JAX primitive and add it to the internal registry.
Inner primitive - single device, no sharding awareness, eager mode fallback
Outer primitive - multi device, sharding aware, partition() distributes work,
used when there's a dev mesh context
""" """
_primitive_registry[cls.__name__] = cls _primitive_registry[cls.__name__] = cls
...@@ -190,14 +193,17 @@ def register_primitive(cls, outer_only=False): ...@@ -190,14 +193,17 @@ def register_primitive(cls, outer_only=False):
inner_p = core.Primitive(cls.name) inner_p = core.Primitive(cls.name)
dispatch.prim_requires_devices_during_lowering.add(inner_p) dispatch.prim_requires_devices_during_lowering.add(inner_p)
inner_p.multiple_results = cls.multiple_results inner_p.multiple_results = cls.multiple_results
# Define eager execution implementation (by invoking it's MLIR lowering)
inner_p.def_impl(partial(xla.apply_primitive, inner_p)) inner_p.def_impl(partial(xla.apply_primitive, inner_p))
inner_p.def_abstract_eval(cls.abstract) inner_p.def_abstract_eval(cls.abstract)
mlir.register_lowering(inner_p, cls.lowering, platform="cuda") mlir.register_lowering(inner_p, cls.lowering, platform="cuda")
cls.inner_primitive = inner_p cls.inner_primitive = inner_p
# Create the outer primitive for distributed execution
outer_p = core.Primitive(name_of_wrapper_p()) outer_p = core.Primitive(name_of_wrapper_p())
dispatch.prim_requires_devices_during_lowering.add(outer_p) dispatch.prim_requires_devices_during_lowering.add(outer_p)
outer_p.multiple_results = cls.multiple_results outer_p.multiple_results = cls.multiple_results
# Define the eager execution implementation
outer_p.def_impl(cls.outer_impl) outer_p.def_impl(cls.outer_impl)
outer_p.def_abstract_eval(cls.outer_abstract) outer_p.def_abstract_eval(cls.outer_abstract)
batching.primitive_batchers[outer_p] = cls.batcher batching.primitive_batchers[outer_p] = cls.batcher
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX te modules""" """JAX te modules"""
...@@ -39,12 +39,12 @@ from ..quantize import ( ...@@ -39,12 +39,12 @@ from ..quantize import (
Quantizer, Quantizer,
GroupedQuantizer, GroupedQuantizer,
QuantizerSet, QuantizerSet,
QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv, apply_padding_to_scale_inv,
get_quantize_config_with_recipe, get_quantize_config_with_recipe,
get_global_quantize_recipe, get_global_quantize_recipe,
QuantizeLayout,
) )
from .misc import get_padded_spec, is_all_reduce_in_float32 from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import ( from ..sharding import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment