Unverified Commit 8dcdd349 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

cherry pick (#251)

parent 19fc0700
...@@ -13,7 +13,12 @@ import test_low_latency ...@@ -13,7 +13,12 @@ import test_low_latency
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup): def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings # Settings
num_tokens, hidden, num_topk_groups, num_topk, num_experts = 4096, 7168, min(num_nodes, 4), 8, (256 // num_ranks) * num_ranks num_tokens = int(os.environ.get("DEEPEP_TEST_NUM_TOKENS", "4096"))
hidden = int(os.environ.get("DEEPEP_TEST_HIDDEN", "7168"))
num_topk_groups = int(os.environ.get("DEEPEP_TEST_NUM_TOPK_GROUPS", str(min(num_nodes, 4))))
num_topk = int(os.environ.get("DEEPEP_TEST_NUM_TOPK", "8"))
num_experts = int(os.environ.get("DEEPEP_TEST_NUM_EXPERTS", str((256 // num_ranks) * num_ranks)))
assert num_experts % num_ranks == 0 and num_local_ranks == 8 assert num_experts % num_ranks == 0 and num_local_ranks == 8
if local_rank == 0: if local_rank == 0:
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True) print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True)
......
...@@ -13,7 +13,11 @@ import test_low_latency ...@@ -13,7 +13,11 @@ import test_low_latency
def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup): def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings # Settings
num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks num_tokens = int(os.environ.get("DEEPEP_TEST_NUM_TOKENS", "4096"))
hidden = int(os.environ.get("DEEPEP_TEST_HIDDEN", "7168"))
num_topk = int(os.environ.get("DEEPEP_TEST_NUM_TOPK", "8"))
num_experts = int(os.environ.get("DEEPEP_TEST_NUM_EXPERTS", str((256 // num_ranks) * num_ranks)))
assert num_experts % num_ranks == 0 assert num_experts % num_ranks == 0
if local_rank == 0: if local_rank == 0:
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True) print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True)
......
...@@ -160,7 +160,10 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -160,7 +160,10 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int): def test_loop(local_rank: int, num_local_ranks: int):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks) rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288 num_tokens = int(os.environ.get("DEEPEP_TEST_NUM_TOKENS", "128"))
hidden = int(os.environ.get("DEEPEP_TEST_HIDDEN", "7168"))
num_topk = int(os.environ.get("DEEPEP_TEST_NUM_TOPK", "8"))
num_experts = int(os.environ.get("DEEPEP_TEST_NUM_EXPERTS", "288"))
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts) num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0: if local_rank == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment