Unverified Commit d3cbccdf authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[JAX] Scale sequence length in CP tests to avoid tiny sizes. (#1347)



Scale sequence length in CP tests to avoid tiny sizes.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent 64126aa8
...@@ -341,8 +341,9 @@ class TestDistributedCrossAttn: ...@@ -341,8 +341,9 @@ class TestDistributedCrossAttn:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"data_shape", "data_shape",
[ [
pytest.param([2, 512, 12, 128], id="2-512-12-128"), # Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"), pytest.param([2, 128, 12, 128], id="2-128xCP-12-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
], ],
) )
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16]) @pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
...@@ -423,6 +424,12 @@ class TestDistributedContextParallelSelfAttn: ...@@ -423,6 +424,12 @@ class TestDistributedContextParallelSelfAttn:
qkv_format = get_qkv_format(qkv_layout) qkv_format = get_qkv_format(qkv_layout)
batch, seqlen, num_head, hidden = data_shape batch, seqlen, num_head, hidden = data_shape
# Scale the sequence length by 2*CP so its never too small as we scale up test.
# 2*CP is used since we split into two CP groups for load balancing.
seqlen = seqlen * cp_size * 2
data_shape = batch, seqlen, num_head, hidden
num_kv_heads = num_head // kv_groups num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head) scaling_factor = 1.0 / np.sqrt(num_head)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment