Unverified Commit bfddb483 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Support Ring Attention (Context Parallelism) (#1059)



* Implement ring attention primative for Jax.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 2643ba1d
......@@ -5,4 +5,11 @@
set -xe
: ${TE_PATH:=/opt/transformerengine}
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
# Skip ring attention tests since they need fixed environment vars
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* -k 'not test_context_parallel_ring_attn'
# Test ring attention with and without scan loop
NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 XLA_FLAGS="--xla_experimental_ignore_channel_id" \
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
......@@ -35,7 +35,9 @@ from transformer_engine.jax.attention import (
get_qkv_format,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
CPStrategy,
)
from transformer_engine.jax.sharding import MeshResource
# We will use the golden reference model from our non distributed attention test fixture.
from test_fused_attn import general_dot_product_attention, make_mask
......@@ -333,6 +335,36 @@ class TestDistributedCrossAttn:
)
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
class TestDistributedContextParallelSelfAttn:
def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
......@@ -370,37 +402,7 @@ class TestDistributedContextParallelSelfAttn:
raise ValueError(f"Unsupported {qkv_layout=}")
return qkv_args
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
def test_contex_parallel_self_attn(
def impl_test_contex_parallel_attn(
self,
device_count,
mesh_shape,
......@@ -412,6 +414,7 @@ class TestDistributedContextParallelSelfAttn:
dtype,
qkv_layout,
load_balanced,
cp_strategy,
):
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
......@@ -469,6 +472,7 @@ class TestDistributedContextParallelSelfAttn:
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
context_parallel_strategy=cp_strategy,
context_parallel_causal_load_balanced=load_balanced,
context_parallel_axis="cp",
).astype(dtype)
......@@ -574,6 +578,60 @@ class TestDistributedContextParallelSelfAttn:
assert_allclose(target_grads[i], ref_grads[i], dtype=dtype)
def test_contex_parallel_allgather_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.ALL_GATHER,
)
def test_context_parallel_ring_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.RING,
)
class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8])
......
......@@ -7,6 +7,7 @@ from dataclasses import dataclass
from functools import partial
from math import sqrt
from typing import Tuple, Optional
import random
import jax
import jax.numpy as jnp
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
from functools import partial
import os
from transformer_engine.jax.cpp_extensions.misc import get_xla_flag
@pytest.fixture(autouse=True, scope="function")
def preserve_xla_flags():
"""Ensures the XLA flags environment variable is restored after any tests in this file run."""
old_flags = os.getenv("XLA_FLAGS")
yield
if old_flags is not None:
os.environ["XLA_FLAGS"] = old_flags
def test_get_xla_flag(request):
os.environ["XLA_FLAGS"] = ""
assert get_xla_flag("") is None
assert get_xla_flag("--foo") is None
assert get_xla_flag("--bar=1") is None
os.environ["XLA_FLAGS"] = "--foo --bar=1 --baz=biz"
assert get_xla_flag("--foo") == True
assert get_xla_flag("--bar") == "1"
assert get_xla_flag("--bar", cast=int) == 1
assert get_xla_flag("--bar", cast=bool) == True
assert get_xla_flag("--baz") == "biz"
with pytest.raises(ValueError):
# cast will fail
assert get_xla_flag("--baz", cast=int)
assert get_xla_flag("--xla") is None
os.environ["XLA_FLAGS"] = "--xla_abc --xla_abb"
assert get_xla_flag("--xla_abc") == True
assert get_xla_flag("--xla_abb") == True
......@@ -79,6 +79,19 @@ class QKVFormat(Enum):
THD = NVTE_QKV_Format.NVTE_THD
class CPStrategy(Enum):
"""Defines the context parallel strategies of Jax fused attention.
DEFAULT: Default strategy will choose automatically if context parallel axis is sharded.
ALL_GATHER: All-gather/reduce scatter implementation.
RING: Ring attention implementation (https://arxiv.org/abs/2310.01889).
"""
DEFAULT = 0
ALL_GATHER = 1
RING = 2
def get_qkv_format(qkv_layout):
"""
Get qkv_format from qkv_layout
......@@ -260,6 +273,7 @@ def fused_attn(
dropout_probability: float,
is_training: bool,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
......@@ -347,6 +361,7 @@ def fused_attn(
is_training=is_training,
max_segments_per_seq=1,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
......@@ -370,6 +385,7 @@ def fused_attn_thd(
is_training: bool,
max_segments_per_seq: int = 1,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
......@@ -470,6 +486,7 @@ def fused_attn_thd(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
......@@ -477,7 +494,7 @@ def fused_attn_thd(
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
......@@ -494,6 +511,7 @@ def _fused_attn(
is_training: bool,
max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]],
context_parallel_strategy: CPStrategy,
context_parallel_causal_load_balanced: bool,
context_parallel_axis: str,
):
......@@ -513,6 +531,7 @@ def _fused_attn(
is_training,
max_segments_per_seq,
window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
)
......@@ -535,6 +554,7 @@ def _fused_attn_fwd_rule(
is_training,
max_segments_per_seq,
window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
):
......@@ -554,6 +574,7 @@ def _fused_attn_fwd_rule(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
......@@ -582,6 +603,7 @@ def _fused_attn_bwd_rule(
is_training,
max_segments_per_seq,
window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
ctx,
......@@ -617,6 +639,7 @@ def _fused_attn_bwd_rule(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
......
......@@ -17,6 +17,8 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine.jax.attention import CPStrategy
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import (
NVTE_Bias_Type,
......@@ -35,6 +37,7 @@ from .misc import (
get_padded_spec,
get_cudnn_version,
is_ffi_enabled,
get_xla_flag,
)
from ..sharding import (
global_mesh_resource,
......@@ -1032,7 +1035,7 @@ class _FusedAttnCPWithAllGatherHelper:
if self.config.qkv_layout not in allowed_layouts:
raise ValueError(
f"{header} only supports layouts:"
f" {','.join([str(x) for x in allowed_layouts])} got: {self.config.qkv_layout}"
f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
)
if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS:
......@@ -1042,7 +1045,7 @@ class _FusedAttnCPWithAllGatherHelper:
if self.config.attn_mask_type not in allowed_masks:
raise ValueError(
f"{header} only supports masking types: "
f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}"
f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
)
if self.config.max_segments_per_seq != 1:
......@@ -1411,6 +1414,503 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
register_primitive(FusedAttnCPWithAllGatherBwdPrimitive)
@dataclass(frozen=True)
class _FusedAttnCPWithP2PHelper:
"""Helper class to assist with running the P2P ring strategy for CP attention."""
mesh: jax.sharding.Mesh
config: _FusedAttnConfig
@staticmethod
def use_scanloop():
"""Returns true if the implementation will use a scan loop for iteration."""
use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1")))
# nvbug(4675071): Disable the HLO verifier for channel ID checks.
# A WAR was added to XLA: https://github.com/openxla/xla/pull/16779
def truthy(val):
return val.lower() in ["1", "true"]
x = use_scan and get_xla_flag(
"--xla_experimental_ignore_channel_id", default=False, cast=truthy
)
return x
def check_supported(self):
"""Checks if the context parallel implementation is supported by the given arguments."""
header = "Context parallel fused ring attention"
allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD]
if self.config.qkv_layout not in allowed_layouts:
raise ValueError(
f"{header} only supports layouts:"
f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
)
if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS:
raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK]
if self.config.attn_mask_type not in allowed_masks:
raise ValueError(
f"{header} only supports masking types: "
f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
)
if self.config.max_segments_per_seq != 1:
raise ValueError(
f"{header} only supports max_segments_per_seq == 1 got:"
f" {self.config.max_segments_per_seq}"
)
if self.config.dropout_probability != 0.0:
raise ValueError(f"{header} does not support dropout")
# We want to encourage use of scan loop to minimize unrolling and ensure more
# predictable scheduling from XLA. The unrolled flavor will be supported but
# not the prefered implementation.
if not self.use_scanloop():
warnings.warn(
"Scan loop is disabled for fused ring attention. To enable set"
" NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment and"
" add --xla_experimental_ignore_channel_id=true to XLA_FLAGS."
)
def get_step_config(self, attn_mask_type) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call to fused attention."""
return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
scaling_factor=self.config.scaling_factor,
dropout_probability=self.config.dropout_probability,
is_training=self.config.is_training,
max_segments_per_seq=self.config.max_segments_per_seq,
window_size=self.config.window_size,
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
)
def stack_kv(self, k, v):
"""Stacks k and v tensors if not stacked."""
_not_used = jnp.zeros(0, dtype=k.dtype)
match self.config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
return k
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
return jnp.stack([k, v], axis=2)
return _not_used
def unstack_kv(self, kv):
"""Un-stacks k and v tensors if not stacked."""
_not_used = jnp.zeros(0, dtype=kv.dtype)
match self.config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
return kv, _not_used
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
return jnp.unstack(kv, axis=2)
return _not_used, _not_used # fall through
def permute_kv(self, kv, cp_perm):
"""Permutes kv around the ring as described by cp_perm."""
return lax_paral_op(kv, lax.ppermute, self.config.cp_axis, mesh=self.mesh, perm=cp_perm)
def correct_softmax_aux(self, softmax_aux, softmax_aux_per_step):
"""Apply soft max correction after an attention step."""
max_scale = jnp.maximum(softmax_aux, softmax_aux_per_step)
min_scale = jnp.minimum(softmax_aux, softmax_aux_per_step)
new_softmax_aux = max_scale + jnp.log(1 + jnp.exp(min_scale - max_scale))
return new_softmax_aux
def adjust_seqlen(self, seqlen, max_seqlen, idx):
"""Adjust the sequence length per step."""
seqlen_of_curr_step = seqlen - max_seqlen * idx
seqlen_of_curr_step = jnp.where(seqlen_of_curr_step < 0, 0, seqlen_of_curr_step)
seqlen_per_step = jnp.where(
seqlen_of_curr_step < max_seqlen, seqlen_of_curr_step, max_seqlen
)
return seqlen_per_step
class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
"""
Fused Ring Attention Forward Primitive
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported()
out_sharding = result_infos[0].sharding
softmax_aux_sharding = result_infos[1].sharding
rng_state_sharding = seed_sharding = NamedSharding(
mesh, PartitionSpec(get_all_mesh_axes(), None)
)
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
def ring_attn_fwd_impl(
q,
k,
v,
bias,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
):
_not_used = jnp.zeros(0, dtype=v.dtype)
# Combine KV tensors if separate for better permute scheduling and performance.
# Eventually XLA should perform this automatically.
kv = helper.stack_kv(k, v)
batch, q_max_seqlen, head, _ = q.shape
kv_max_seqlen = k.shape[1]
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
output_per_steps = jnp.zeros((cp_size, *q.shape), dtype=q.dtype)
softmax_aux_per_steps = jnp.zeros(
(cp_size, batch, head, q_max_seqlen, 1), dtype=jnp.float32
)
softmax_aux = jnp.full((batch, head, q_max_seqlen, 1), -jnp.inf, dtype=jnp.float32)
# RNG shape should be the shared shape. This is unused for ring attention as we do not
# support dropout currently.
rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:])
rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)
def scan_kv_block(idx, carry):
kv, softmax_aux, output_per_steps, softmax_aux_per_steps = carry
# Send KV block to next step so we can overlap compute.
kv_next = helper.permute_kv(kv, cp_perm)
def mask_compute(attn_mask_type):
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
q,
kv,
_not_used,
bias,
q_seqlen_per_step,
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
seed,
helper.get_step_config(attn_mask_type),
)
return output_per_step, softmax_aux_per_step
causal_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_CAUSAL_MASK)
no_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_NO_MASK)
def half_kv_no_mask_compute():
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
kv_part = lax.slice_in_dim(kv, 0, kv.shape[1] // 2, axis=1)
output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
q,
kv_part,
_not_used,
bias,
q_seqlen_per_step,
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
seed,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
)
return output_per_step, softmax_aux_per_step
def half_q_no_mask_compute():
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
q_part,
kv,
_not_used,
bias,
q_seqlen_per_step,
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
seed,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
)
output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1)
softmax_aux_per_step = jnp.concat(
[
jnp.full_like(softmax_aux_per_step, -jnp.inf),
softmax_aux_per_step,
],
axis=2,
)
return output_per_step, softmax_aux_per_step
def skip_compute():
output_per_step = jnp.zeros_like(q)
softmax_aux_per_step = jnp.full(
(batch, head, q.shape[1], 1), -jnp.inf, dtype=jnp.float32
)
return output_per_step, softmax_aux_per_step
if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK:
# This is for nested jax.lax.cond
def jax_cond_wrap():
if config.context_parallel_load_balanced:
return lax.cond(
(idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
)
return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)
output_per_step, softmax_aux_per_step = lax.cond(
idx == 0, causal_mask_compute, jax_cond_wrap
)
else:
output_per_step, softmax_aux_per_step = no_mask_compute()
softmax_aux = helper.correct_softmax_aux(softmax_aux, softmax_aux_per_step)
output_per_steps = output_per_steps.at[idx].set(output_per_step)
softmax_aux_per_steps = softmax_aux_per_steps.at[idx].set(softmax_aux_per_step)
return (kv_next, softmax_aux, output_per_steps, softmax_aux_per_steps)
carry = (kv, softmax_aux, output_per_steps, softmax_aux_per_steps)
if helper.use_scanloop():
carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
else:
for i in range(0, cp_size):
carry = scan_kv_block(i, carry)
(kv, softmax_aux, output_per_steps, softmax_aux_per_steps) = carry
output = jnp.zeros(q.shape).astype(jnp.float32)
for idx in range(cp_size):
output = output + output_per_steps[idx].astype(jnp.float32) * jnp.exp(
softmax_aux_per_steps[idx] - softmax_aux
).transpose(0, 2, 1, 3)
output = output.astype(q.dtype)
return output, softmax_aux, rng_state
return mesh, ring_attn_fwd_impl, out_shardings, arg_shardings
register_primitive(FusedRingAttnFwdPrimitive)
class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
"""
Fused Ring Attention Backward Primitive
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
del result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported()
def ring_attn_bwd_impl(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
):
_not_used = jnp.zeros(0, dtype=output.dtype)
# Combine KV tensors if separate for better permute scheduling and performance.
# Eventually XLA should perform this automatically.
kv = helper.stack_kv(k, v)
q_max_seqlen = q.shape[1]
kv_max_seqlen = k.shape[1]
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
dq = jnp.zeros_like(q)
dk_dv = helper.stack_kv(jnp.zeros_like(k), jnp.zeros_like(v))
dbias = jnp.zeros_like(bias)
def scan_kv_block(idx, carry):
kv, dq, dk_dv, dbias = carry
# Start communication that feeds the next iteraton.
# We further combine the tensors to improve overlap.
kv_dk_dv = jnp.stack([kv, dk_dv])
kv_dk_dv = helper.permute_kv(kv_dk_dv, cp_perm)
def mask_compute(attn_mask_type):
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
q,
kv,
_not_used,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen_per_step,
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
config=helper.get_step_config(attn_mask_type),
)
return dq_per_step, dk_dv_per_step, dbias_per_step
causal_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_CAUSAL_MASK)
no_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_NO_MASK)
def half_kv_no_mask_compute():
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
kv_part = lax.slice_in_dim(kv, 0, kv_max_seqlen // 2, axis=1)
dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
q,
kv_part,
_not_used,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen_per_step,
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
)
dk_dv_per_step = jnp.concat(
[dk_dv_per_step, jnp.zeros_like(dk_dv_per_step)], axis=1
)
return dq_per_step, dk_dv_per_step, dbias_per_step
def half_q_no_mask_compute():
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
doutput_part = lax.slice_in_dim(
doutput, q_max_seqlen // 2, q_max_seqlen, axis=1
)
output_part = lax.slice_in_dim(output, q_max_seqlen // 2, q_max_seqlen, axis=1)
softmax_aux_part = lax.slice_in_dim(
softmax_aux, q_max_seqlen // 2, q_max_seqlen, axis=2
)
dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
q_part,
kv,
_not_used,
bias,
softmax_aux_part,
rng_state,
output_part,
doutput_part,
q_seqlen_per_step,
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
)
dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1)
return dq_per_step, dk_dv_per_step, dbias_per_step
def skip_compute():
return jnp.zeros_like(q), jnp.zeros_like(kv), jnp.zeros_like(bias)
if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK:
# This is for nested jax.lax.cond
def jax_cond_wrap():
if config.context_parallel_load_balanced:
return lax.cond(
(idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
)
return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)
dq_per_step, dk_dv_per_step, dbias_per_step = lax.cond(
idx == 0, causal_mask_compute, jax_cond_wrap
)
else:
dq_per_step, dk_dv_per_step, dbias_per_step = no_mask_compute()
kv_next, dk_dv = jnp.unstack(kv_dk_dv)
dq = dq + dq_per_step
dk_dv = dk_dv + dk_dv_per_step
if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
dbias = dbias + dbias_per_step
return (kv_next, dq, dk_dv, dbias)
carry = (kv, dq, dk_dv, dbias)
if helper.use_scanloop():
carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
else:
for i in range(0, cp_size):
carry = scan_kv_block(i, carry)
(kv, dq, dk_dv, dbias) = carry
# Final permute to put gradients back to their final resting place.
dk_dv = helper.permute_kv(dk_dv, cp_perm)
global_dbias = dbias
if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)
dk, dv = helper.unstack_kv(dk_dv)
return dq, dk, dv, global_dbias
return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings
register_primitive(FusedRingAttnBwdPrimitive)
def _maybe_context_parallel_axis(cp_axis: str):
if not cp_axis:
gmr = global_mesh_resource()
......@@ -1437,6 +1937,7 @@ def fused_attn_fwd(
is_training: bool,
max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
) -> jnp.ndarray:
......@@ -1519,7 +2020,14 @@ def fused_attn_fwd(
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
)
return FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive.bind(
primative = None
match context_parallel_strategy:
case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
primative = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive
case CPStrategy.RING:
primative = FusedRingAttnFwdPrimitive.outer_primitive
return primative.bind(
*qkv_for_primitive,
bias,
q_seqlen,
......@@ -1550,6 +2058,7 @@ def fused_attn_bwd(
is_training: bool,
max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
......@@ -1636,7 +2145,14 @@ def fused_attn_bwd(
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
)
*qkv_grads, bias_grad = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive.bind(
primative = None
match context_parallel_strategy:
case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
primative = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive
case CPStrategy.RING:
primative = FusedRingAttnBwdPrimitive.outer_primitive
*qkv_grads, bias_grad = primative.bind(
*qkv_for_primitive,
bias,
softmax_aux,
......
......@@ -167,3 +167,25 @@ def is_ffi_enabled():
is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1"))
assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value"
return is_supported and is_enabled
def get_xla_flag(flag: str, default=None, cast=str):
"""
Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value.
"""
xla_flags = []
if xla_flags_env := os.getenv("XLA_FLAGS"):
xla_flags.extend(xla_flags_env.split())
for flag_i in sorted(xla_flags):
if "=" in flag_i:
# option like --xla_abc=foo
name, val = flag_i.split("=", 2)
if name == flag:
return val if cast is None else cast(val)
else:
# flag like --xla_enable_foo
name, val = flag_i, None
if name == flag:
return True
return default
......@@ -197,7 +197,7 @@ class MeshResource:
The axis name in Mesh used to split the batch and weights along.
If it is None, then full-sharded data parallelism is disabled.
pp_resource : str, default = None
The axis name in Mesh used to split model layers. along.
The axis name in Mesh used to split model layers along.
If it is None, then pipeline parallelism is disabled.
cp_resource : str, default = None
The axis name in Mesh used to split sequence (context) dimensions along
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment