Unverified Commit 20c75295 authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[JAX] Fix correctness of JAX fused attention with CP and improve numerics...


[JAX] Fix correctness of JAX fused attention with CP and improve numerics check in unit tests (#1282)

Fix correctness of JAX fused attention with CP.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent d9b4bfb5
...@@ -17,7 +17,13 @@ from distributed_test_base import ( ...@@ -17,7 +17,13 @@ from distributed_test_base import (
generate_collectives_count, generate_collectives_count,
compare_ops, compare_ops,
) )
from utils import make_causal_mask, make_self_mask, assert_tree_like_allclose, assert_allclose from utils import (
make_causal_mask,
make_self_mask,
assert_tree_like_allclose,
assert_allclose,
print_debug_tensor_stats,
)
from transformer_engine.jax import fp8_autocast from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
is_fused_attn_kernel_available, is_fused_attn_kernel_available,
...@@ -31,6 +37,8 @@ from transformer_engine.jax.attention import ( ...@@ -31,6 +37,8 @@ from transformer_engine.jax.attention import (
inverse_reorder_causal_load_balancing, inverse_reorder_causal_load_balancing,
) )
# We will use the golden reference model from our non distributed attention test fixture.
from test_fused_attn import general_dot_product_attention, make_mask
DTYPES = [jnp.float16, jnp.bfloat16] DTYPES = [jnp.float16, jnp.bfloat16]
...@@ -327,18 +335,27 @@ class TestDistributedCrossAttn: ...@@ -327,18 +335,27 @@ class TestDistributedCrossAttn:
) )
class TestDistributedContexParallelSelfAttn: class TestDistributedContextParallelSelfAttn:
def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype): def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
batch, seqlen, heads, hidden = shape batch, seqlen, heads, hidden = shape
kv_shape = (batch, seqlen, heads // kv_groups, hidden)
qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3) qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3)
q = random.normal(qkey, shape, dtype=dtype) q = random.normal(qkey, shape, dtype=dtype)
k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
mask = None def gen_valid(bs, max_seqlen, pad_ratio):
if attn_mask_type == AttnMaskType.CAUSAL_MASK: pad_len = int(max_seqlen * pad_ratio)
mask = make_causal_mask(batch, seqlen) valid_len = max_seqlen - pad_len
tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
return tokens, jnp.logical_not(tokens)
from test_fused_attn import make_mask
q_idx, _ = gen_valid(batch, seqlen, 0.0)
kv_idx, _ = gen_valid(batch, seqlen, 0.0)
mask = make_mask(q_idx, kv_idx, None, None, attn_mask_type)
return q, k, v, mask return q, k, v, mask
...@@ -382,7 +399,8 @@ class TestDistributedContexParallelSelfAttn: ...@@ -382,7 +399,8 @@ class TestDistributedContexParallelSelfAttn:
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"load_balanced", [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")] "load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
) )
def test_contex_parallel_self_attn( def test_contex_parallel_self_attn(
self, self,
...@@ -400,12 +418,12 @@ class TestDistributedContexParallelSelfAttn: ...@@ -400,12 +418,12 @@ class TestDistributedContexParallelSelfAttn:
attn_bias_type = AttnBiasType.NO_BIAS attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
scaling_factor = 1.0
dp_size, cp_size, tp_size = mesh_shape dp_size, cp_size, tp_size = mesh_shape
qkv_format = get_qkv_format(qkv_layout) qkv_format = get_qkv_format(qkv_layout)
_, seqlen, num_head, hidden = data_shape batch, seqlen, num_head, hidden = data_shape
num_kv_heads = num_head // kv_groups num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)
if not is_fused_attn_kernel_available( if not is_fused_attn_kernel_available(
dtype, dtype,
...@@ -424,54 +442,69 @@ class TestDistributedContexParallelSelfAttn: ...@@ -424,54 +442,69 @@ class TestDistributedContexParallelSelfAttn:
): ):
pytest.skip(f"No FusedAttn backend found") pytest.skip(f"No FusedAttn backend found")
if dp_size > 1 and batch % dp_size != 0:
pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")
# make sure the mesh even divides cp and tp axis # make sure the mesh even divides cp and tp axis
if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0: if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
def target_func(q, k, v, mask): def target_func(q, k, v, mask):
return jnp.mean( return fused_attn(
fused_attn( self.qkv_to_layout(q, k, v, qkv_layout),
self.qkv_to_layout(q, k, v, qkv_layout), None, # bias
bias=None, mask,
mask=mask, None, # seed
seed=None, attn_bias_type=attn_bias_type,
attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type,
attn_mask_type=attn_mask_type, qkv_layout=qkv_layout,
qkv_layout=qkv_layout, scaling_factor=scaling_factor,
scaling_factor=scaling_factor, dropout_probability=dropout_prob,
dropout_probability=dropout_prob, is_training=is_training,
is_training=is_training, context_parallel_causal_load_balanced=load_balanced,
context_parallel_causal_load_balanced=load_balanced, context_parallel_axis="cp",
),
).astype(dtype) ).astype(dtype)
def ref_func(q, k, v, mask, kv_groups): def ref_func(q, k, v, mask):
q = jnp.squeeze(q) output = general_dot_product_attention(
k = jnp.squeeze(jnp.repeat(k, kv_groups, axis=2))
v = jnp.squeeze(jnp.repeat(v, kv_groups, axis=2))
output = dot_product_attention(
q, q,
k, k,
v, v,
bias=None, bias=None,
mask=mask, mask=mask,
deterministic=is_training, deterministic=not is_training,
scale_factor=scaling_factor,
dropout_rate=dropout_prob, dropout_rate=dropout_prob,
dropout_rng=None, dropout_rng=None,
dtype=jnp.float32, dtype=jnp.float32,
) )
return jnp.mean(output).astype(dtype) return output.astype(dtype)
def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
_, max_seq_len, num_heads, _ = data_shape
gradient_multiplier = max_seq_len * num_heads
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]:
gradient_multiplier /= 10
ret_valid = func(*args, **kwargs)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype)
q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype) q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype)
diff_argnums = (0, 1, 2)
# Single GPU (reference) # Single GPU (reference)
ref_func_jit = jax.jit(jax.value_and_grad(ref_func, argnums=[0, 1, 2]), static_argnums=[4]) ref_func_jit = jax.jit(
ref_fwd, ref_grads = ref_func_jit(q, k, v, mask, kv_groups) jax.value_and_grad(
lambda q, k, v, mask: grad_func(ref_func, q, k, v, mask), argnums=diff_argnums
)
)
ref_fwd, ref_grads = ref_func_jit(q, k, v, mask)
# Multi GPU (function under test) # Multi GPU (function under test)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource, enabled=False):
qkv_ps = PartitionSpec( qkv_ps = PartitionSpec(
mesh_resource.dp_resource, mesh_resource.dp_resource,
mesh_resource.cp_resource, mesh_resource.cp_resource,
...@@ -499,7 +532,10 @@ class TestDistributedContexParallelSelfAttn: ...@@ -499,7 +532,10 @@ class TestDistributedContexParallelSelfAttn:
mask_ = jax.device_put(mask, device=mask_sharding) mask_ = jax.device_put(mask, device=mask_sharding)
target_func_jit = jax.jit( target_func_jit = jax.jit(
jax.value_and_grad(target_func, argnums=[0, 1, 2]), jax.value_and_grad(
lambda q, k, v, mask: grad_func(target_func, q, k, v, mask),
argnums=diff_argnums,
),
in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding], in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding],
out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)), out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)),
) )
...@@ -510,37 +546,25 @@ class TestDistributedContexParallelSelfAttn: ...@@ -510,37 +546,25 @@ class TestDistributedContexParallelSelfAttn:
target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3]) target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3])
target_grads = (target_dq, target_dk, target_dv, *target_grads[3:]) target_grads = (target_dq, target_dk, target_dv, *target_grads[3:])
def _print_diffs(target, ref):
print("min: ", jnp.min(target), jnp.min(ref))
print("max: ", jnp.max(target), jnp.max(ref))
print("mean: ", jnp.mean(target), jnp.mean(ref))
print("median: ", jnp.median(target), jnp.median(ref))
print("std: ", jnp.std(target), jnp.std(ref))
print("var: ", jnp.var(target), jnp.var(ref))
print("max diff: ", jnp.max(jnp.abs(target - ref)))
has_diffs = False has_diffs = False
try: print_debug_tensor_stats("target", target_fwd)
assert_allclose(target_fwd, ref_fwd, dtype=dtype) print_debug_tensor_stats("ref", ref_fwd)
except AssertionError as e: print_debug_tensor_stats("diff", jnp.abs(target_fwd - ref_fwd))
has_diffs = True assert_allclose(target_fwd, ref_fwd, dtype=dtype)
print(f"target_fwd v. ref_fwd")
_print_diffs(target_fwd, ref_fwd)
for i in range(len(target_grads)): for i in range(len(target_grads)):
if ref_grads[i] is None or target_grads[i] is None: if ref_grads[i] is None or target_grads[i] is None:
# expect both none if one is # expect both none if one is
assert target_grads[i] is None and ref_grads[i] is None assert target_grads[i] is None and ref_grads[i] is None
else: else:
try: print_debug_tensor_stats(f"target_grad[{i}]", target_grads[i])
assert_allclose(target_grads[i], ref_grads[i]) print_debug_tensor_stats(f"ref_grad[{i}]", ref_grads[i])
except AssertionError as e: print_debug_tensor_stats(
has_diffs = True f"diff_grad[{i}]", jnp.abs(target_grads[i] - ref_grads[i])
print(f"target_grads[{i}] v. ref_grads[{i}]") )
_print_diffs(target_grads[i], ref_grads[i])
assert_allclose(target_grads[i], ref_grads[i], dtype=dtype)
assert has_diffs == False, "has_diffs != False"
class TestReorderCausalLoadBalancing: class TestReorderCausalLoadBalancing:
......
...@@ -7,6 +7,7 @@ import functools ...@@ -7,6 +7,7 @@ import functools
import math import math
import operator import operator
from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional
import os
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -30,6 +31,9 @@ PrecisionLike = Union[ ...@@ -30,6 +31,9 @@ PrecisionLike = Union[
] ]
Initializer = Callable[[PRNGKey, Shape, DType], Array] Initializer = Callable[[PRNGKey, Shape, DType], Array]
# Enables verbose printing of tensor numerics for debug.
NVTE_DEBUG_NUMERICS = bool(int(os.getenv("NVTE_DEBUG_NUMERICS", 0)))
def is_devices_enough(required): def is_devices_enough(required):
""" """
...@@ -1466,3 +1470,23 @@ def sync_params_values(dst, src, transformations, sep="/"): ...@@ -1466,3 +1470,23 @@ def sync_params_values(dst, src, transformations, sep="/"):
synced_dst = jax.tree_util.tree_unflatten(dst_tree_def, synced_dst_values) synced_dst = jax.tree_util.tree_unflatten(dst_tree_def, synced_dst_values)
return jax.tree_util.tree_map(lambda x, y: x.reshape(y.shape), synced_dst, dst) return jax.tree_util.tree_map(lambda x, y: x.reshape(y.shape), synced_dst, dst)
@functools.partial(jax.jit, static_argnums=[0, 2])
def print_debug_tensor_stats(prefix, tensor, hist=False):
if NVTE_DEBUG_NUMERICS:
args = [
jnp.mean(tensor),
jnp.min(tensor),
jnp.max(tensor),
jnp.cumprod(jnp.array(tensor.shape))[-1] if len(tensor.shape) >= 1 else 1,
jnp.count_nonzero(tensor),
]
fmt = prefix + " mean={}, min={}, max={}, numel={}, nzcnt={}"
if hist:
h = jnp.histogram(tensor.astype(jnp.float32), bins=10)
args += [h[0], h[1]]
fmt = fmt + "\n {}\n {}"
jax.debug.print(fmt, *args)
...@@ -242,73 +242,16 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout): ...@@ -242,73 +242,16 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
return batch, q_max_seqlen, kv_max_seqlen return batch, q_max_seqlen, kv_max_seqlen
def _reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat, inverse: bool):
match tensor_format:
case QKVFormat.SBHD:
seq_dim = 0
case QKVFormat.BSHD:
seq_dim = 1
case _:
raise ValueError(f"{tensor_format=} is not supported for causal load balancing.")
if cp_size == 1:
return tensor
if cp_size % 2 != 0:
raise ValueError(f"{cp_size=} must be a multiple of 2.")
# Need to ensure we have 2 pairs to swap for balancing between cp ranks
if tensor.shape[seq_dim] % (cp_size * 2) != 0:
raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}")
# [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D]
# [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D]
ori_tensor_shape = tensor.shape
tensor = tensor.reshape(
(
*ori_tensor_shape[:seq_dim],
2 * cp_size,
ori_tensor_shape[seq_dim] // (2 * cp_size),
*ori_tensor_shape[seq_dim + 1 :],
)
)
parts = []
if not inverse:
for cp_rank in range(cp_size):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)])
parts.append(jnp.take(tensor, index, axis=seq_dim))
else:
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 4 * cp_rank
index = jnp.array([base, base + 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 2 * cp_size - 1 - 4 * cp_rank
index = jnp.array([base, base - 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D]
combined = jnp.stack(parts, axis=seq_dim)
return combined.reshape(ori_tensor_shape)
def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
"""Reorders a tensor for load balancing the compute of causal attention.""" """Reorders a tensor for load balancing the compute of causal attention."""
return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, False) seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, False)
def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
"""Inverse operation of `reorder_causal_load_balancing`.""" """Inverse operation of `reorder_causal_load_balancing`."""
return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, True) seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True)
def fused_attn( def fused_attn(
......
...@@ -911,6 +911,58 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -911,6 +911,58 @@ class FusedAttnBwdPrimitive(BasePrimitive):
register_primitive(FusedAttnBwdPrimitive) register_primitive(FusedAttnBwdPrimitive)
def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contiguous: bool):
"""Reorders a tensor for load balancing the compute of causal attention."""
if cp_size == 1:
return tensor
if cp_size % 2 != 0:
raise ValueError(f"{cp_size=} must be a multiple of 2.")
# Need to ensure we have 2 pairs to swap for balancing between cp ranks
if tensor.shape[seq_dim] % (cp_size * 2) != 0:
raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}")
# [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D]
# [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D]
ori_tensor_shape = tensor.shape
tensor = tensor.reshape(
(
*ori_tensor_shape[:seq_dim],
2 * cp_size,
ori_tensor_shape[seq_dim] // (2 * cp_size),
*ori_tensor_shape[seq_dim + 1 :],
)
)
parts = []
if not to_contiguous:
for cp_rank in range(cp_size):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)])
parts.append(jnp.take(tensor, index, axis=seq_dim))
else:
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 4 * cp_rank
index = jnp.array([base, base + 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 2 * cp_size - 1 - 4 * cp_rank
index = jnp.array([base, base - 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D]
combined = jnp.stack(parts, axis=seq_dim)
return combined.reshape(ori_tensor_shape)
@dataclass(frozen=True) @dataclass(frozen=True)
class _FusedAttnCPWithAllGatherHelper: class _FusedAttnCPWithAllGatherHelper:
"""Helper class to assist with running the all-gather strategy for CP attention.""" """Helper class to assist with running the all-gather strategy for CP attention."""
...@@ -954,13 +1006,32 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -954,13 +1006,32 @@ class _FusedAttnCPWithAllGatherHelper:
return NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK return NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
return self.config.attn_mask_type return self.config.attn_mask_type
def get_step_config(self) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call to fused attention."""
return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type,
attn_mask_type=self.get_adjusted_mask(),
qkv_layout=self.config.qkv_layout,
scaling_factor=self.config.scaling_factor,
dropout_probability=self.config.dropout_probability,
is_training=self.config.is_training,
max_segments_per_seq=self.config.max_segments_per_seq,
window_size=self.config.window_size,
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
)
def all_gather_kv(self, k, v): def all_gather_kv(self, k, v):
"""Performs a all-gather of k and v over context parallel ranks.""" """Performs a all-gather of k and v over context parallel ranks."""
def ag(x): def ag(x):
return lax_paral_op( x = lax_paral_op(
x, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True x, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
) )
if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=True)
return x
match self.config.qkv_layout: match self.config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
...@@ -974,6 +1045,10 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -974,6 +1045,10 @@ class _FusedAttnCPWithAllGatherHelper:
"""Performs a reduce-scatter of dk and dv over context parallel ranks.""" """Performs a reduce-scatter of dk and dv over context parallel ranks."""
def rs(x): def rs(x):
if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=False)
return lax_paral_op( return lax_paral_op(
x, x,
lax.psum_scatter, lax.psum_scatter,
...@@ -1078,7 +1153,6 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1078,7 +1153,6 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
def impl(q, k, v, bias, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, seed): def impl(q, k, v, bias, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, seed):
cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
...@@ -1120,7 +1194,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1120,7 +1194,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
q_seq_offsets, q_seq_offsets,
k_seq_offsets, k_seq_offsets,
seed, seed,
config=config, config=helper.get_step_config(),
) )
results.append((output, softmax_aux, rng_state)) results.append((output, softmax_aux, rng_state))
...@@ -1237,7 +1311,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1237,7 +1311,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
kv_seqlen_for_step, kv_seqlen_for_step,
q_seq_offsets, q_seq_offsets,
k_seq_offsets, k_seq_offsets,
config=config, config=helper.get_step_config(),
) )
# pad dk/dv to be unsliced shape so we can reduce scatter over all ranks. # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment