Unverified Commit bfddb483 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Support Ring Attention (Context Parallelism) (#1059)



* Implement ring attention primative for Jax.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 2643ba1d
......@@ -5,4 +5,11 @@
set -xe
: ${TE_PATH:=/opt/transformerengine}
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
# Skip ring attention tests since they need fixed environment vars
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* -k 'not test_context_parallel_ring_attn'
# Test ring attention with and without scan loop
NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 XLA_FLAGS="--xla_experimental_ignore_channel_id" \
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
......@@ -35,7 +35,9 @@ from transformer_engine.jax.attention import (
get_qkv_format,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
CPStrategy,
)
from transformer_engine.jax.sharding import MeshResource
# We will use the golden reference model from our non distributed attention test fixture.
from test_fused_attn import general_dot_product_attention, make_mask
......@@ -333,6 +335,36 @@ class TestDistributedCrossAttn:
)
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
class TestDistributedContextParallelSelfAttn:
def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
......@@ -370,37 +402,7 @@ class TestDistributedContextParallelSelfAttn:
raise ValueError(f"Unsupported {qkv_layout=}")
return qkv_args
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
def test_contex_parallel_self_attn(
def impl_test_contex_parallel_attn(
self,
device_count,
mesh_shape,
......@@ -412,6 +414,7 @@ class TestDistributedContextParallelSelfAttn:
dtype,
qkv_layout,
load_balanced,
cp_strategy,
):
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
......@@ -469,6 +472,7 @@ class TestDistributedContextParallelSelfAttn:
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
context_parallel_strategy=cp_strategy,
context_parallel_causal_load_balanced=load_balanced,
context_parallel_axis="cp",
).astype(dtype)
......@@ -574,6 +578,60 @@ class TestDistributedContextParallelSelfAttn:
assert_allclose(target_grads[i], ref_grads[i], dtype=dtype)
def test_contex_parallel_allgather_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.ALL_GATHER,
)
def test_context_parallel_ring_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.RING,
)
class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8])
......
......@@ -7,6 +7,7 @@ from dataclasses import dataclass
from functools import partial
from math import sqrt
from typing import Tuple, Optional
import random
import jax
import jax.numpy as jnp
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
from functools import partial
import os
from transformer_engine.jax.cpp_extensions.misc import get_xla_flag
@pytest.fixture(autouse=True, scope="function")
def preserve_xla_flags():
"""Ensures the XLA flags environment variable is restored after any tests in this file run."""
old_flags = os.getenv("XLA_FLAGS")
yield
if old_flags is not None:
os.environ["XLA_FLAGS"] = old_flags
def test_get_xla_flag(request):
os.environ["XLA_FLAGS"] = ""
assert get_xla_flag("") is None
assert get_xla_flag("--foo") is None
assert get_xla_flag("--bar=1") is None
os.environ["XLA_FLAGS"] = "--foo --bar=1 --baz=biz"
assert get_xla_flag("--foo") == True
assert get_xla_flag("--bar") == "1"
assert get_xla_flag("--bar", cast=int) == 1
assert get_xla_flag("--bar", cast=bool) == True
assert get_xla_flag("--baz") == "biz"
with pytest.raises(ValueError):
# cast will fail
assert get_xla_flag("--baz", cast=int)
assert get_xla_flag("--xla") is None
os.environ["XLA_FLAGS"] = "--xla_abc --xla_abb"
assert get_xla_flag("--xla_abc") == True
assert get_xla_flag("--xla_abb") == True
......@@ -79,6 +79,19 @@ class QKVFormat(Enum):
THD = NVTE_QKV_Format.NVTE_THD
class CPStrategy(Enum):
"""Defines the context parallel strategies of Jax fused attention.
DEFAULT: Default strategy will choose automatically if context parallel axis is sharded.
ALL_GATHER: All-gather/reduce scatter implementation.
RING: Ring attention implementation (https://arxiv.org/abs/2310.01889).
"""
DEFAULT = 0
ALL_GATHER = 1
RING = 2
def get_qkv_format(qkv_layout):
"""
Get qkv_format from qkv_layout
......@@ -260,6 +273,7 @@ def fused_attn(
dropout_probability: float,
is_training: bool,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
......@@ -347,6 +361,7 @@ def fused_attn(
is_training=is_training,
max_segments_per_seq=1,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
......@@ -370,6 +385,7 @@ def fused_attn_thd(
is_training: bool,
max_segments_per_seq: int = 1,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
......@@ -470,6 +486,7 @@ def fused_attn_thd(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
......@@ -477,7 +494,7 @@ def fused_attn_thd(
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
......@@ -494,6 +511,7 @@ def _fused_attn(
is_training: bool,
max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]],
context_parallel_strategy: CPStrategy,
context_parallel_causal_load_balanced: bool,
context_parallel_axis: str,
):
......@@ -513,6 +531,7 @@ def _fused_attn(
is_training,
max_segments_per_seq,
window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
)
......@@ -535,6 +554,7 @@ def _fused_attn_fwd_rule(
is_training,
max_segments_per_seq,
window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
):
......@@ -554,6 +574,7 @@ def _fused_attn_fwd_rule(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
......@@ -582,6 +603,7 @@ def _fused_attn_bwd_rule(
is_training,
max_segments_per_seq,
window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
ctx,
......@@ -617,6 +639,7 @@ def _fused_attn_bwd_rule(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
......
......@@ -167,3 +167,25 @@ def is_ffi_enabled():
is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1"))
assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value"
return is_supported and is_enabled
def get_xla_flag(flag: str, default=None, cast=str):
"""
Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value.
"""
xla_flags = []
if xla_flags_env := os.getenv("XLA_FLAGS"):
xla_flags.extend(xla_flags_env.split())
for flag_i in sorted(xla_flags):
if "=" in flag_i:
# option like --xla_abc=foo
name, val = flag_i.split("=", 2)
if name == flag:
return val if cast is None else cast(val)
else:
# flag like --xla_enable_foo
name, val = flag_i, None
if name == flag:
return True
return default
......@@ -197,7 +197,7 @@ class MeshResource:
The axis name in Mesh used to split the batch and weights along.
If it is None, then full-sharded data parallelism is disabled.
pp_resource : str, default = None
The axis name in Mesh used to split model layers. along.
The axis name in Mesh used to split model layers along.
If it is None, then pipeline parallelism is disabled.
cp_resource : str, default = None
The axis name in Mesh used to split sequence (context) dimensions along
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment