"vscode:/vscode.git/clone" did not exist on "5b3d65cc1a157ac76d0e4c6342db0c5d80f69984"
Unverified Commit d3cbccdf authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[JAX] Scale sequence length in CP tests to avoid tiny sizes. (#1347)



Scale sequence length in CP tests to avoid tiny sizes.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent 64126aa8
......@@ -341,8 +341,9 @@ class TestDistributedCrossAttn:
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 12, 128], id="2-128xCP-12-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
......@@ -423,6 +424,12 @@ class TestDistributedContextParallelSelfAttn:
qkv_format = get_qkv_format(qkv_layout)
batch, seqlen, num_head, hidden = data_shape
# Scale the sequence length by 2*CP so its never too small as we scale up test.
# 2*CP is used since we split into two CP groups for load balancing.
seqlen = seqlen * cp_size * 2
data_shape = batch, seqlen, num_head, hidden
num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment