Unverified Commit 9101a78f authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[JAX] Context Parallel Attention with All-Gather (#1106)



Implementation of context parallel fused attention using all-gather.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent d2d4cf91
......@@ -4,6 +4,8 @@
import operator
import re
from functools import reduce
from itertools import product
import pytest
import jax
from jax.experimental.pjit import pjit, _UNSPECIFIED
......@@ -29,6 +31,28 @@ def generate_configs():
return configs
def generate_context_parallel_configs():
configs = []
DP_sizes = (1, 2)
CP_sizes = (1, 2, 4, 8)
TP_sizes = (1, 2)
for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes):
ndev = cp * tp * dp
if is_devices_enough(ndev):
configs.append(
pytest.param(
ndev,
(dp, cp, tp),
("dp", "cp", "tp"),
MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"),
id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}",
)
)
return configs
COLL_AR_KEY = "all-reduce"
COLL_AG_KEY = "all-gather"
COLL_OTHER_KEY = "other"
......
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
import pytest
from functools import partial
import jax
import jax.numpy as jnp
......@@ -10,8 +11,13 @@ import numpy as np
from flax.linen import dot_product_attention
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count, compare_ops
from utils import make_causal_mask, make_self_mask
from distributed_test_base import (
generate_configs,
generate_context_parallel_configs,
generate_collectives_count,
compare_ops,
)
from utils import make_causal_mask, make_self_mask, assert_tree_like_allclose, assert_allclose
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
......@@ -19,6 +25,10 @@ from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
get_qkv_format,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
)
......@@ -263,7 +273,8 @@ class TestDistributedCrossAttn:
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
)
),
dtype=jnp.float32,
)
def ref_func(query, kv, mask):
......@@ -284,7 +295,7 @@ class TestDistributedCrossAttn:
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype)
return jnp.mean(output, dtype=jnp.float32)
(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, attn_mask_type, dtype
......@@ -310,3 +321,229 @@ class TestDistributedCrossAttn:
in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec)),
)
class TestDistributedContexParallelSelfAttn:
def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
batch, seqlen, heads, hidden = shape
qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3)
q = random.normal(qkey, shape, dtype=dtype)
k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
mask = None
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_causal_mask(batch, seqlen)
return q, k, v, mask
def qkv_to_layout(self, q, k, v, qkv_layout):
qkv_args = ()
match qkv_layout:
case QKVLayout.BSHD_BS2HD:
k, v = map(partial(jnp.expand_dims, axis=-3), [k, v])
kv = jnp.concatenate((k, v), axis=-3)
qkv_args = (q, kv)
case QKVLayout.BSHD_BSHD_BSHD:
qkv_args = (q, k, v)
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
return qkv_args
@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"load_balanced", [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")]
)
def test_contex_parallel_self_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
is_training = True
scaling_factor = 1.0
dp_size, cp_size, tp_size = mesh_shape
qkv_format = get_qkv_format(qkv_layout)
_, seqlen, num_head, hidden = data_shape
num_kv_heads = num_head // kv_groups
# make sure the mesh evently divides cp and tp axis
if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
def target_func(q, k, v, mask):
return jnp.mean(
fused_attn(
self.qkv_to_layout(q, k, v, qkv_layout),
bias=None,
mask=mask,
seed=None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
context_parallel_causal_load_balanced=load_balanced,
),
).astype(dtype)
def ref_func(q, k, v, mask, kv_groups):
q = jnp.squeeze(q)
k = jnp.squeeze(jnp.repeat(k, kv_groups, axis=2))
v = jnp.squeeze(jnp.repeat(v, kv_groups, axis=2))
output = dot_product_attention(
q,
k,
v,
bias=None,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype)
q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype)
# Single GPU (reference)
ref_func_jit = jax.jit(jax.value_and_grad(ref_func, argnums=[0, 1, 2]), static_argnums=[4])
ref_fwd, ref_grads = ref_func_jit(q, k, v, mask, kv_groups)
# Multi GPU (function under test)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
qkv_ps = PartitionSpec(
mesh_resource.dp_resource,
mesh_resource.cp_resource,
mesh_resource.tp_resource,
None,
)
qkv_sharding = NamedSharding(mesh, qkv_ps)
mask_ps = PartitionSpec(
mesh_resource.dp_resource, None, mesh_resource.cp_resource, None
)
mask_sharding = NamedSharding(mesh, mask_ps)
reorder = partial(
reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
)
inverse_reorder = partial(
inverse_reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
)
if load_balanced:
q, k, v = jax.tree.map(reorder, (q, k, v))
q_, k_, v_ = map(partial(jax.device_put, device=qkv_sharding), [q, k, v])
mask_ = jax.device_put(mask, device=mask_sharding)
target_func_jit = jax.jit(
jax.value_and_grad(target_func, argnums=[0, 1, 2]),
in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding],
out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)),
)
target_fwd, target_grads = target_func_jit(q_, k_, v_, mask_)
if load_balanced:
target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3])
target_grads = (target_dq, target_dk, target_dv, *target_grads[3:])
def _print_diffs(target, ref):
print("min: ", jnp.min(target), jnp.min(ref))
print("max: ", jnp.max(target), jnp.max(ref))
print("mean: ", jnp.mean(target), jnp.mean(ref))
print("median: ", jnp.median(target), jnp.median(ref))
print("std: ", jnp.std(target), jnp.std(ref))
print("var: ", jnp.var(target), jnp.var(ref))
print("max diff: ", jnp.max(jnp.abs(target - ref)))
has_diffs = False
try:
assert_allclose(target_fwd, ref_fwd, dtype=dtype)
except AssertionError as e:
has_diffs = True
print(f"target_fwd v. ref_fwd")
_print_diffs(target_fwd, ref_fwd)
for i in range(len(target_grads)):
if ref_grads[i] is None or target_grads[i] is None:
# expect both none if one is
assert target_grads[i] is None and ref_grads[i] is None
else:
try:
assert_allclose(target_grads[i], ref_grads[i])
except AssertionError as e:
has_diffs = True
print(f"target_grads[{i}] v. ref_grads[{i}]")
_print_diffs(target_grads[i], ref_grads[i])
assert has_diffs == False, "has_diffs != False"
class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8])
@pytest.mark.parametrize(
"shape",
[
pytest.param([1, 16, 1, 1], id="1-16-1-1"),
pytest.param([4, 32, 12, 32], id="4-32-12-32"),
pytest.param([3, 32, 8, 64], id="3-32-8-64"),
],
)
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
def test(self, cp_size, shape, qkv_format):
tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
if qkv_format == QKVFormat.SBHD:
tensor = tensor.swapaxes(0, 1)
ref = tensor.copy()
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2])
reordered = reorder(tensor, cp_size, qkv_format)
inversed = inverse(reordered, cp_size, qkv_format)
assert jnp.array_equal(inversed, ref)
......@@ -43,6 +43,8 @@ class AttnMaskType(Enum):
PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK
CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK
PADDING_CAUSAL_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
class QKVLayout(Enum):
......@@ -97,11 +99,21 @@ def canonicalize_attn_mask_type(attn_mask_type: str):
return AttnMaskType.PADDING_MASK
case "causal":
return AttnMaskType.CAUSAL_MASK
case "causal_bottom_right" | "bottom_right_causal":
return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
case "padding_causal" | "causal_padding":
return AttnMaskType.PADDING_CAUSAL_MASK
case (
"padding_causal_bottom_right"
| "causal_padding_bottom_right"
| "bottom_right_causal_padding"
| "bottom_right_padding_causal"
):
return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK
raise ValueError(
f"Unsupported {attn_mask_type=}, supported attn_mask_type="
"{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}"
f"Unsupported {attn_mask_type=}, supported attn_mask_type={{'no_mask', 'padding', 'causal',"
" 'padding_causal', 'causal_padding', 'causal_bottom_right',"
" 'padding_causal_bottom_right'}"
)
......@@ -155,6 +167,75 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
return batch, q_max_seqlen, kv_max_seqlen
def _reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat, inverse: bool):
match tensor_format:
case QKVFormat.SBHD:
seq_dim = 0
case QKVFormat.BSHD:
seq_dim = 1
case _:
raise ValueError(f"{tensor_format=} is not supported for causal load balancing.")
if cp_size == 1:
return tensor
if cp_size % 2 != 0:
raise ValueError(f"{cp_size=} must be a multiple of 2.")
# Need to ensure we have 2 pairs to swap for balancing between cp ranks
if tensor.shape[seq_dim] % (cp_size * 2) != 0:
raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}")
# [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D]
# [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D]
ori_tensor_shape = tensor.shape
tensor = tensor.reshape(
(
*ori_tensor_shape[:seq_dim],
2 * cp_size,
ori_tensor_shape[seq_dim] // (2 * cp_size),
*ori_tensor_shape[seq_dim + 1 :],
)
)
parts = []
if not inverse:
for cp_rank in range(cp_size):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)])
parts.append(jnp.take(tensor, index, axis=seq_dim))
else:
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 4 * cp_rank
index = jnp.array([base, base + 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 2 * cp_size - 1 - 4 * cp_rank
index = jnp.array([base, base - 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D]
combined = jnp.stack(parts, axis=seq_dim)
return combined.reshape(ori_tensor_shape)
def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
"""Reorders a tensor for load balancing the compute of causal attention."""
return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, False)
def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
"""Inverse operation of `reorder_causal_load_balancing`."""
return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, True)
def fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
......@@ -166,6 +247,8 @@ def fused_attn(
scaling_factor: float,
dropout_probability: float,
is_training: bool,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
"""
Perform non-THD (non-packed) cuDNN fused attention.
......@@ -192,6 +275,9 @@ def fused_attn(
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
"""
......@@ -213,7 +299,11 @@ def fused_attn(
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
# convert the mask to seqlens, mask doesn't support ragged offsets
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
if attn_mask_type in [
AttnMaskType.NO_MASK,
AttnMaskType.CAUSAL_MASK,
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
]:
batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32)
kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32)
......@@ -242,6 +332,8 @@ def fused_attn(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=1,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
return output
......@@ -262,6 +354,8 @@ def fused_attn_thd(
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int = 1,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
"""
(Experimental) Perform THD (packed) cuDNN fused attention.
......@@ -300,6 +394,9 @@ def fused_attn_thd(
Indicating the maximum number of segments inside a sequence. This parameter is to
constrain the limit usage and need to be static during the e2e training. The XLA compile
time and memory consumption is proportional to `max_segments_per_seq`.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
......@@ -354,12 +451,14 @@ def fused_attn_thd(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
......@@ -375,6 +474,8 @@ def _fused_attn(
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int,
context_parallel_causal_load_balanced: bool,
context_parallel_axis: str,
):
output, _ = _fused_attn_fwd_rule(
qkv,
......@@ -391,6 +492,8 @@ def _fused_attn(
dropout_probability,
is_training,
max_segments_per_seq,
context_parallel_causal_load_balanced,
context_parallel_axis,
)
return output
......@@ -410,6 +513,8 @@ def _fused_attn_fwd_rule(
dropout_probability,
is_training,
max_segments_per_seq,
context_parallel_causal_load_balanced,
context_parallel_axis,
):
output, softmax_aux, rng_state = tex.fused_attn_fwd(
qkv,
......@@ -426,6 +531,8 @@ def _fused_attn_fwd_rule(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
output = checkpoint_name(output, "context")
softmax_aux = checkpoint_name(softmax_aux, "context")
......@@ -451,6 +558,8 @@ def _fused_attn_bwd_rule(
dropout_probability,
is_training,
max_segments_per_seq,
context_parallel_causal_load_balanced,
context_parallel_axis,
ctx,
dz,
):
......@@ -483,6 +592,8 @@ def _fused_attn_bwd_rule(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
......
......@@ -100,7 +100,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK);
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK",
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
......
......@@ -20,6 +20,7 @@ _PXLA_THREAD_RESOURCES = pxla.thread_resources
BATCH_AXES = "nvte_batch"
SEQLEN_AXES = "nvte_seqlen"
SEQLEN_TP_AXES = "nvte_seqlen_tp"
SEQLEN_CP_AXES = "nvte_seqlen_cp"
HEAD_AXES = "nvte_head"
HIDDEN_AXES = "nvte_hidden"
HIDDEN_TP_AXES = "nvte_hidden_tp"
......@@ -65,6 +66,7 @@ def get_sharding_map_logic_axis_to_mesh_axis():
BATCH_AXES: batch_dim_rule,
SEQLEN_AXES: None,
SEQLEN_TP_AXES: gsr.tp_resource,
SEQLEN_CP_AXES: gsr.cp_resource,
HEAD_AXES: gsr.tp_resource,
HIDDEN_AXES: None,
HIDDEN_TP_AXES: gsr.tp_resource,
......@@ -131,13 +133,15 @@ def get_padded_spec(spec, ndim):
return spec + (None,) * (ndim - len(spec))
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh):
def lax_paral_op(
x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh, **kwargs
):
"""
A wrapper function to invoke lax.p* operations, like psum.
"""
if mesh_resource is not None:
_, resource = _get_mesh_info(mesh_resource, mesh)
return ops(x, resource)
return ops(x, resource, **kwargs)
return x
......@@ -148,6 +152,33 @@ def num_of_devices():
return len(jax.devices())
def get_mesh_axis_size(axis, mesh=None):
"""
Get the axis size of the given mesh.
If the mesh is None, it would be replaced
by the global mesh.
"""
if mesh is None:
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if axis is None:
return 1
assert axis in mesh.shape, f"{axis} is not a axis of the given mesh {mesh.shape}"
return mesh.shape[axis]
def get_mesh_axis_rank(axis: str, mesh=None):
"""
Gets the local axis rank of the `axis` of the array.
If the mesh is None the rank is 0.
"""
if mesh is None:
return 0
_, axis_name = _get_mesh_info(axis, mesh)
return jax.lax.axis_index(axis_name)
@dataclass
class MeshResource:
"""
......@@ -168,12 +199,16 @@ class MeshResource:
pp_resource : str, default = None
The axis name in Mesh used to split model layers. along.
If it is None, then pipeline parallelism is disabled.
cp_resource : str, default = None
The axis name in Mesh used to split sequence (context) dimensions along
in the attention. If it is None, then context parallelism is disabled.
"""
dp_resource: str = None
tp_resource: str = None
fsdp_resource: str = None
pp_resource: str = None
cp_resource: str = None
_GLOBAL_MESH_RESOURCE = MeshResource()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment