Commit 95e46992 authored by lijian6's avatar lijian6
Browse files

fix intranode test err.


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent abdd8b40
...@@ -47,7 +47,7 @@ struct Config { ...@@ -47,7 +47,7 @@ struct Config {
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % (2 * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL) == 0); EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % (2 * NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL) == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL; const int num_channels = num_ranks <=8 ? num_sms / 2 : num_sms / NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL;
// 计算每个nvl通信数据包的数据量 // 计算每个nvl通信数据包的数据量
size_t num_single_nvl_bag_bytes = size_t num_single_nvl_bag_bytes =
......
...@@ -231,6 +231,10 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks ...@@ -231,6 +231,10 @@ def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks
print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True) print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', flush=True)
print('', flush=True) print('', flush=True)
def get_hidden_bytes(args: argparse.Namespace) -> int:
x = torch.ones((args.num_tokens, args.hidden), dtype=torch.bfloat16)
t = x[0] if isinstance(x, tuple) else x
return t.size(1) * max(t.element_size(), 2)
# noinspection PyUnboundLocalVariable,PyShadowingNames # noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
...@@ -240,11 +244,19 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -240,11 +244,19 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts) num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
buffer = deep_ep.Buffer(group, int(2e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility, num_sms = 60
deep_ep.Buffer.set_num_sms(num_sms)
hidden_bytes = get_hidden_bytes(args)
num_nvl_bytes = 0
for config in (deep_ep.Buffer.get_dispatch_config(group.size()), deep_ep.Buffer.get_combine_config(group.size())):
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)
buffer = deep_ep.Buffer(group, num_nvl_bytes, num_rdma_bytes, low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), explicitly_destroy=True) num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), explicitly_destroy=True)
torch.manual_seed(rank) torch.manual_seed(rank)
for i in (60, ): for i in (num_sms, ):
test_main(args, i, local_rank, num_ranks, rank, buffer, group) test_main(args, i, local_rank, num_ranks, rank, buffer, group)
if local_rank == 0: if local_rank == 0:
print('', flush=True) print('', flush=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment