Unverified Commit 20c75295 authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[JAX] Fix correctness of JAX fused attention with CP and improve numerics...


[JAX] Fix correctness of JAX fused attention with CP and improve numerics check in unit tests (#1282)

Fix correctness of JAX fused attention with CP.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent d9b4bfb5
......@@ -17,7 +17,13 @@ from distributed_test_base import (
generate_collectives_count,
compare_ops,
)
from utils import make_causal_mask, make_self_mask, assert_tree_like_allclose, assert_allclose
from utils import (
make_causal_mask,
make_self_mask,
assert_tree_like_allclose,
assert_allclose,
print_debug_tensor_stats,
)
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
......@@ -31,6 +37,8 @@ from transformer_engine.jax.attention import (
inverse_reorder_causal_load_balancing,
)
# We will use the golden reference model from our non distributed attention test fixture.
from test_fused_attn import general_dot_product_attention, make_mask
DTYPES = [jnp.float16, jnp.bfloat16]
......@@ -327,18 +335,27 @@ class TestDistributedCrossAttn:
)
class TestDistributedContexParallelSelfAttn:
class TestDistributedContextParallelSelfAttn:
def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
batch, seqlen, heads, hidden = shape
kv_shape = (batch, seqlen, heads // kv_groups, hidden)
qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3)
q = random.normal(qkey, shape, dtype=dtype)
k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
mask = None
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_causal_mask(batch, seqlen)
def gen_valid(bs, max_seqlen, pad_ratio):
pad_len = int(max_seqlen * pad_ratio)
valid_len = max_seqlen - pad_len
tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
return tokens, jnp.logical_not(tokens)
from test_fused_attn import make_mask
q_idx, _ = gen_valid(batch, seqlen, 0.0)
kv_idx, _ = gen_valid(batch, seqlen, 0.0)
mask = make_mask(q_idx, kv_idx, None, None, attn_mask_type)
return q, k, v, mask
......@@ -382,7 +399,8 @@ class TestDistributedContexParallelSelfAttn:
],
)
@pytest.mark.parametrize(
"load_balanced", [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")]
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
def test_contex_parallel_self_attn(
self,
......@@ -400,12 +418,12 @@ class TestDistributedContexParallelSelfAttn:
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
is_training = True
scaling_factor = 1.0
dp_size, cp_size, tp_size = mesh_shape
qkv_format = get_qkv_format(qkv_layout)
_, seqlen, num_head, hidden = data_shape
batch, seqlen, num_head, hidden = data_shape
num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)
if not is_fused_attn_kernel_available(
dtype,
......@@ -424,54 +442,69 @@ class TestDistributedContexParallelSelfAttn:
):
pytest.skip(f"No FusedAttn backend found")
if dp_size > 1 and batch % dp_size != 0:
pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")
# make sure the mesh even divides cp and tp axis
if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
def target_func(q, k, v, mask):
return jnp.mean(
fused_attn(
self.qkv_to_layout(q, k, v, qkv_layout),
bias=None,
mask=mask,
seed=None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
context_parallel_causal_load_balanced=load_balanced,
),
return fused_attn(
self.qkv_to_layout(q, k, v, qkv_layout),
None, # bias
mask,
None, # seed
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
context_parallel_causal_load_balanced=load_balanced,
context_parallel_axis="cp",
).astype(dtype)
def ref_func(q, k, v, mask, kv_groups):
q = jnp.squeeze(q)
k = jnp.squeeze(jnp.repeat(k, kv_groups, axis=2))
v = jnp.squeeze(jnp.repeat(v, kv_groups, axis=2))
output = dot_product_attention(
def ref_func(q, k, v, mask):
output = general_dot_product_attention(
q,
k,
v,
bias=None,
mask=mask,
deterministic=is_training,
deterministic=not is_training,
scale_factor=scaling_factor,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype)
return output.astype(dtype)
def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
_, max_seq_len, num_heads, _ = data_shape
gradient_multiplier = max_seq_len * num_heads
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]:
gradient_multiplier /= 10
ret_valid = func(*args, **kwargs)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype)
q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype)
diff_argnums = (0, 1, 2)
# Single GPU (reference)
ref_func_jit = jax.jit(jax.value_and_grad(ref_func, argnums=[0, 1, 2]), static_argnums=[4])
ref_fwd, ref_grads = ref_func_jit(q, k, v, mask, kv_groups)
ref_func_jit = jax.jit(
jax.value_and_grad(
lambda q, k, v, mask: grad_func(ref_func, q, k, v, mask), argnums=diff_argnums
)
)
ref_fwd, ref_grads = ref_func_jit(q, k, v, mask)
# Multi GPU (function under test)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
with mesh, fp8_autocast(mesh_resource=mesh_resource, enabled=False):
qkv_ps = PartitionSpec(
mesh_resource.dp_resource,
mesh_resource.cp_resource,
......@@ -499,7 +532,10 @@ class TestDistributedContexParallelSelfAttn:
mask_ = jax.device_put(mask, device=mask_sharding)
target_func_jit = jax.jit(
jax.value_and_grad(target_func, argnums=[0, 1, 2]),
jax.value_and_grad(
lambda q, k, v, mask: grad_func(target_func, q, k, v, mask),
argnums=diff_argnums,
),
in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding],
out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)),
)
......@@ -510,37 +546,25 @@ class TestDistributedContexParallelSelfAttn:
target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3])
target_grads = (target_dq, target_dk, target_dv, *target_grads[3:])
def _print_diffs(target, ref):
print("min: ", jnp.min(target), jnp.min(ref))
print("max: ", jnp.max(target), jnp.max(ref))
print("mean: ", jnp.mean(target), jnp.mean(ref))
print("median: ", jnp.median(target), jnp.median(ref))
print("std: ", jnp.std(target), jnp.std(ref))
print("var: ", jnp.var(target), jnp.var(ref))
print("max diff: ", jnp.max(jnp.abs(target - ref)))
has_diffs = False
try:
assert_allclose(target_fwd, ref_fwd, dtype=dtype)
except AssertionError as e:
has_diffs = True
print(f"target_fwd v. ref_fwd")
_print_diffs(target_fwd, ref_fwd)
print_debug_tensor_stats("target", target_fwd)
print_debug_tensor_stats("ref", ref_fwd)
print_debug_tensor_stats("diff", jnp.abs(target_fwd - ref_fwd))
assert_allclose(target_fwd, ref_fwd, dtype=dtype)
for i in range(len(target_grads)):
if ref_grads[i] is None or target_grads[i] is None:
# expect both none if one is
assert target_grads[i] is None and ref_grads[i] is None
else:
try:
assert_allclose(target_grads[i], ref_grads[i])
except AssertionError as e:
has_diffs = True
print(f"target_grads[{i}] v. ref_grads[{i}]")
_print_diffs(target_grads[i], ref_grads[i])
assert has_diffs == False, "has_diffs != False"
print_debug_tensor_stats(f"target_grad[{i}]", target_grads[i])
print_debug_tensor_stats(f"ref_grad[{i}]", ref_grads[i])
print_debug_tensor_stats(
f"diff_grad[{i}]", jnp.abs(target_grads[i] - ref_grads[i])
)
assert_allclose(target_grads[i], ref_grads[i], dtype=dtype)
class TestReorderCausalLoadBalancing:
......
......@@ -7,6 +7,7 @@ import functools
import math
import operator
from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional
import os
import jax
import jax.numpy as jnp
......@@ -30,6 +31,9 @@ PrecisionLike = Union[
]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
# Enables verbose printing of tensor numerics for debug.
NVTE_DEBUG_NUMERICS = bool(int(os.getenv("NVTE_DEBUG_NUMERICS", 0)))
def is_devices_enough(required):
"""
......@@ -1466,3 +1470,23 @@ def sync_params_values(dst, src, transformations, sep="/"):
synced_dst = jax.tree_util.tree_unflatten(dst_tree_def, synced_dst_values)
return jax.tree_util.tree_map(lambda x, y: x.reshape(y.shape), synced_dst, dst)
@functools.partial(jax.jit, static_argnums=[0, 2])
def print_debug_tensor_stats(prefix, tensor, hist=False):
if NVTE_DEBUG_NUMERICS:
args = [
jnp.mean(tensor),
jnp.min(tensor),
jnp.max(tensor),
jnp.cumprod(jnp.array(tensor.shape))[-1] if len(tensor.shape) >= 1 else 1,
jnp.count_nonzero(tensor),
]
fmt = prefix + " mean={}, min={}, max={}, numel={}, nzcnt={}"
if hist:
h = jnp.histogram(tensor.astype(jnp.float32), bins=10)
args += [h[0], h[1]]
fmt = fmt + "\n {}\n {}"
jax.debug.print(fmt, *args)
......@@ -242,73 +242,16 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
return batch, q_max_seqlen, kv_max_seqlen
def _reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat, inverse: bool):
match tensor_format:
case QKVFormat.SBHD:
seq_dim = 0
case QKVFormat.BSHD:
seq_dim = 1
case _:
raise ValueError(f"{tensor_format=} is not supported for causal load balancing.")
if cp_size == 1:
return tensor
if cp_size % 2 != 0:
raise ValueError(f"{cp_size=} must be a multiple of 2.")
# Need to ensure we have 2 pairs to swap for balancing between cp ranks
if tensor.shape[seq_dim] % (cp_size * 2) != 0:
raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}")
# [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D]
# [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D]
ori_tensor_shape = tensor.shape
tensor = tensor.reshape(
(
*ori_tensor_shape[:seq_dim],
2 * cp_size,
ori_tensor_shape[seq_dim] // (2 * cp_size),
*ori_tensor_shape[seq_dim + 1 :],
)
)
parts = []
if not inverse:
for cp_rank in range(cp_size):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)])
parts.append(jnp.take(tensor, index, axis=seq_dim))
else:
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 4 * cp_rank
index = jnp.array([base, base + 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 2 * cp_size - 1 - 4 * cp_rank
index = jnp.array([base, base - 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D]
combined = jnp.stack(parts, axis=seq_dim)
return combined.reshape(ori_tensor_shape)
def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
"""Reorders a tensor for load balancing the compute of causal attention."""
return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, False)
seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, False)
def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
"""Inverse operation of `reorder_causal_load_balancing`."""
return _reorder_causal_load_balancing(tensor, cp_size, tensor_format, True)
seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True)
def fused_attn(
......
......@@ -911,6 +911,58 @@ class FusedAttnBwdPrimitive(BasePrimitive):
register_primitive(FusedAttnBwdPrimitive)
def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contiguous: bool):
"""Reorders a tensor for load balancing the compute of causal attention."""
if cp_size == 1:
return tensor
if cp_size % 2 != 0:
raise ValueError(f"{cp_size=} must be a multiple of 2.")
# Need to ensure we have 2 pairs to swap for balancing between cp ranks
if tensor.shape[seq_dim] % (cp_size * 2) != 0:
raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}")
# [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D]
# [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D]
ori_tensor_shape = tensor.shape
tensor = tensor.reshape(
(
*ori_tensor_shape[:seq_dim],
2 * cp_size,
ori_tensor_shape[seq_dim] // (2 * cp_size),
*ori_tensor_shape[seq_dim + 1 :],
)
)
parts = []
if not to_contiguous:
for cp_rank in range(cp_size):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)])
parts.append(jnp.take(tensor, index, axis=seq_dim))
else:
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 4 * cp_rank
index = jnp.array([base, base + 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 2 * cp_size - 1 - 4 * cp_rank
index = jnp.array([base, base - 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D]
combined = jnp.stack(parts, axis=seq_dim)
return combined.reshape(ori_tensor_shape)
@dataclass(frozen=True)
class _FusedAttnCPWithAllGatherHelper:
"""Helper class to assist with running the all-gather strategy for CP attention."""
......@@ -954,13 +1006,32 @@ class _FusedAttnCPWithAllGatherHelper:
return NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
return self.config.attn_mask_type
def get_step_config(self) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call to fused attention."""
return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type,
attn_mask_type=self.get_adjusted_mask(),
qkv_layout=self.config.qkv_layout,
scaling_factor=self.config.scaling_factor,
dropout_probability=self.config.dropout_probability,
is_training=self.config.is_training,
max_segments_per_seq=self.config.max_segments_per_seq,
window_size=self.config.window_size,
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
)
def all_gather_kv(self, k, v):
"""Performs a all-gather of k and v over context parallel ranks."""
def ag(x):
return lax_paral_op(
x = lax_paral_op(
x, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
)
if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=True)
return x
match self.config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
......@@ -974,6 +1045,10 @@ class _FusedAttnCPWithAllGatherHelper:
"""Performs a reduce-scatter of dk and dv over context parallel ranks."""
def rs(x):
if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=False)
return lax_paral_op(
x,
lax.psum_scatter,
......@@ -1078,7 +1153,6 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
def impl(q, k, v, bias, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, seed):
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
......@@ -1120,7 +1194,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
q_seq_offsets,
k_seq_offsets,
seed,
config=config,
config=helper.get_step_config(),
)
results.append((output, softmax_aux, rng_state))
......@@ -1237,7 +1311,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
kv_seqlen_for_step,
q_seq_offsets,
k_seq_offsets,
config=config,
config=helper.get_step_config(),
)
# pad dk/dv to be unsliced shape so we can reduce scatter over all ranks.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment