Unverified Commit bfddb483 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Support Ring Attention (Context Parallelism) (#1059)



* Implement ring attention primative for Jax.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 2643ba1d
...@@ -5,4 +5,11 @@ ...@@ -5,4 +5,11 @@
set -xe set -xe
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
# Skip ring attention tests since they need fixed environment vars
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* -k 'not test_context_parallel_ring_attn'
# Test ring attention with and without scan loop
NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 XLA_FLAGS="--xla_experimental_ignore_channel_id" \
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
...@@ -35,7 +35,9 @@ from transformer_engine.jax.attention import ( ...@@ -35,7 +35,9 @@ from transformer_engine.jax.attention import (
get_qkv_format, get_qkv_format,
reorder_causal_load_balancing, reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing, inverse_reorder_causal_load_balancing,
CPStrategy,
) )
from transformer_engine.jax.sharding import MeshResource
# We will use the golden reference model from our non distributed attention test fixture. # We will use the golden reference model from our non distributed attention test fixture.
from test_fused_attn import general_dot_product_attention, make_mask from test_fused_attn import general_dot_product_attention, make_mask
...@@ -333,6 +335,36 @@ class TestDistributedCrossAttn: ...@@ -333,6 +335,36 @@ class TestDistributedCrossAttn:
) )
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
class TestDistributedContextParallelSelfAttn: class TestDistributedContextParallelSelfAttn:
def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype): def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
...@@ -370,37 +402,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -370,37 +402,7 @@ class TestDistributedContextParallelSelfAttn:
raise ValueError(f"Unsupported {qkv_layout=}") raise ValueError(f"Unsupported {qkv_layout=}")
return qkv_args return qkv_args
@pytest.mark.parametrize( def impl_test_contex_parallel_attn(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
def test_contex_parallel_self_attn(
self, self,
device_count, device_count,
mesh_shape, mesh_shape,
...@@ -412,6 +414,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -412,6 +414,7 @@ class TestDistributedContextParallelSelfAttn:
dtype, dtype,
qkv_layout, qkv_layout,
load_balanced, load_balanced,
cp_strategy,
): ):
attn_bias_type = AttnBiasType.NO_BIAS attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0 dropout_prob = 0.0
...@@ -469,6 +472,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -469,6 +472,7 @@ class TestDistributedContextParallelSelfAttn:
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_prob, dropout_probability=dropout_prob,
is_training=is_training, is_training=is_training,
context_parallel_strategy=cp_strategy,
context_parallel_causal_load_balanced=load_balanced, context_parallel_causal_load_balanced=load_balanced,
context_parallel_axis="cp", context_parallel_axis="cp",
).astype(dtype) ).astype(dtype)
...@@ -574,6 +578,60 @@ class TestDistributedContextParallelSelfAttn: ...@@ -574,6 +578,60 @@ class TestDistributedContextParallelSelfAttn:
assert_allclose(target_grads[i], ref_grads[i], dtype=dtype) assert_allclose(target_grads[i], ref_grads[i], dtype=dtype)
def test_contex_parallel_allgather_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.ALL_GATHER,
)
def test_context_parallel_ring_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.RING,
)
class TestReorderCausalLoadBalancing: class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8]) @pytest.mark.parametrize("cp_size", [2, 4, 8])
......
...@@ -7,6 +7,7 @@ from dataclasses import dataclass ...@@ -7,6 +7,7 @@ from dataclasses import dataclass
from functools import partial from functools import partial
from math import sqrt from math import sqrt
from typing import Tuple, Optional from typing import Tuple, Optional
import random
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
from functools import partial
import os
from transformer_engine.jax.cpp_extensions.misc import get_xla_flag
@pytest.fixture(autouse=True, scope="function")
def preserve_xla_flags():
"""Ensures the XLA flags environment variable is restored after any tests in this file run."""
old_flags = os.getenv("XLA_FLAGS")
yield
if old_flags is not None:
os.environ["XLA_FLAGS"] = old_flags
def test_get_xla_flag(request):
os.environ["XLA_FLAGS"] = ""
assert get_xla_flag("") is None
assert get_xla_flag("--foo") is None
assert get_xla_flag("--bar=1") is None
os.environ["XLA_FLAGS"] = "--foo --bar=1 --baz=biz"
assert get_xla_flag("--foo") == True
assert get_xla_flag("--bar") == "1"
assert get_xla_flag("--bar", cast=int) == 1
assert get_xla_flag("--bar", cast=bool) == True
assert get_xla_flag("--baz") == "biz"
with pytest.raises(ValueError):
# cast will fail
assert get_xla_flag("--baz", cast=int)
assert get_xla_flag("--xla") is None
os.environ["XLA_FLAGS"] = "--xla_abc --xla_abb"
assert get_xla_flag("--xla_abc") == True
assert get_xla_flag("--xla_abb") == True
...@@ -79,6 +79,19 @@ class QKVFormat(Enum): ...@@ -79,6 +79,19 @@ class QKVFormat(Enum):
THD = NVTE_QKV_Format.NVTE_THD THD = NVTE_QKV_Format.NVTE_THD
class CPStrategy(Enum):
"""Defines the context parallel strategies of Jax fused attention.
DEFAULT: Default strategy will choose automatically if context parallel axis is sharded.
ALL_GATHER: All-gather/reduce scatter implementation.
RING: Ring attention implementation (https://arxiv.org/abs/2310.01889).
"""
DEFAULT = 0
ALL_GATHER = 1
RING = 2
def get_qkv_format(qkv_layout): def get_qkv_format(qkv_layout):
""" """
Get qkv_format from qkv_layout Get qkv_format from qkv_layout
...@@ -260,6 +273,7 @@ def fused_attn( ...@@ -260,6 +273,7 @@ def fused_attn(
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
): ):
...@@ -347,6 +361,7 @@ def fused_attn( ...@@ -347,6 +361,7 @@ def fused_attn(
is_training=is_training, is_training=is_training,
max_segments_per_seq=1, max_segments_per_seq=1,
window_size=window_size, window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
) )
...@@ -370,6 +385,7 @@ def fused_attn_thd( ...@@ -370,6 +385,7 @@ def fused_attn_thd(
is_training: bool, is_training: bool,
max_segments_per_seq: int = 1, max_segments_per_seq: int = 1,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
): ):
...@@ -470,6 +486,7 @@ def fused_attn_thd( ...@@ -470,6 +486,7 @@ def fused_attn_thd(
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq, max_segments_per_seq=max_segments_per_seq,
window_size=window_size, window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
) )
...@@ -477,7 +494,7 @@ def fused_attn_thd( ...@@ -477,7 +494,7 @@ def fused_attn_thd(
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _fused_attn( def _fused_attn(
qkv: Tuple[jnp.ndarray, ...], qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray], bias: Optional[jnp.ndarray],
...@@ -494,6 +511,7 @@ def _fused_attn( ...@@ -494,6 +511,7 @@ def _fused_attn(
is_training: bool, is_training: bool,
max_segments_per_seq: int, max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]], window_size: Optional[Tuple[int, int]],
context_parallel_strategy: CPStrategy,
context_parallel_causal_load_balanced: bool, context_parallel_causal_load_balanced: bool,
context_parallel_axis: str, context_parallel_axis: str,
): ):
...@@ -513,6 +531,7 @@ def _fused_attn( ...@@ -513,6 +531,7 @@ def _fused_attn(
is_training, is_training,
max_segments_per_seq, max_segments_per_seq,
window_size, window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced, context_parallel_causal_load_balanced,
context_parallel_axis, context_parallel_axis,
) )
...@@ -535,6 +554,7 @@ def _fused_attn_fwd_rule( ...@@ -535,6 +554,7 @@ def _fused_attn_fwd_rule(
is_training, is_training,
max_segments_per_seq, max_segments_per_seq,
window_size, window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced, context_parallel_causal_load_balanced,
context_parallel_axis, context_parallel_axis,
): ):
...@@ -554,6 +574,7 @@ def _fused_attn_fwd_rule( ...@@ -554,6 +574,7 @@ def _fused_attn_fwd_rule(
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq, max_segments_per_seq=max_segments_per_seq,
window_size=window_size, window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
) )
...@@ -582,6 +603,7 @@ def _fused_attn_bwd_rule( ...@@ -582,6 +603,7 @@ def _fused_attn_bwd_rule(
is_training, is_training,
max_segments_per_seq, max_segments_per_seq,
window_size, window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced, context_parallel_causal_load_balanced,
context_parallel_axis, context_parallel_axis,
ctx, ctx,
...@@ -617,6 +639,7 @@ def _fused_attn_bwd_rule( ...@@ -617,6 +639,7 @@ def _fused_attn_bwd_rule(
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq, max_segments_per_seq=max_segments_per_seq,
window_size=window_size, window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
) )
......
...@@ -167,3 +167,25 @@ def is_ffi_enabled(): ...@@ -167,3 +167,25 @@ def is_ffi_enabled():
is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1")) is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1"))
assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value" assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value"
return is_supported and is_enabled return is_supported and is_enabled
def get_xla_flag(flag: str, default=None, cast=str):
"""
Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value.
"""
xla_flags = []
if xla_flags_env := os.getenv("XLA_FLAGS"):
xla_flags.extend(xla_flags_env.split())
for flag_i in sorted(xla_flags):
if "=" in flag_i:
# option like --xla_abc=foo
name, val = flag_i.split("=", 2)
if name == flag:
return val if cast is None else cast(val)
else:
# flag like --xla_enable_foo
name, val = flag_i, None
if name == flag:
return True
return default
...@@ -197,7 +197,7 @@ class MeshResource: ...@@ -197,7 +197,7 @@ class MeshResource:
The axis name in Mesh used to split the batch and weights along. The axis name in Mesh used to split the batch and weights along.
If it is None, then full-sharded data parallelism is disabled. If it is None, then full-sharded data parallelism is disabled.
pp_resource : str, default = None pp_resource : str, default = None
The axis name in Mesh used to split model layers. along. The axis name in Mesh used to split model layers along.
If it is None, then pipeline parallelism is disabled. If it is None, then pipeline parallelism is disabled.
cp_resource : str, default = None cp_resource : str, default = None
The axis name in Mesh used to split sequence (context) dimensions along The axis name in Mesh used to split sequence (context) dimensions along
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment