Unverified Commit 6e848924 authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[JAX] Consolidate the distributed fused attention test code (#1405)



Consolidate the distributed fused attention tests to shared input generation and execition logic.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent c2937c5a
......@@ -18,14 +18,22 @@ from utils import assert_allclose, is_devices_enough
def generate_configs():
configs = []
if is_devices_enough(2):
configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")])
configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")])
configs.append(
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
)
configs.append(
pytest.param(2, (2,), ("tp",), MeshResource(tp_resource="tp"), id="n2_dp1_tp2")
)
if is_devices_enough(4):
TP_size = 2
DP_size = 2
configs.append(
[4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")]
pytest.param(
4,
(2, 2),
("dp", "tp"),
MeshResource(dp_resource="dp", tp_resource="tp"),
id=f"n4_dp2_tp2",
)
)
return configs
......@@ -33,7 +41,8 @@ def generate_configs():
def generate_context_parallel_configs():
configs = []
mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp")
axes = ("dp", "cp", "tp")
DP_sizes = (1, 2)
CP_sizes = (1, 2, 4, 8)
TP_sizes = (1, 2)
......@@ -41,13 +50,7 @@ def generate_context_parallel_configs():
ndev = cp * tp * dp
if is_devices_enough(ndev):
configs.append(
pytest.param(
ndev,
(dp, cp, tp),
("dp", "cp", "tp"),
MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"),
id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}",
)
pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}")
)
return configs
......
......@@ -37,8 +37,7 @@ from transformer_engine.jax.attention import (
)
from transformer_engine.jax.sharding import MeshResource
# We will use the golden reference model from our non distributed attention test fixture.
from test_fused_attn import general_dot_product_attention, make_mask
from test_fused_attn import FusedAttnRunner, BiasShape, general_dot_product_attention, make_mask
DTYPES = [jnp.float16, jnp.bfloat16]
......@@ -49,7 +48,7 @@ class TestDistributedSelfAttn:
self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
_, seqlen, _, heads, _ = shape
_, seqlen, heads, _ = shape
is_dp_enabled = mesh_resource.dp_resource is not None
tp_size = 1
if mesh_resource.tp_resource is not None:
......@@ -62,45 +61,28 @@ class TestDistributedSelfAttn:
# for loss and dbias
return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, with_bias, attn_mask_type, dtype):
batch, seqlen, _, heads, _ = shape
qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
bias = (
random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype)
if with_bias
else None
)
mask = None
if attn_mask_type == AttnMaskType.PADDING_MASK:
mask = make_causal_mask(batch, seqlen)
elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_self_mask(batch, seqlen)
qkv_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
)
bias_pspec = (
PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None
)
mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
@pytest.mark.parametrize(
"attn_bias_type",
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS],
"data_shape",
[
pytest.param((32, 512, 12, 64), id="32-512-12-64"),
pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
],
)
@pytest.mark.parametrize(
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
],
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
......@@ -111,14 +93,14 @@ class TestDistributedSelfAttn:
mesh_resource,
data_shape,
attn_bias_type,
bias_shape,
attn_mask_type,
dtype,
):
dropout_prob = 0.0
is_training = True
scaling_factor = 1.0
_, seqlen, _, num_head, hidden = data_shape
batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(
dtype,
......@@ -136,74 +118,36 @@ class TestDistributedSelfAttn:
):
pytest.skip(f"No FusedAttn backend found")
def target_func(qkv, bias, mask):
return jnp.mean(
fused_attn(
(qkv,),
bias,
mask,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=QKVLayout.BS3HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
)
)
def ref_func(qkv, bias, mask):
query, key, value = jnp.split(qkv, [1, 2], axis=-3)
query = jnp.squeeze(query)
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype)
with_bias = attn_bias_type != AttnBiasType.NO_BIAS
(qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, with_bias, attn_mask_type, dtype
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec))
bias_ = (
jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias
)
mask_ = (
jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
col_ref = self.generate_collectives_count_ref(
mesh_shape,
mesh_axes,
mesh_resource,
attn_bias_type != AttnBiasType.NO_BIAS,
data_shape,
dtype,
)
grad_args = (0, 1) if with_bias else (0,)
out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,)
compare_ops(
target_func,
ref_func,
[qkv_, bias_, mask_],
collective_count_ref,
grad_args=grad_args,
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
out_shardings=(None, out_grad_shardings),
runner = FusedAttnRunner(
batch,
seqlen,
seqlen,
num_head,
num_head,
hidden,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
is_training,
QKVLayout.BS3HD,
bias_shape,
None,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
mesh_resource=mesh_resource,
coll_count_ref=col_ref,
)
runner.test_backward()
class TestDistributedCrossAttn:
......@@ -213,31 +157,6 @@ class TestDistributedCrossAttn:
all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype):
batch, seqlen, heads, hidden = shape
q = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
kv = random.normal(random.PRNGKey(1125), (batch, seqlen, 2, heads, hidden), dtype=dtype)
mask = None
if attn_mask_type == AttnMaskType.PADDING_MASK:
mask = make_causal_mask(batch, seqlen)
elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_self_mask(batch, seqlen)
q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None)
kv_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
)
mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
@pytest.mark.parametrize(
......@@ -248,11 +167,11 @@ class TestDistributedCrossAttn:
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
):
attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None
dropout_prob = 0.0
is_training = True
scaling_factor = 1.0
_, seqlen, num_head, hidden = data_shape
batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(
dtype,
......@@ -270,67 +189,29 @@ class TestDistributedCrossAttn:
):
pytest.skip(f"No FusedAttn backend found")
def target_func(q, kv, mask):
return jnp.mean(
fused_attn(
(q, kv),
None,
mask,
col_ref = self.generate_collectives_count_ref()
runner = FusedAttnRunner(
batch,
seqlen,
seqlen,
num_head,
num_head,
hidden,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
is_training,
QKVLayout.BSHD_BS2HD,
bias_shape,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=QKVLayout.BSHD_BS2HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
),
dtype=jnp.float32,
)
def ref_func(query, kv, mask):
key, value = jnp.split(kv, [1], axis=-3)
query = jnp.squeeze(query)
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(
query,
key,
value,
bias=None,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32,
)
return jnp.mean(output, dtype=jnp.float32)
(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, attn_mask_type, dtype
)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
q_ = jax.device_put(q, NamedSharding(mesh, q_pspec))
kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec))
mask_ = (
jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
)
compare_ops(
target_func,
ref_func,
[q_, kv_, mask_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec)),
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
mesh_resource=mesh_resource,
coll_count_ref=col_ref,
)
runner.test_backward()
@pytest.mark.parametrize(
......@@ -366,41 +247,6 @@ class TestDistributedCrossAttn:
)
class TestDistributedContextParallelSelfAttn:
def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
batch, seqlen, heads, hidden = shape
kv_shape = (batch, seqlen, heads // kv_groups, hidden)
qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3)
q = random.normal(qkey, shape, dtype=dtype)
k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype)
def gen_valid(bs, max_seqlen, pad_ratio):
pad_len = int(max_seqlen * pad_ratio)
valid_len = max_seqlen - pad_len
tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
return tokens, jnp.logical_not(tokens)
from test_fused_attn import make_mask
q_idx, _ = gen_valid(batch, seqlen, 0.0)
kv_idx, _ = gen_valid(batch, seqlen, 0.0)
mask = make_mask(q_idx, kv_idx, None, None, attn_mask_type)
return q, k, v, mask
def qkv_to_layout(self, q, k, v, qkv_layout):
qkv_args = ()
match qkv_layout:
case QKVLayout.BSHD_BS2HD:
k, v = map(partial(jnp.expand_dims, axis=-3), [k, v])
kv = jnp.concatenate((k, v), axis=-3)
qkv_args = (q, kv)
case QKVLayout.BSHD_BSHD_BSHD:
qkv_args = (q, k, v)
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
return qkv_args
def impl_test_context_parallel_attn(
self,
device_count,
......@@ -416,6 +262,7 @@ class TestDistributedContextParallelSelfAttn:
cp_strategy,
):
attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None
dropout_prob = 0.0
is_training = True
dp_size, cp_size, tp_size = mesh_shape
......@@ -431,6 +278,29 @@ class TestDistributedContextParallelSelfAttn:
num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)
runner = FusedAttnRunner(
batch,
seqlen,
seqlen,
num_head,
num_kv_heads,
hidden,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
None,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
mesh_resource=mesh_resource,
cp_strategy=cp_strategy,
cp_load_balanced=load_balanced,
)
def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
dtype,
......@@ -465,123 +335,7 @@ class TestDistributedContextParallelSelfAttn:
if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")
def target_func(q, k, v, mask):
return fused_attn(
self.qkv_to_layout(q, k, v, qkv_layout),
None, # bias
mask,
None, # seed
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
context_parallel_strategy=cp_strategy,
context_parallel_causal_load_balanced=load_balanced,
context_parallel_axis="cp",
).astype(dtype)
def ref_func(q, k, v, mask):
output = general_dot_product_attention(
q,
k,
v,
bias=None,
mask=mask,
deterministic=not is_training,
scale_factor=scaling_factor,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32,
)
return output.astype(dtype)
def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
_, max_seq_len, num_heads, _ = data_shape
gradient_multiplier = max_seq_len * num_heads
if attn_mask_type.is_causal():
gradient_multiplier /= 10
ret_valid = func(*args, **kwargs)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype)
q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype)
diff_argnums = (0, 1, 2)
# Single GPU (reference)
ref_func_jit = jax.jit(
jax.value_and_grad(
lambda q, k, v, mask: grad_func(ref_func, q, k, v, mask), argnums=diff_argnums
)
)
ref_fwd, ref_grads = ref_func_jit(q, k, v, mask)
# Multi GPU (function under test)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource, enabled=False):
qkv_ps = PartitionSpec(
mesh_resource.dp_resource,
mesh_resource.cp_resource,
mesh_resource.tp_resource,
None,
)
qkv_sharding = NamedSharding(mesh, qkv_ps)
mask_ps = PartitionSpec(
mesh_resource.dp_resource, None, mesh_resource.cp_resource, None
)
mask_sharding = NamedSharding(mesh, mask_ps)
reorder = partial(
reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
)
inverse_reorder = partial(
inverse_reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format
)
if load_balanced:
q, k, v = jax.tree.map(reorder, (q, k, v))
q_, k_, v_ = map(partial(jax.device_put, device=qkv_sharding), [q, k, v])
mask_ = jax.device_put(mask, device=mask_sharding)
target_func_jit = jax.jit(
jax.value_and_grad(
lambda q, k, v, mask: grad_func(target_func, q, k, v, mask),
argnums=diff_argnums,
),
in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding],
out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)),
)
target_fwd, target_grads = target_func_jit(q_, k_, v_, mask_)
if load_balanced:
target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3])
target_grads = (target_dq, target_dk, target_dv, *target_grads[3:])
has_diffs = False
print_debug_tensor_stats("target", target_fwd)
print_debug_tensor_stats("ref", ref_fwd)
print_debug_tensor_stats("diff", jnp.abs(target_fwd - ref_fwd))
assert_allclose(target_fwd, ref_fwd, dtype=dtype)
for i in range(len(target_grads)):
if ref_grads[i] is None or target_grads[i] is None:
# expect both none if one is
assert target_grads[i] is None and ref_grads[i] is None
else:
print_debug_tensor_stats(f"target_grad[{i}]", target_grads[i])
print_debug_tensor_stats(f"ref_grad[{i}]", ref_grads[i])
print_debug_tensor_stats(
f"diff_grad[{i}]", jnp.abs(target_grads[i] - ref_grads[i])
)
assert_allclose(target_grads[i], ref_grads[i], dtype=dtype)
runner.test_backward()
def test_context_parallel_allgather_attn(
self,
......
......@@ -3,10 +3,10 @@
# See LICENSE for license information.
"""Tests for fused attention"""
from enum import Enum
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import partial
from math import sqrt
from typing import Tuple, Optional
from typing import Tuple, Optional, Dict
import random
import jax
......@@ -19,16 +19,22 @@ from flax.linen import make_attention_mask
from flax.linen.dtypes import promote_dtype
from jax import Array
from jax import value_and_grad, jit
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
fused_attn,
fused_attn_thd,
make_swa_mask,
CPStrategy,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import (
......@@ -36,7 +42,8 @@ from transformer_engine.transformer_engine_jax import (
get_cudnn_version,
)
from utils import assert_allclose
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats
@pytest.fixture(autouse=True, scope="module")
......@@ -304,6 +311,19 @@ class FusedAttnRunner:
bias_shape: BiasShape
window_size: Optional[Tuple[int, int]] = None
# Specifies sharding resources for distributed tests
number_of_devices: int = 1
mesh_shape: tuple[int, ...] = (1, 1, 1)
mesh_axes: tuple[str, ...] = ("dp", "cp", "tp")
mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp"))
# Context parallel aux arguments
cp_strategy: CPStrategy = CPStrategy.DEFAULT
cp_load_balanced: bool = True
# dictionary of expected collective comm bytes
coll_count_ref: Optional[Dict[str, int]] = None
# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
......@@ -362,6 +382,14 @@ class FusedAttnRunner:
def _setup_inputs(self):
self._check_configs()
# Create a mesh for distributed tests
self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape)
self.mesh = Mesh(self.devices, self.mesh_axes)
self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1)
key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
......@@ -527,6 +555,66 @@ class FusedAttnRunner:
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim)
# Setup distributed sharding specs
# Setup shardings for distributed tests
self.qkvo_psec = PartitionSpec(
self.mesh_resource.dp_resource,
self.mesh_resource.cp_resource,
self.mesh_resource.tp_resource,
None,
)
self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)
self.mask_pspec = PartitionSpec(
self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
)
self.mask_sharding = NamedSharding(self.mesh, self.mask_pspec)
if self.bias_shape == BiasShape._1HSS:
self.bias_pspec = PartitionSpec(
None, self.mesh_resource.tp_resource, self.mesh_resource.cp_resource, None
)
elif self.bias_shape == BiasShape._B1SS:
self.bias_pspec = PartitionSpec(
self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
)
elif self.bias_shape == BiasShape._11SS:
self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
else:
self.bias_pspec = PartitionSpec()
self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)
self.dropout_rng_pspec = PartitionSpec(
None,
)
self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec)
self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec)
# [batch][max_segments_per_batch]
# TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset.
self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)
# Softmax aux sharding
if self.cp_size > 1 and self.cp_load_balanced:
self.cp_reorder_fn = partial(
reorder_causal_load_balancing,
cp_size=self.cp_size,
tensor_format=self.qkv_layout.get_qkv_format(),
)
self.cp_inverse_reorder_fn = partial(
inverse_reorder_causal_load_balancing,
cp_size=self.cp_size,
tensor_format=self.qkv_layout.get_qkv_format(),
)
else:
# no-ops for non cp or non load balanced
self.cp_reorder_fn = lambda x: x
self.cp_inverse_reorder_fn = lambda x: x
def test_forward(self):
"""
Test forward without JIT
......@@ -534,17 +622,21 @@ class FusedAttnRunner:
self._setup_inputs()
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
customcall_args = [
self.q,
self.k,
self.v,
self.bias,
self.mask_for_customcall,
self.seqlens_q,
self.seqlens_kv,
self.offsets_q,
self.offsets_kv,
self.dropout_rng,
# Put test data onto each GPU for distributed.
# TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
# THD params once we support those features on CP.
jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.mask_for_customcall, self.mask_sharding),
jax.device_put(self.seqlens_q, self.seq_length_offset_sharding),
jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding),
jax.device_put(self.offsets_q, self.seq_length_offset_sharding),
jax.device_put(self.offsets_kv, self.seq_length_offset_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
"attn_bias_type": self.attn_bias_type,
......@@ -555,10 +647,31 @@ class FusedAttnRunner:
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
}
# Convert the outputs to float32 for the elementwise comparison
primitive_out = customcall_fused_dpa(*customcall_args, **kwargs)
customcall_fused_dpa_jit = jit(
partial(customcall_fused_dpa, **kwargs),
static_argnames=kwargs.keys(),
in_shardings=[
self.qkvo_sharding,
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.mask_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.dropout_rng_sharding,
],
)
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
primitive_out = customcall_fused_dpa_jit(*customcall_args)
primitive_out = self.cp_inverse_reorder_fn(primitive_out)
reference_out = jax_dpa(*args, **kwargs)
if self.is_training and self.dropout_prob > 0.0:
......@@ -571,9 +684,19 @@ class FusedAttnRunner:
assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
if self.coll_count_ref is not None:
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
target_hlo = (
customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
)
assert_equal_collectives(target_hlo, self.coll_count_ref)
def test_backward(self):
"""
Test value_and_grad with JIT, which includes both forward and backward
Test value_and_grad with JIT, which includes both forward and backward.
If coll_count_ref is not None then the HLO of the backwrds function
HLO will be examined for the expected comms.
"""
self._setup_inputs()
......@@ -587,20 +710,24 @@ class FusedAttnRunner:
ret_valid = jnp.where(
self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
return (
jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
).astype(self.dtype)
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
customcall_args = [
self.q,
self.k,
self.v,
self.bias,
self.mask_for_customcall,
self.seqlens_q,
self.seqlens_kv,
self.offsets_q,
self.offsets_kv,
self.dropout_rng,
# TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
# THD params once we support those features on CP.
jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.mask_for_customcall, self.mask_sharding),
jax.device_put(self.seqlens_q, self.seq_length_offset_sharding),
jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding),
jax.device_put(self.offsets_q, self.seq_length_offset_sharding),
jax.device_put(self.offsets_kv, self.seq_length_offset_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
"attn_bias_type": self.attn_bias_type,
......@@ -611,10 +738,22 @@ class FusedAttnRunner:
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
}
# We can compute dBias only for the [1, h, s, s] layout
arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape._1HSS else (0, 1, 2)
if self.bias_shape == BiasShape._1HSS:
arg_nums = (0, 1, 2, 3)
grad_shardings = (
self.qkvo_sharding,
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
)
else:
arg_nums = (0, 1, 2)
grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding)
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
......@@ -623,7 +762,20 @@ class FusedAttnRunner:
customcall_fused_dpa, q, k, v, bias, *args, **kwargs
),
arg_nums,
)
),
in_shardings=(
self.qkvo_sharding,
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.mask_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.dropout_rng_sharding,
),
out_shardings=(None, grad_shardings),
)
jitted_reference = jit(
value_and_grad(
......@@ -632,20 +784,31 @@ class FusedAttnRunner:
)
)
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)
reference_out, reference_dgrad = jitted_reference(*args)
# Skip elementwise comparison when dropout enabled
if self.dropout_prob > 0.0:
return
print_debug_tensor_stats(f"primitive_out", primitive_out)
print_debug_tensor_stats(f"reference_grad_valid", reference_out)
print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out))
assert_allclose(primitive_out, reference_out, dtype=self.dtype)
def check_dqkv(primitive, reference, pad):
def check_dqkv(primitive, reference, pad, idx):
primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
_split_valid_and_invalid(primitive, reference, pad)
)
print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx])
print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx])
print_debug_tensor_stats(
f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx])
)
assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
......@@ -653,11 +816,17 @@ class FusedAttnRunner:
primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
check_dqkv(primitive_dq, reference_dq, self.pad_q)
check_dqkv(primitive_dk, reference_dk, self.pad_kv)
check_dqkv(primitive_dv, reference_dv, self.pad_kv)
primitive_dq = self.cp_inverse_reorder_fn(primitive_dq)
primitive_dk = self.cp_inverse_reorder_fn(primitive_dk)
primitive_dv = self.cp_inverse_reorder_fn(primitive_dv)
check_dqkv(primitive_dq, reference_dq, self.pad_q, 0)
check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1)
check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2)
if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
# TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation.
primitive_dbias = primitive_dgrad[3]
reference_dbias = reference_dgrad[3]
......@@ -685,6 +854,11 @@ class FusedAttnRunner:
dtype=self.dtype,
)
if self.coll_count_ref is not None:
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
assert_equal_collectives(target_hlo, self.coll_count_ref)
@pytest.mark.parametrize(
"attn_mask_type",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment