Unverified Commit 6d178b4e authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Reduce L1 tests/jax/test_distributed_softmax.py test runtime (#2031)



* Pytest timings
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Reduce softmax test shape sizes
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Switch softmax tests to use shardy by default
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent ed42b5ac
......@@ -5,6 +5,8 @@
import os
import jax
import pytest
from collections import defaultdict
import time
import transformer_engine.jax
......@@ -32,3 +34,54 @@ def enable_fused_attn_after_hopper():
yield
if "NVTE_FUSED_ATTN" in os.environ:
del os.environ["NVTE_FUSED_ATTN"]
class TestTimingPlugin:
"""
Plugin to measure test execution time. Enable test timing by setting NVTE_JAX_TEST_TIMING=1
in the environment.
"""
def __init__(self):
self.test_timings = defaultdict(list)
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_setup(self, item):
item._timing_start = time.time()
@pytest.hookimpl(trylast=True)
def pytest_runtest_teardown(self, item, nextitem):
if hasattr(item, "_timing_start"):
duration = time.time() - item._timing_start
# Extract base function name without parameters
test_name = item.name
if "[" in test_name:
base_name = test_name.split("[")[0]
else:
base_name = test_name
self.test_timings[base_name].append(duration)
def pytest_sessionfinish(self, session, exitstatus):
print("\n" + "=" * 80)
print("TEST RUNTIME SUMMARY (grouped by function)")
print("=" * 80)
total_overall = 0
for test_name, durations in sorted(self.test_timings.items()):
total_time = sum(durations)
count = len(durations)
avg_time = total_time / count if count > 0 else 0
total_overall += total_time
print(f"{test_name:<60} | {count:3}x | {total_time:7.2f}s | avg: {avg_time:6.2f}s")
print("=" * 80)
print(f"{'TOTAL RUNTIME':<60} | {'':>3} | {total_overall:7.2f}s |")
print("=" * 80)
def pytest_configure(config):
if os.getenv("NVTE_JAX_TEST_TIMING", "0") == "1":
config.pluginmanager.register(TestTimingPlugin(), "test_timing")
......@@ -135,7 +135,7 @@ class TestDistributedSoftmax:
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]])
@pytest.mark.parametrize(
"softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
......@@ -168,14 +168,14 @@ class TestDistributedSoftmax:
dtype,
bad_sharding,
broadcast_batch_mask,
use_shardy=False,
use_shardy=True,
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED])
@pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax_shardy(
def test_softmax_gspmd(
self,
device_count,
mesh_shape,
......@@ -196,5 +196,5 @@ class TestDistributedSoftmax:
dtype=DTYPES[0],
bad_sharding=bad_sharding,
broadcast_batch_mask=broadcast_batch_mask,
use_shardy=True,
use_shardy=False,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment