Unverified Commit 92f431bf authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Trim dist fused attn tests in L1 (#2050)



* Move some dist fused attn tests to L2
1. TestReorderCausalLoadBalancing: Run two (non symmetric) BSHD/SBHD data shape combination
2. TestDistributedSelfAttn: Run only one (smaller) BSHD type data shape combination
3. TestDistributedCrossAttn: Run only one (smaller) BSHD type data shape combination
4. TestDistributedContextParallelSelfAttn: Run all cp1 combinations
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Use pytest_parametrize_wrapper for splitting fused attn distributed JAX tests as L1 and L2
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Undo pytest -k split commands in qa scripts
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix usage of pytest_parametrize_wrapper in test_distributed_fused_attn
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Remove test code for L2 dist residing in L2 test.sh
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Add comments for code. Swap the test data shapes in REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add L0 to the data shape dictionaries in the distributed test
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
parent 12065ac2
...@@ -39,8 +39,10 @@ def generate_configs(): ...@@ -39,8 +39,10 @@ def generate_configs():
return configs return configs
def generate_context_parallel_configs(): def generate_context_parallel_configs_for_attn():
configs = [] """Generate CP combinations along with TP+DP for TestDistributedContextParallelSelfAttn only"""
configsL1 = []
configsL2 = []
mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp") mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp")
axes = ("dp", "cp", "tp") axes = ("dp", "cp", "tp")
DP_sizes = (1, 2) DP_sizes = (1, 2)
...@@ -49,10 +51,16 @@ def generate_context_parallel_configs(): ...@@ -49,10 +51,16 @@ def generate_context_parallel_configs():
for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes):
ndev = cp * tp * dp ndev = cp * tp * dp
if is_devices_enough(ndev): if is_devices_enough(ndev):
configs.append( # Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations)
pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}") if cp != 1:
) configsL1.append(
pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}")
)
else:
configsL2.append(
pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}")
)
configs = {"L0": [], "L1": configsL1, "L2": configsL2}
return configs return configs
......
...@@ -9,10 +9,11 @@ import jax.numpy as jnp ...@@ -9,10 +9,11 @@ import jax.numpy as jnp
from jax import random from jax import random
from distributed_test_base import ( from distributed_test_base import (
generate_configs, generate_configs,
generate_context_parallel_configs, generate_context_parallel_configs_for_attn,
generate_collectives_count, generate_collectives_count,
) )
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
from utils import pytest_parametrize_wrapper
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
is_fused_attn_kernel_available, is_fused_attn_kernel_available,
AttnBiasType, AttnBiasType,
...@@ -28,6 +29,12 @@ from transformer_engine.jax.attention import ( ...@@ -28,6 +29,12 @@ from transformer_engine.jax.attention import (
DTYPES = [jnp.bfloat16] DTYPES = [jnp.bfloat16]
DISTRIBUTED_SELF_ATTN_DATA_SHAPES = {
"L0": [()],
"L1": [(32, 1024, 16, 128)],
"L2": [(32, 512, 12, 64)],
}
class TestDistributedSelfAttn: class TestDistributedSelfAttn:
...@@ -64,7 +71,6 @@ class TestDistributedSelfAttn: ...@@ -64,7 +71,6 @@ class TestDistributedSelfAttn:
jax.config.update("jax_use_shardy_partitioner", use_shardy) jax.config.update("jax_use_shardy_partitioner", use_shardy)
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
batch, seqlen, num_head, hidden = data_shape batch, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available( if not is_fused_attn_kernel_available(
...@@ -119,13 +125,7 @@ class TestDistributedSelfAttn: ...@@ -119,13 +125,7 @@ class TestDistributedSelfAttn:
runner.test_backward() runner.test_backward()
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize( @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SELF_ATTN_DATA_SHAPES)
"data_shape",
[
pytest.param((32, 512, 12, 64), id="32-512-12-64"),
pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"attn_bias_type, bias_shape", "attn_bias_type, bias_shape",
[ [
...@@ -193,6 +193,13 @@ class TestDistributedSelfAttn: ...@@ -193,6 +193,13 @@ class TestDistributedSelfAttn:
) )
DISTRIBUTED_CROSS_ATTN_DATA_SHAPES = {
"L0": [()],
"L1": [[32, 512, 16, 64]],
"L2": [[32, 128, 12, 64]],
}
class TestDistributedCrossAttn: class TestDistributedCrossAttn:
def generate_collectives_count_ref(self): def generate_collectives_count_ref(self):
...@@ -201,7 +208,7 @@ class TestDistributedCrossAttn: ...@@ -201,7 +208,7 @@ class TestDistributedCrossAttn:
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]]) @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_CROSS_ATTN_DATA_SHAPES)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK] "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
) )
...@@ -390,8 +397,9 @@ class TestDistributedContextParallelSelfAttn: ...@@ -390,8 +397,9 @@ class TestDistributedContextParallelSelfAttn:
runner.test_backward() runner.test_backward()
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]
@pytest.mark.parametrize( @pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() "device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
) )
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
...@@ -426,8 +434,9 @@ class TestDistributedContextParallelSelfAttn: ...@@ -426,8 +434,9 @@ class TestDistributedContextParallelSelfAttn:
use_shardy=True, use_shardy=True,
) )
@pytest.mark.parametrize( @pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() "device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
) )
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("kv_groups", [1, 8]) @pytest.mark.parametrize("kv_groups", [1, 8])
...@@ -468,8 +477,9 @@ class TestDistributedContextParallelSelfAttn: ...@@ -468,8 +477,9 @@ class TestDistributedContextParallelSelfAttn:
use_shardy=False, use_shardy=False,
) )
@pytest.mark.parametrize( @pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() "device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
) )
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("kv_groups", [1, 8]) @pytest.mark.parametrize("kv_groups", [1, 8])
...@@ -532,8 +542,9 @@ class TestDistributedContextParallelSelfAttn: ...@@ -532,8 +542,9 @@ class TestDistributedContextParallelSelfAttn:
window_size=window_size, window_size=window_size,
) )
@pytest.mark.parametrize( @pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() "device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
) )
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
...@@ -570,16 +581,16 @@ class TestDistributedContextParallelSelfAttn: ...@@ -570,16 +581,16 @@ class TestDistributedContextParallelSelfAttn:
) )
REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = {
"L0": [[]],
"L1": [[3, 32, 8, 64]],
"L2": [[4, 32, 12, 32], [1, 16, 1, 1]],
}
class TestReorderCausalLoadBalancing: class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8]) @pytest.mark.parametrize("cp_size", [2, 4, 8])
@pytest.mark.parametrize( @pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES)
"shape",
[
pytest.param([1, 16, 1, 1], id="1-16-1-1"),
pytest.param([4, 32, 12, 32], id="4-32-12-32"),
pytest.param([3, 32, 8, 64], id="3-32-8-64"),
],
)
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD]) @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"reorder_strategy", "reorder_strategy",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment