Unverified Commit 9101a78f authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[JAX] Context Parallel Attention with All-Gather (#1106)



Implementation of context parallel fused attention using all-gather.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent d2d4cf91
......@@ -4,6 +4,8 @@
import operator
import re
from functools import reduce
from itertools import product
import pytest
import jax
from jax.experimental.pjit import pjit, _UNSPECIFIED
......@@ -29,6 +31,28 @@ def generate_configs():
return configs
def generate_context_parallel_configs():
configs = []
DP_sizes = (1, 2)
CP_sizes = (1, 2, 4, 8)
TP_sizes = (1, 2)
for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes):
ndev = cp * tp * dp
if is_devices_enough(ndev):
configs.append(
pytest.param(
ndev,
(dp, cp, tp),
("dp", "cp", "tp"),
MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"),
id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}",
)
)
return configs
COLL_AR_KEY = "all-reduce"
COLL_AG_KEY = "all-gather"
COLL_OTHER_KEY = "other"
......
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
import pytest
from functools import partial
import jax
import jax.numpy as jnp
......@@ -10,8 +11,13 @@ import numpy as np
from flax.linen import dot_product_attention
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count, compare_ops
from utils import make_causal_mask, make_self_mask
from distributed_test_base import (
generate_configs,
generate_context_parallel_configs,
generate_collectives_count,
compare_ops,
)
from utils import make_causal_mask, make_self_mask, assert_tree_like_allclose, assert_allclose
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
......@@ -19,6 +25,10 @@ from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
get_qkv_format,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
)
......@@ -263,7 +273,8 @@ class TestDistributedCrossAttn:
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
)
),
dtype=jnp.float32,
)
def ref_func(query, kv, mask):
......@@ -284,7 +295,7 @@ class TestDistributedCrossAttn:
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype)
return jnp.mean(output, dtype=jnp.float32)
(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, attn_mask_type, dtype
......@@ -310,3 +321,229 @@ class TestDistributedCrossAttn:
in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec)),
)
class TestDistributedContexParallelSelfAttn:
def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
batch, seqlen, heads, hidden = shape
qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3)
q = random.normal(qkey, shape, dtype=dtype)
k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
mask = None
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_causal_mask(batch, seqlen)
return q, k, v, mask
def qkv_to_layout(self, q, k, v, qkv_layout):
qkv_args = ()
match qkv_layout:
case QKVLayout.BSHD_BS2HD:
k, v = map(partial(jnp.expand_dims, axis=-3), [k, v])
kv = jnp.concatenate((k, v), axis=-3)
qkv_args = (q, kv)
case QKVLayout.BSHD_BSHD_BSHD:
qkv_args = (q, k, v)
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
return qkv_args
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"load_balanced", [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")]
)
def test_contex_parallel_self_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
is_training = True
scaling_factor = 1.0
dp_size, cp_size, tp_size = mesh_shape
qkv_format = get_qkv_format(qkv_layout)
_, seqlen, num_head, hidden = data_shape
num_kv_heads = num_head // kv_groups
# make sure the mesh evently divides cp and tp axis
if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
def target_func(q, k, v, mask):
return jnp.mean(
fused_attn(
self.qkv_to_layout(q, k, v, qkv_layout),
bias=None,
mask=mask,
seed=None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
context_parallel_causal_load_balanced=load_balanced,
),
).astype(dtype)
def ref_func(q, k, v, mask, kv_groups):
q = jnp.squeeze(q)
k = jnp.squeeze(jnp.repeat(k, kv_groups, axis=2))
v = jnp.squeeze(jnp.repeat(v, kv_groups, axis=2))
output = dot_product_attention(
q,
k,
v,
bias=None,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype)
q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype)
# Single GPU (reference)
ref_func_jit = jax.jit(jax.value_and_grad(ref_func, argnums=[0, 1, 2]), static_argnums=[4])
ref_fwd, ref_grads = ref_func_jit(q, k, v, mask, kv_groups)
# Multi GPU (function under test)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
qkv_ps = PartitionSpec(
mesh_resource.dp_resource,
mesh_resource.cp_resource,
mesh_resource.tp_resource,
None,
)
qkv_sharding = NamedSharding(mesh, qkv_ps)
mask_ps = PartitionSpec(
mesh_resource.dp_resource, None, mesh_resource.cp_resource, None
)
mask_sharding = NamedSharding(mesh, mask_ps)
reorder = partial(
reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
)
inverse_reorder = partial(
inverse_reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
)
if load_balanced:
q, k, v = jax.tree.map(reorder, (q, k, v))
q_, k_, v_ = map(partial(jax.device_put, device=qkv_sharding), [q, k, v])
mask_ = jax.device_put(mask, device=mask_sharding)
target_func_jit = jax.jit(
jax.value_and_grad(target_func, argnums=[0, 1, 2]),
in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding],
out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)),
)
target_fwd, target_grads = target_func_jit(q_, k_, v_, mask_)
if load_balanced:
target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3])
target_grads = (target_dq, target_dk, target_dv, *target_grads[3:])
def _print_diffs(target, ref):
print("min: ", jnp.min(target), jnp.min(ref))
print("max: ", jnp.max(target), jnp.max(ref))
print("mean: ", jnp.mean(target), jnp.mean(ref))
print("median: ", jnp.median(target), jnp.median(ref))
print("std: ", jnp.std(target), jnp.std(ref))
print("var: ", jnp.var(target), jnp.var(ref))
print("max diff: ", jnp.max(jnp.abs(target - ref)))
has_diffs = False
try:
assert_allclose(target_fwd, ref_fwd, dtype=dtype)
except AssertionError as e:
has_diffs = True
print(f"target_fwd v. ref_fwd")
_print_diffs(target_fwd, ref_fwd)
for i in range(len(target_grads)):
if ref_grads[i] is None or target_grads[i] is None:
# expect both none if one is
assert target_grads[i] is None and ref_grads[i] is None
else:
try:
assert_allclose(target_grads[i], ref_grads[i])
except AssertionError as e:
has_diffs = True
print(f"target_grads[{i}] v. ref_grads[{i}]")
_print_diffs(target_grads[i], ref_grads[i])
assert has_diffs == False, "has_diffs != False"
class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8])
@pytest.mark.parametrize(
"shape",
[
pytest.param([1, 16, 1, 1], id="1-16-1-1"),
pytest.param([4, 32, 12, 32], id="4-32-12-32"),
pytest.param([3, 32, 8, 64], id="3-32-8-64"),
],
)
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
def test(self, cp_size, shape, qkv_format):
tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
if qkv_format == QKVFormat.SBHD:
tensor = tensor.swapaxes(0, 1)
ref = tensor.copy()
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2])
reordered = reorder(tensor, cp_size, qkv_format)
inversed = inverse(reordered, cp_size, qkv_format)
assert jnp.array_equal(inversed, ref)
......@@ -43,6 +43,8 @@ class AttnMaskType(Enum):
PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
class QKVLayout(Enum):
......@@ -97,11 +99,21 @@ def canonicalize_attn_mask_type(attn_mask_type: str):
return AttnMaskType.PADDING_MASK
case "causal":
return AttnMaskType.CAUSAL_MASK
case "causal_bottom_right" | "bottom_right_causal":
return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
case "padding_causal" | "causal_padding":
return AttnMaskType.PADDING_CAUSAL_MASK
case (
"padding_causal_bottom_right"
| "causal_padding_bottom_right"
| "bottom_right_causal_padding"
| "bottom_right_padding_causal"
):
return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK
raise ValueError(
f"Unsupported {attn_mask_type=}, supported attn_mask_type="
"{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
f"Unsupported {attn_mask_type=}, supported attn_mask_type={{'no_mask', 'padding', 'causal',"
" 'padding_causal', 'causal_padding', 'causal_bottom_right',"
" 'padding_causal_bottom_right'}"
)
......@@ -155,6 +167,75 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
return batch, q_max_seqlen, kv_max_seqlen
def _reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat, inverse: bool):
match tensor_format:
case QKVFormat.SBHD:
seq_dim = 0
case QKVFormat.BSHD:
seq_dim = 1
case _:
raise ValueError(f"{tensor_format=} is not supported for causal load balancing.")
if cp_size == 1:
return tensor
if cp_size % 2 != 0:
raise ValueError(f"{cp_size=} must be a multiple of 2.")
# Need to ensure we have 2 pairs to swap for balancing between cp ranks
if tensor.shape[seq_dim] % (cp_size * 2) != 0:
raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}")
# [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D]
# [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D]
ori_tensor_shape = tensor.shape
tensor = tensor.reshape(
(
*ori_tensor_shape[:seq_dim],
2 * cp_size,
ori_tensor_shape[seq_dim] // (2 * cp_size),
*ori_tensor_shape[seq_dim + 1 :],
)
)
parts = []
if not inverse:
for cp_rank in range(cp_size):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)])
parts.append(jnp.take(tensor, index, axis=seq_dim))
else:
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 4 * cp_rank
index = jnp.array([base, base + 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 2 * cp_size - 1 - 4 * cp_rank
index = jnp.array([base, base - 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D]
combined = jnp.stack(parts, axis=seq_dim)
return combined.reshape(ori_tensor_shape)
def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
"""Reorders a tensor for load balancing the compute of causal attention."""
return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, False)
def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
"""Inverse operation of `reorder_causal_load_balancing`."""
return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, True)
def fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
......@@ -166,6 +247,8 @@ def fused_attn(
scaling_factor: float,
dropout_probability: float,
is_training: bool,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
"""
Perform non-THD (non-packed) cuDNN fused attention.
......@@ -192,6 +275,9 @@ def fused_attn(
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
"""
......@@ -213,7 +299,11 @@ def fused_attn(
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
# convert the mask to seqlens, mask doesn't support ragged offsets
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
if attn_mask_type in [
AttnMaskType.NO_MASK,
AttnMaskType.CAUSAL_MASK,
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
]:
batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32)
kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32)
......@@ -242,6 +332,8 @@ def fused_attn(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=1,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
return output
......@@ -262,6 +354,8 @@ def fused_attn_thd(
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int = 1,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
"""
(Experimental) Perform THD (packed) cuDNN fused attention.
......@@ -300,6 +394,9 @@ def fused_attn_thd(
Indicating the maximum number of segments inside a sequence. This parameter is to
constrain the limit usage and need to be static during the e2e training. The XLA compile
time and memory consumption is proportional to `max_segments_per_seq`.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
......@@ -354,12 +451,14 @@ def fused_attn_thd(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
......@@ -375,6 +474,8 @@ def _fused_attn(
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int,
context_parallel_causal_load_balanced: bool,
context_parallel_axis: str,
):
output, _ = _fused_attn_fwd_rule(
qkv,
......@@ -391,6 +492,8 @@ def _fused_attn(
dropout_probability,
is_training,
max_segments_per_seq,
context_parallel_causal_load_balanced,
context_parallel_axis,
)
return output
......@@ -410,6 +513,8 @@ def _fused_attn_fwd_rule(
dropout_probability,
is_training,
max_segments_per_seq,
context_parallel_causal_load_balanced,
context_parallel_axis,
):
output, softmax_aux, rng_state = tex.fused_attn_fwd(
qkv,
......@@ -426,6 +531,8 @@ def _fused_attn_fwd_rule(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
output = checkpoint_name(output, "context")
softmax_aux = checkpoint_name(softmax_aux, "context")
......@@ -451,6 +558,8 @@ def _fused_attn_bwd_rule(
dropout_probability,
is_training,
max_segments_per_seq,
context_parallel_causal_load_balanced,
context_parallel_axis,
ctx,
dz,
):
......@@ -483,6 +592,8 @@ def _fused_attn_bwd_rule(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
......
......@@ -9,8 +9,9 @@ import os
from typing import Optional, Tuple
import warnings
import jax
import jax.numpy as jnp
from jax import dtypes
from jax import dtypes, lax
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
......@@ -34,7 +35,11 @@ from .misc import (
get_cudnn_version,
)
from ..sharding import (
global_mesh_resource,
lax_paral_op,
all_reduce_sum_along_dp_fsdp,
get_mesh_axis_size,
get_mesh_axis_rank,
get_all_mesh_axes,
num_of_devices,
)
......@@ -47,6 +52,38 @@ __all__ = [
]
@partial(
jax.tree_util.register_dataclass,
data_fields=[],
meta_fields=[
"attn_bias_type",
"attn_mask_type",
"qkv_layout",
"scaling_factor",
"dropout_probability",
"is_training",
"max_segments_per_seq",
"context_parallel_load_balanced",
"cp_axis",
],
)
@dataclass(frozen=True)
class _FusedAttnConfig:
"""
Passes static configuration properties of fused attention.
"""
attn_bias_type: NVTE_Bias_Type
attn_mask_type: NVTE_Mask_Type
qkv_layout: NVTE_QKV_Layout
scaling_factor: float
dropout_probability: float
is_training: bool
max_segments_per_seq: int
context_parallel_load_balanced: bool
cp_axis: str
@dataclass(frozen=True)
class FusedAttnHelper:
"""
......@@ -178,7 +215,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
name = "te_fused_attn_forward"
multiple_results = True
impl_static_args = (9, 10, 11, 12, 13, 14, 15)
impl_static_args = (9,)
inner_primitive = None
outer_primitive = None
......@@ -194,13 +231,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
_k_seq_offsets,
seed_aval,
*,
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
config: _FusedAttnConfig,
):
"""
Fused attention fwd abstract
......@@ -213,7 +244,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim)
......@@ -223,10 +254,10 @@ class FusedAttnFwdPrimitive(BasePrimitive):
backend = FusedAttnHelper(
q_dtype,
k_dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_probability,
config.qkv_layout,
config.attn_bias_type,
config.attn_mask_type,
config.dropout_probability,
attn_heads,
num_gqa_groups,
q_max_seqlen,
......@@ -238,7 +269,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, max_segments_per_seq)
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, config.max_segments_per_seq)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f"Unsupported {backend=}")
......@@ -252,7 +283,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
......@@ -270,14 +301,14 @@ class FusedAttnFwdPrimitive(BasePrimitive):
num_gqa_groups,
bias_heads,
head_dim,
scaling_factor,
dropout_probability,
attn_bias_type,
attn_mask_type,
qkv_layout,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
is_training,
max_segments_per_seq,
config.is_training,
config.max_segments_per_seq,
)
wkspace_aval = q_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
......@@ -308,28 +339,12 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k_seq_offsets,
seed,
*,
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
config: _FusedAttnConfig,
):
"""
Fused attention fwd lowering rules
"""
operands = [
q,
k,
v,
bias,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
]
operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
......@@ -340,12 +355,12 @@ class FusedAttnFwdPrimitive(BasePrimitive):
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
input_batch = reduce(operator.mul, batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
......@@ -362,16 +377,16 @@ class FusedAttnFwdPrimitive(BasePrimitive):
num_gqa_groups,
bias_heads,
head_dim,
max_segments_per_seq,
config.max_segments_per_seq,
wkspace_aval.size,
scaling_factor,
dropout_probability,
attn_bias_type,
attn_mask_type,
qkv_layout,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training,
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
)
......@@ -390,17 +405,11 @@ class FusedAttnFwdPrimitive(BasePrimitive):
q_seq_offsets,
k_seq_offsets,
seed,
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
config: _FusedAttnConfig,
):
assert FusedAttnFwdPrimitive.inner_primitive is not None
if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD:
if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD:
def _fix_len_take(x, condition, fill_value=-1):
x_shape = x.shape
......@@ -418,7 +427,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
)
return offsets_2d
match qkv_layout:
match config.qkv_layout:
case NVTE_QKV_Layout.NVTE_T3HD:
kv_max_seqlen = q_max_seqlen = q.shape[-4]
kv_batch = q_batch = reduce(operator.mul, q.shape[:-4])
......@@ -472,66 +481,27 @@ class FusedAttnFwdPrimitive(BasePrimitive):
q_seq_offsets,
k_seq_offsets,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
config=config,
)
return output, softmax_aux, rng_state
@staticmethod
def batcher(
batched_args,
batch_dims,
*,
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
):
def batcher(batched_args, batch_dims, *, config):
check_valid_batch_dims(batch_dims)
assert FusedAttnFwdPrimitive.outer_primitive is not None
q_bdim, *_, seed_bdim = batch_dims
out_bdims = q_bdim, q_bdim, seed_bdim
return (
FusedAttnFwdPrimitive.outer_primitive.bind(
*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
),
FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
mesh,
arg_infos,
result_infos,
):
del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, max_segments_per_seq, result_infos
def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
match qkv_layout:
match config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
# q_spec = (...batch, q_seqlen, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
......@@ -543,33 +513,22 @@ class FusedAttnFwdPrimitive(BasePrimitive):
# k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4])
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
)
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
# q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3])
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
)
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
raise ValueError(f"Unsupported {config.qkv_layout=}")
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
return (out_sharding, softmax_aux_sharding, rng_state_sharding)
@staticmethod
def partition(
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
mesh,
arg_infos,
result_infos,
):
def partition(config, mesh, arg_infos, result_infos):
out_sharding = result_infos[0].sharding
softmax_aux_sharding = result_infos[1].sharding
rng_state_sharding = seed_sharding = NamedSharding(
......@@ -577,16 +536,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
)
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
impl = partial(
FusedAttnFwdPrimitive.impl,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
)
impl = partial(FusedAttnFwdPrimitive.impl, config=config)
return mesh, impl, out_shardings, arg_shardings
......@@ -600,7 +550,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
name = "te_fused_attn_backward"
multiple_results = True
impl_static_args = (12, 13, 14, 15, 16, 17, 18)
impl_static_args = (12,)
inner_primitive = None
outer_primitive = None
......@@ -619,13 +569,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_q_seq_offsets,
_k_seq_offsets,
*,
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
config,
):
"""
Fused attention bwd abstract
......@@ -641,10 +585,10 @@ class FusedAttnBwdPrimitive(BasePrimitive):
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
......@@ -662,15 +606,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
num_gqa_groups,
bias_heads,
head_dim,
scaling_factor,
dropout_probability,
attn_bias_type,
attn_mask_type,
qkv_layout,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
is_training,
config.is_training,
deterministic,
max_segments_per_seq,
config.max_segments_per_seq,
)
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
......@@ -707,13 +651,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
q_seq_offsets,
k_seq_offsets,
*,
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
config,
):
"""
Fused attention bwd lowering rules
......@@ -743,12 +681,12 @@ class FusedAttnBwdPrimitive(BasePrimitive):
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
input_batch = reduce(operator.mul, batch_shape)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
......@@ -765,16 +703,16 @@ class FusedAttnBwdPrimitive(BasePrimitive):
num_gqa_groups,
bias_heads,
head_dim,
max_segments_per_seq,
config.max_segments_per_seq,
wkspace_aval.size,
scaling_factor,
dropout_probability,
attn_bias_type,
attn_mask_type,
qkv_layout,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training,
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
)
......@@ -796,17 +734,11 @@ class FusedAttnBwdPrimitive(BasePrimitive):
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
config,
):
assert FusedAttnBwdPrimitive.inner_primitive is not None
if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD:
if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD:
def _fix_len_take(x, condition, fill_value=-1):
x_shape = x.shape
......@@ -825,7 +757,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
)
return offsets_2d
match qkv_layout:
match config.qkv_layout:
case NVTE_QKV_Layout.NVTE_T3HD:
kv_max_seqlen = q_max_seqlen = q.shape[-4]
kv_batch = q_batch = reduce(operator.mul, q.shape[:-4])
......@@ -882,63 +814,25 @@ class FusedAttnBwdPrimitive(BasePrimitive):
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
config=config,
)
return dq, dk, dv, dbias
@staticmethod
def batcher(
batched_args,
batch_dims,
*,
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
):
def batcher(batched_args, batch_dims, *, config):
check_valid_batch_dims(batch_dims)
assert FusedAttnBwdPrimitive.outer_primitive is not None
q_bdim, k_bdim, v_bdim, *_ = batch_dims
out_bdims = q_bdim, k_bdim, v_bdim, q_bdim
return (
FusedAttnBwdPrimitive.outer_primitive.bind(
*batched_args,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
),
FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
mesh,
arg_infos,
result_infos,
):
del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, max_segments_per_seq
del dropout_probability, is_training, result_infos
def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
del config, result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
......@@ -950,18 +844,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
@staticmethod
def partition(
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
mesh,
arg_infos,
result_infos,
):
def partition(config, mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
......@@ -1001,16 +884,10 @@ class FusedAttnBwdPrimitive(BasePrimitive):
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
config=config,
)
global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
return local_dq, local_dk, local_dv, global_dbias
......@@ -1020,6 +897,378 @@ class FusedAttnBwdPrimitive(BasePrimitive):
register_primitive(FusedAttnBwdPrimitive)
@dataclass(frozen=True)
class _FusedAttnCPWithAllGatherHelper:
"""Helper class to assist with running the all-gather strategy for CP attention."""
mesh: jax.sharding.Mesh
config: _FusedAttnConfig
def check_supported(self):
"""Checks if the context parallel implementation is supported by the given arguments."""
header = "Context parallel fused attention"
allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD]
assert self.config.qkv_layout in allowed_layouts, (
f"{header} only supports layouts: {','.join([str(x) for x in allowed_layouts])} got:"
f" {self.config.qkv_layout}"
)
assert (
self.config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS
), f"{header} does not support bias got: {self.config.attn_bias_type}"
allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK]
assert self.config.attn_mask_type in allowed_masks, (
f"{header} only supports masking types: "
f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}"
)
assert self.config.max_segments_per_seq == 1, (
f"{header} only supports max_segments_per_seq == 1 got:"
f" {self.config.max_segments_per_seq}"
)
assert self.config.dropout_probability == 0.0, f"{header} does not support dropout"
def get_adjusted_mask(self):
"""Converts the mask for context parallelism."""
if self.config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK:
return NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
return self.config.attn_mask_type
def all_gather_kv(self, k, v):
"""Performs a all-gather of k and v over context parallel ranks."""
def ag(x):
return lax_paral_op(
x, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
)
match self.config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
return ag(k), v
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
return ag(k), ag(v)
return k, v # fall through
def reduce_scatter_dkv(self, dk, dv):
"""Performs a reduce-scatter of dk and dv over context parallel ranks."""
def rs(x):
return lax_paral_op(
x,
lax.psum_scatter,
self.config.cp_axis,
mesh=self.mesh,
scatter_dimension=1,
tiled=True,
)
match self.config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
return rs(dk), dv
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
return rs(dk), rs(dv)
return dk, dv # fall through
def kv_seqlens_for_rank(self, cp_rank, kv_max_seqlen, kv_seqlen_per_subrank):
"""Returns sequence lengths of KV to use for each sub rank of the given cp_rank.
Example: CP=4, MaxLen = 1024, Unbalanced
cp_rank 0: [128, 256]
cp_rank 1: [384, 512]
cp_rank 2: [640, 768]
cp_rank 3: [896, 1024]
Example: CP=4, MaxLen = 1024, Balanced
cp_rank 0: [128, 1024]
cp_rank 1: [256, 896]
cp_rank 2: [384, 768]
cp_rank 3: [512, 640]
"""
if self.config.context_parallel_load_balanced:
kv_seq_this_rank = [
(cp_rank + 1) * kv_seqlen_per_subrank,
kv_max_seqlen - cp_rank * kv_seqlen_per_subrank,
]
else:
kv_seq_this_rank = [
(cp_rank * 2 + 1) * kv_seqlen_per_subrank,
(cp_rank * 2 + 2) * kv_seqlen_per_subrank,
]
return kv_seq_this_rank
def slice_kv(self, k, v, slice_seq_len):
"""Slices k and v tensors to a sequence length of slice_seq_len."""
def sliced(x):
return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1)
match self.config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
return sliced(k), v
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
return sliced(k), sliced(v)
return k, v # fall through
def pad_kv(self, dk, dv, pad_seq_len):
"""Pads dk and dv tensors to a sequence length of pad_seq_len."""
def pad(x, npad):
return jnp.pad(x, npad, "constant", constant_values=0.0)
match self.config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]]
return pad(dk, npad), dv
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]]
return pad(dk, npad), pad(dv, npad)
return dk, dv # fall through
class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
"""
Fused Attention Forward with Context Parallelism Primitive
This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks.
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
if not is_context_parallel:
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
helper.check_supported()
out_sharding = result_infos[0].sharding
softmax_aux_sharding = result_infos[1].sharding
rng_state_sharding = seed_sharding = NamedSharding(
mesh, PartitionSpec(get_all_mesh_axes(), None)
)
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
def impl(q, k, v, bias, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, seed):
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
# cuDNN does not support right-aligned masking with dynamic sequence length padding.
# Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch
# to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor
# meeting the expectation of the SPMD model.
# TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
# mask/sequence length tensor to avoid this unrolled loop.
def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed):
kv_max_seqlen = k.shape[1]
kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"
q_split = jnp.split(q, 2, axis=1)
kv_seqlens_for_rank = helper.kv_seqlens_for_rank(
idx, kv_max_seqlen, kv_seqlen_per_subrank
)
results = []
for sub_idx in range(2):
if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK:
k_unmasked, v_unmasked = k, v # full kv used for unmasked
else:
k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])
q_seqlen_for_step = q_seqlen / (cp_size * 2)
num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks
output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl(
q_split[sub_idx],
k_unmasked,
v_unmasked,
bias,
q_seqlen_for_step,
kv_seqlen_for_step,
q_seq_offsets,
k_seq_offsets,
seed,
config=config,
)
results.append((output, softmax_aux, rng_state))
output = jnp.concatenate((results[0][0], results[1][0]), axis=1)
softmax_aux = jnp.concatenate((results[0][1], results[1][1]), axis=2)
rng_state = results[1][2] # Use the final RNG state
return output, softmax_aux, rng_state
k_ag, v_ag = helper.all_gather_kv(k, v)
functions = [
partial(_cross_attn, idx, q, k_ag, v_ag, bias, q_seqlen, kv_seqlen, seed)
for idx in range(cp_size)
]
return lax.switch(cp_rank, functions)
return mesh, impl, out_shardings, arg_shardings
register_primitive(FusedAttnCPWithAllGatherFwdPrimitive)
class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
"""
Fused Attention Backward with Context Parallelism Primitive.
This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks.
The gradients are subsequently reduce-scattered back to each context parallel rank.
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
# Ensure we can support this configuration with context parallelism.
helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
helper.check_supported()
del result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
def impl(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
):
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
# See comment in FusedAttnCPFwdPrimitive.partition for why we define this function.
def _cross_attn_bwd(
idx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen
):
kv_max_seqlen = k.shape[1]
kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"
q_split = jnp.split(q, 2, axis=1)
output_split = jnp.split(output, 2, axis=1)
doutput_split = jnp.split(doutput, 2, axis=1)
softmax_aux_split = jnp.split(softmax_aux, 2, axis=2)
kv_seqlens_for_rank = helper.kv_seqlens_for_rank(
idx, kv_max_seqlen, kv_seqlen_per_subrank
)
results = []
for sub_idx in range(2):
if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK:
k_unmasked, v_unmasked = k, v # full kv used for unmasked
else:
k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])
q_seqlen_for_step = q_seqlen // (cp_size * 2)
num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks
dq_local, dk_local, dv_local, dbias_local = FusedAttnBwdPrimitive.impl(
q_split[sub_idx],
k_unmasked,
v_unmasked,
bias,
softmax_aux_split[sub_idx],
rng_state,
output_split[sub_idx],
doutput_split[sub_idx],
q_seqlen_for_step,
kv_seqlen_for_step,
q_seq_offsets,
k_seq_offsets,
config=config,
)
# pad dk/dv to be unsliced shape so we can reduce scatter over all ranks.
if config.attn_mask_type != NVTE_Mask_Type.NVTE_NO_MASK:
pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx]
dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length)
results.append((dq_local, dk_local, dv_local, dbias_local))
dq_local = jnp.concatenate((results[0][0], results[1][0]), axis=1)
dk_local_pad = results[0][1] + results[1][1]
dv_local_pad = results[0][2] + results[1][2]
return dq_local, dk_local_pad, dv_local_pad, results[1][3]
k_ag, v_ag = helper.all_gather_kv(k, v)
functions = [
partial(
_cross_attn_bwd,
idx,
q,
k_ag,
v_ag,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
)
for idx in range(cp_size)
]
dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions)
dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local)
return dq, dk, dv, dbias
return mesh, impl, out_shardings, arg_shardings
register_primitive(FusedAttnCPWithAllGatherBwdPrimitive)
def _maybe_context_parallel_axis(cp_axis: str):
if not cp_axis:
gmr = global_mesh_resource()
if gmr is not None:
cp_axis = gmr.cp_resource
else:
cp_axis = ""
return cp_axis
def fused_attn_fwd(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
......@@ -1035,6 +1284,8 @@ def fused_attn_fwd(
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
) -> jnp.ndarray:
"""
Perform the forward pass of with cuDNN fused attention implementations.
......@@ -1063,6 +1314,9 @@ def fused_attn_fwd(
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
"""
......@@ -1094,14 +1348,7 @@ def fused_attn_fwd(
assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype)
return FusedAttnFwdPrimitive.outer_primitive.bind(
*qkv_for_primitive,
bias,
q_seqlen,
kv_seqlen,
q_seq_offsets if is_ragged else _not_used,
kv_seq_offsets if is_ragged else _not_used,
seed,
fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
......@@ -1109,6 +1356,19 @@ def fused_attn_fwd(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
)
return FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive.bind(
*qkv_for_primitive,
bias,
q_seqlen,
kv_seqlen,
q_seq_offsets if is_ragged else _not_used,
kv_seq_offsets if is_ragged else _not_used,
seed,
config=fused_config,
)
......@@ -1130,6 +1390,8 @@ def fused_attn_bwd(
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
"""
Perform the backward pass of the cuDNN fused attention implementations.
......@@ -1159,7 +1421,9 @@ def fused_attn_bwd(
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Returns:
Tuple[jnp.ndarray, ...], jnp.ndarray:
- The first tuple contains the gradients with respect to the input `qkv` tensors in the
......@@ -1194,7 +1458,19 @@ def fused_attn_bwd(
assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype)
*qkv_grads, bias_grad = FusedAttnBwdPrimitive.outer_primitive.bind(
fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
)
*qkv_grads, bias_grad = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive.bind(
*qkv_for_primitive,
bias,
softmax_aux,
......@@ -1205,12 +1481,6 @@ def fused_attn_bwd(
kv_seqlen,
q_seq_offsets if is_ragged else _not_used,
kv_seq_offsets if is_ragged else _not_used,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
config=fused_config,
)
return tuple(qkv_grads[: len(qkv)]), bias_grad
......@@ -100,7 +100,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK);
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK",
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
......
......@@ -20,6 +20,7 @@ _PXLA_THREAD_RESOURCES = pxla.thread_resources
BATCH_AXES = "nvte_batch"
SEQLEN_AXES = "nvte_seqlen"
SEQLEN_TP_AXES = "nvte_seqlen_tp"
SEQLEN_CP_AXES = "nvte_seqlen_cp"
HEAD_AXES = "nvte_head"
HIDDEN_AXES = "nvte_hidden"
HIDDEN_TP_AXES = "nvte_hidden_tp"
......@@ -65,6 +66,7 @@ def get_sharding_map_logic_axis_to_mesh_axis():
BATCH_AXES: batch_dim_rule,
SEQLEN_AXES: None,
SEQLEN_TP_AXES: gsr.tp_resource,
SEQLEN_CP_AXES: gsr.cp_resource,
HEAD_AXES: gsr.tp_resource,
HIDDEN_AXES: None,
HIDDEN_TP_AXES: gsr.tp_resource,
......@@ -131,13 +133,15 @@ def get_padded_spec(spec, ndim):
return spec + (None,) * (ndim - len(spec))
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh):
def lax_paral_op(
x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh, **kwargs
):
"""
A wrapper function to invoke lax.p* operations, like psum.
"""
if mesh_resource is not None:
_, resource = _get_mesh_info(mesh_resource, mesh)
return ops(x, resource)
return ops(x, resource, **kwargs)
return x
......@@ -148,6 +152,33 @@ def num_of_devices():
return len(jax.devices())
def get_mesh_axis_size(axis, mesh=None):
"""
Get the axis size of the given mesh.
If the mesh is None, it would be replaced
by the global mesh.
"""
if mesh is None:
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if axis is None:
return 1
assert axis in mesh.shape, f"{axis} is not a axis of the given mesh {mesh.shape}"
return mesh.shape[axis]
def get_mesh_axis_rank(axis: str, mesh=None):
"""
Gets the local axis rank of the `axis` of the array.
If the mesh is None the rank is 0.
"""
if mesh is None:
return 0
_, axis_name = _get_mesh_info(axis, mesh)
return jax.lax.axis_index(axis_name)
@dataclass
class MeshResource:
"""
......@@ -168,12 +199,16 @@ class MeshResource:
pp_resource : str, default = None
The axis name in Mesh used to split model layers. along.
If it is None, then pipeline parallelism is disabled.
cp_resource : str, default = None
The axis name in Mesh used to split sequence (context) dimensions along
in the attention. If it is None, then context parallelism is disabled.
"""
dp_resource: str = None
tp_resource: str = None
fsdp_resource: str = None
pp_resource: str = None
cp_resource: str = None
_GLOBAL_MESH_RESOURCE = MeshResource()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment