Unverified Commit 6e848924 authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[JAX] Consolidate the distributed fused attention test code (#1405)



Consolidate the distributed fused attention tests to shared input generation and execition logic.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent c2937c5a
......@@ -18,14 +18,22 @@ from utils import assert_allclose, is_devices_enough
def generate_configs():
configs = []
if is_devices_enough(2):
configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")])
configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")])
configs.append(
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
)
configs.append(
pytest.param(2, (2,), ("tp",), MeshResource(tp_resource="tp"), id="n2_dp1_tp2")
)
if is_devices_enough(4):
TP_size = 2
DP_size = 2
configs.append(
[4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")]
pytest.param(
4,
(2, 2),
("dp", "tp"),
MeshResource(dp_resource="dp", tp_resource="tp"),
id=f"n4_dp2_tp2",
)
)
return configs
......@@ -33,7 +41,8 @@ def generate_configs():
def generate_context_parallel_configs():
configs = []
mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp")
axes = ("dp", "cp", "tp")
DP_sizes = (1, 2)
CP_sizes = (1, 2, 4, 8)
TP_sizes = (1, 2)
......@@ -41,13 +50,7 @@ def generate_context_parallel_configs():
ndev = cp * tp * dp
if is_devices_enough(ndev):
configs.append(
pytest.param(
ndev,
(dp, cp, tp),
("dp", "cp", "tp"),
MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"),
id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}",
)
pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}")
)
return configs
......
This diff is collapsed.
......@@ -3,10 +3,10 @@
# See LICENSE for license information.
"""Tests for fused attention"""
from enum import Enum
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import partial
from math import sqrt
from typing import Tuple, Optional
from typing import Tuple, Optional, Dict
import random
import jax
......@@ -19,16 +19,22 @@ from flax.linen import make_attention_mask
from flax.linen.dtypes import promote_dtype
from jax import Array
from jax import value_and_grad, jit
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
fused_attn,
fused_attn_thd,
make_swa_mask,
CPStrategy,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import (
......@@ -36,7 +42,8 @@ from transformer_engine.transformer_engine_jax import (
get_cudnn_version,
)
from utils import assert_allclose
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats
@pytest.fixture(autouse=True, scope="module")
......@@ -304,6 +311,19 @@ class FusedAttnRunner:
bias_shape: BiasShape
window_size: Optional[Tuple[int, int]] = None
# Specifies sharding resources for distributed tests
number_of_devices: int = 1
mesh_shape: tuple[int, ...] = (1, 1, 1)
mesh_axes: tuple[str, ...] = ("dp", "cp", "tp")
mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp"))
# Context parallel aux arguments
cp_strategy: CPStrategy = CPStrategy.DEFAULT
cp_load_balanced: bool = True
# dictionary of expected collective comm bytes
coll_count_ref: Optional[Dict[str, int]] = None
# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
......@@ -362,6 +382,14 @@ class FusedAttnRunner:
def _setup_inputs(self):
self._check_configs()
# Create a mesh for distributed tests
self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape)
self.mesh = Mesh(self.devices, self.mesh_axes)
self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1)
key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
......@@ -527,6 +555,66 @@ class FusedAttnRunner:
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim)
# Setup distributed sharding specs
# Setup shardings for distributed tests
self.qkvo_psec = PartitionSpec(
self.mesh_resource.dp_resource,
self.mesh_resource.cp_resource,
self.mesh_resource.tp_resource,
None,
)
self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)
self.mask_pspec = PartitionSpec(
self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
)
self.mask_sharding = NamedSharding(self.mesh, self.mask_pspec)
if self.bias_shape == BiasShape._1HSS:
self.bias_pspec = PartitionSpec(
None, self.mesh_resource.tp_resource, self.mesh_resource.cp_resource, None
)
elif self.bias_shape == BiasShape._B1SS:
self.bias_pspec = PartitionSpec(
self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
)
elif self.bias_shape == BiasShape._11SS:
self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
else:
self.bias_pspec = PartitionSpec()
self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)
self.dropout_rng_pspec = PartitionSpec(
None,
)
self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec)
self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec)
# [batch][max_segments_per_batch]
# TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset.
self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)
# Softmax aux sharding
if self.cp_size > 1 and self.cp_load_balanced:
self.cp_reorder_fn = partial(
reorder_causal_load_balancing,
cp_size=self.cp_size,
tensor_format=self.qkv_layout.get_qkv_format(),
)
self.cp_inverse_reorder_fn = partial(
inverse_reorder_causal_load_balancing,
cp_size=self.cp_size,
tensor_format=self.qkv_layout.get_qkv_format(),
)
else:
# no-ops for non cp or non load balanced
self.cp_reorder_fn = lambda x: x
self.cp_inverse_reorder_fn = lambda x: x
def test_forward(self):
"""
Test forward without JIT
......@@ -534,17 +622,21 @@ class FusedAttnRunner:
self._setup_inputs()
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
customcall_args = [
self.q,
self.k,
self.v,
self.bias,
self.mask_for_customcall,
self.seqlens_q,
self.seqlens_kv,
self.offsets_q,
self.offsets_kv,
self.dropout_rng,
# Put test data onto each GPU for distributed.
# TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
# THD params once we support those features on CP.
jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.mask_for_customcall, self.mask_sharding),
jax.device_put(self.seqlens_q, self.seq_length_offset_sharding),
jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding),
jax.device_put(self.offsets_q, self.seq_length_offset_sharding),
jax.device_put(self.offsets_kv, self.seq_length_offset_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
"attn_bias_type": self.attn_bias_type,
......@@ -555,10 +647,31 @@ class FusedAttnRunner:
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
}
# Convert the outputs to float32 for the elementwise comparison
primitive_out = customcall_fused_dpa(*customcall_args, **kwargs)
customcall_fused_dpa_jit = jit(
partial(customcall_fused_dpa, **kwargs),
static_argnames=kwargs.keys(),
in_shardings=[
self.qkvo_sharding,
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.mask_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.dropout_rng_sharding,
],
)
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
primitive_out = customcall_fused_dpa_jit(*customcall_args)
primitive_out = self.cp_inverse_reorder_fn(primitive_out)
reference_out = jax_dpa(*args, **kwargs)
if self.is_training and self.dropout_prob > 0.0:
......@@ -571,9 +684,19 @@ class FusedAttnRunner:
assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
if self.coll_count_ref is not None:
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
target_hlo = (
customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
)
assert_equal_collectives(target_hlo, self.coll_count_ref)
def test_backward(self):
"""
Test value_and_grad with JIT, which includes both forward and backward
Test value_and_grad with JIT, which includes both forward and backward.
If coll_count_ref is not None then the HLO of the backwrds function
HLO will be examined for the expected comms.
"""
self._setup_inputs()
......@@ -587,20 +710,24 @@ class FusedAttnRunner:
ret_valid = jnp.where(
self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
return (
jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
).astype(self.dtype)
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
customcall_args = [
self.q,
self.k,
self.v,
self.bias,
self.mask_for_customcall,
self.seqlens_q,
self.seqlens_kv,
self.offsets_q,
self.offsets_kv,
self.dropout_rng,
# TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
# THD params once we support those features on CP.
jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.mask_for_customcall, self.mask_sharding),
jax.device_put(self.seqlens_q, self.seq_length_offset_sharding),
jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding),
jax.device_put(self.offsets_q, self.seq_length_offset_sharding),
jax.device_put(self.offsets_kv, self.seq_length_offset_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
"attn_bias_type": self.attn_bias_type,
......@@ -611,10 +738,22 @@ class FusedAttnRunner:
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
}
# We can compute dBias only for the [1, h, s, s] layout
arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape._1HSS else (0, 1, 2)
if self.bias_shape == BiasShape._1HSS:
arg_nums = (0, 1, 2, 3)
grad_shardings = (
self.qkvo_sharding,
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
)
else:
arg_nums = (0, 1, 2)
grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding)
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
......@@ -623,7 +762,20 @@ class FusedAttnRunner:
customcall_fused_dpa, q, k, v, bias, *args, **kwargs
),
arg_nums,
)
),
in_shardings=(
self.qkvo_sharding,
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.mask_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.dropout_rng_sharding,
),
out_shardings=(None, grad_shardings),
)
jitted_reference = jit(
value_and_grad(
......@@ -632,20 +784,31 @@ class FusedAttnRunner:
)
)
primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)
reference_out, reference_dgrad = jitted_reference(*args)
# Skip elementwise comparison when dropout enabled
if self.dropout_prob > 0.0:
return
print_debug_tensor_stats(f"primitive_out", primitive_out)
print_debug_tensor_stats(f"reference_grad_valid", reference_out)
print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out))
assert_allclose(primitive_out, reference_out, dtype=self.dtype)
def check_dqkv(primitive, reference, pad):
def check_dqkv(primitive, reference, pad, idx):
primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
_split_valid_and_invalid(primitive, reference, pad)
)
print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx])
print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx])
print_debug_tensor_stats(
f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx])
)
assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
......@@ -653,11 +816,17 @@ class FusedAttnRunner:
primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
check_dqkv(primitive_dq, reference_dq, self.pad_q)
check_dqkv(primitive_dk, reference_dk, self.pad_kv)
check_dqkv(primitive_dv, reference_dv, self.pad_kv)
primitive_dq = self.cp_inverse_reorder_fn(primitive_dq)
primitive_dk = self.cp_inverse_reorder_fn(primitive_dk)
primitive_dv = self.cp_inverse_reorder_fn(primitive_dv)
check_dqkv(primitive_dq, reference_dq, self.pad_q, 0)
check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1)
check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2)
if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
# TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation.
primitive_dbias = primitive_dgrad[3]
reference_dbias = reference_dgrad[3]
......@@ -685,6 +854,11 @@ class FusedAttnRunner:
dtype=self.dtype,
)
if self.coll_count_ref is not None:
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
assert_equal_collectives(target_hlo, self.coll_count_ref)
@pytest.mark.parametrize(
"attn_mask_type",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment