Unverified Commit 5f0e3b93 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Refactor and trim TE JAX Attn testing (#2542)



* Pick a leaner set of combinations for TE JAX CP attn tests such that only those cp,dp,tp combinations are picked where cp*dp*tp is equal to num gpus
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Consolidate the test cases run for different B,S,H,D and QKV layout
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Code and comments clean up
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Make FP16 + GQA test cross attn instead of self attn to generalize the test
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 5f828c25
......@@ -12,7 +12,7 @@ from jax._src.sharding_impls import UNSPECIFIED as _UNSPECIFIED
from transformer_engine.jax.sharding import MeshResource
from utils import assert_allclose, is_devices_enough
from utils import assert_allclose, is_devices_enough, is_devices_equal
def generate_configs():
......@@ -49,7 +49,11 @@ def generate_context_parallel_configs_for_attn():
TP_sizes = (1, 2)
for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes):
ndev = cp * tp * dp
if is_devices_enough(ndev):
# Run only those dp,cp,tp combinations which require exactly ndev GPUs.
# For e.g., if num_GPUs is 8 and ndev=8 , all the dp,cp,tp combinations fulfilling ndev = cp * tp * dp are picked.
# However, if num_GPUs is 8 and ndev=4, then all the dp,cp,tp combinations fulfilling ndev = cp * tp * dp are ignored.
# To explicitly pick combinations associated with ndev=4, one can set CUDA_VISIBLE_DEVICES=0,1,2,3, thereby forcing num_GPUs to 4 instead of 8.
if is_devices_equal(ndev):
# Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations)
if cp != 1:
configsL1.append(
......
......@@ -334,7 +334,7 @@ DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [
class TestDistributedContextParallelSelfAttn:
# TODO(KshitijLakhani): parametrize num_segments_per_seq for all CP tests
def impl_test_context_parallel_attn(
self,
device_count,
......
......@@ -1068,41 +1068,70 @@ class FusedAttnRunner:
],
)
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
],
)
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype",
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout",
[
# large data size + bf16 + qkv packed
pytest.param(
2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF"
2,
2048,
2048,
12,
12,
64,
64,
jnp.bfloat16,
QKVLayout.BS3HD,
id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED",
),
pytest.param(
2,
512,
1024,
2048,
2048,
12,
12,
64,
64,
jnp.bfloat16,
id="2-512-1024-12-12-64-64-BF16-CROSS",
QKVLayout.T3HD,
id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED",
),
# mid data size + bf16 + cross attn + kv packed
pytest.param(
2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
2,
512,
1024,
12,
12,
64,
64,
jnp.bfloat16,
QKVLayout.BSHD_BS2HD,
id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED",
),
pytest.param(
4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF"
2,
512,
1024,
12,
12,
64,
64,
jnp.bfloat16,
QKVLayout.THD_T2HD,
id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED",
),
# large data size + bf16 + cross attn + diff hidden v dim + qkv separate
pytest.param(
4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF"
2,
2048,
1024,
12,
12,
64,
32,
jnp.bfloat16,
QKVLayout.BSHD_BSHD_BSHD,
id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE",
),
pytest.param(
2,
......@@ -1113,10 +1142,108 @@ class FusedAttnRunner:
64,
32,
jnp.bfloat16,
id="2-2048-1024-12-12-64-32-BF16-CROSS",
QKVLayout.THD_THD_THD,
id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE",
),
# large data size + bf16 + gqa + kv packed
pytest.param(
2,
2048,
2048,
12,
6,
64,
64,
jnp.bfloat16,
QKVLayout.BSHD_BS2HD,
id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED",
),
pytest.param(
2,
2048,
2048,
12,
6,
64,
64,
jnp.bfloat16,
QKVLayout.THD_T2HD,
id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED",
),
# small data size + fp16 + diff hidden v dim + qkv packed
pytest.param(
2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA"
4,
128,
128,
16,
16,
64,
32,
jnp.float16,
QKVLayout.BS3HD,
id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED",
),
pytest.param(
4,
128,
128,
16,
16,
64,
32,
jnp.float16,
QKVLayout.T3HD,
id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED",
),
# small data size + fp16 + kv packed
pytest.param(
4,
128,
128,
16,
16,
64,
64,
jnp.float16,
QKVLayout.BSHD_BS2HD,
id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED",
),
pytest.param(
4,
128,
128,
16,
16,
64,
64,
jnp.float16,
QKVLayout.THD_T2HD,
id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED",
),
# large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate
pytest.param(
2,
1024,
2048,
12,
6,
128,
64,
jnp.float16,
QKVLayout.BSHD_BSHD_BSHD,
id="2-1024-2048-12-6-128-64-FP16-CROSS-GQA-SEPARATE",
),
pytest.param(
2,
1024,
2048,
12,
6,
128,
64,
jnp.float16,
QKVLayout.THD_THD_THD,
id="2-1024-2048-12-6-128-64-FP16-CROSS-GQA-RAGGED_SEPARATE",
),
],
)
......
......@@ -47,6 +47,13 @@ def is_devices_enough(required):
return len(jax.devices()) >= required
def is_devices_equal(required):
"""
Check if the available GPUs is exactly equal
"""
return len(jax.devices()) == required
def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
# Generate broadcast dims for drop_path.
drop_path_shape = list(range(0, len(shape)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment