Commit c5b40405 authored by Chenggang Zhao's avatar Chenggang Zhao
Browse files

Enable intranode kernel tests with EP2 and EP4

parent 6cc3497d
...@@ -11,10 +11,11 @@ from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to ...@@ -11,10 +11,11 @@ from utils import init_dist, bench, calc_diff, inplace_unique, per_token_cast_to
import test_low_latency import test_low_latency
def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup): def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup):
# Settings # Settings
# TODO: fix EP2/4/8 performance
num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks
assert num_experts % num_ranks == 0 and num_local_ranks == 8 assert num_experts % num_ranks == 0
if local_rank == 0: if local_rank == 0:
print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True) print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True)
...@@ -208,7 +209,7 @@ def test_loop(local_rank: int, num_local_ranks: int): ...@@ -208,7 +209,7 @@ def test_loop(local_rank: int, num_local_ranks: int):
torch.manual_seed(rank) torch.manual_seed(rank)
for i in (24, ): for i in (24, ):
test_main(i, local_rank, num_local_ranks, num_ranks, rank, buffer, group) test_main(i, local_rank, num_ranks, rank, buffer, group)
if local_rank == 0: if local_rank == 0:
print() print()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment