"vscode:/vscode.git/clone" did not exist on "d3b4044e13b45b3a648e4e4ef86edfcfce3baa4e"
Commit da13c63a authored by lishen's avatar lishen
Browse files

完成低延迟接口功能

parent 09cb2b03
../../../lib/cmake/rocshmem/rocshmem-targets-release.cmake #----------------------------------------------------------------
\ No newline at end of file # Generated CMake target import file for configuration "Release".
#----------------------------------------------------------------
# Commands may need to know the format version.
set(CMAKE_IMPORT_FILE_VERSION 1)
# Import target "roc::rocshmem" for configuration "Release"
set_property(TARGET roc::rocshmem APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
set_target_properties(roc::rocshmem PROPERTIES
IMPORTED_LINK_INTERFACE_LANGUAGES_RELEASE "CXX"
IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/librocshmem.a"
)
list(APPEND _cmake_import_check_targets roc::rocshmem )
list(APPEND _cmake_import_check_files_for_roc::rocshmem "${_IMPORT_PREFIX}/lib/librocshmem.a" )
# Commands beyond this point should not need to know the version.
set(CMAKE_IMPORT_FILE_VERSION)
../../../lib/cmake/rocshmem/rocshmem-targets.cmake # Generated by CMake
\ No newline at end of file
if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8)
message(FATAL_ERROR "CMake >= 2.8.0 required")
endif()
if(CMAKE_VERSION VERSION_LESS "2.8.12")
message(FATAL_ERROR "CMake >= 2.8.12 required")
endif()
cmake_policy(PUSH)
cmake_policy(VERSION 2.8.12...3.27)
#----------------------------------------------------------------
# Generated CMake target import file.
#----------------------------------------------------------------
# Commands may need to know the format version.
set(CMAKE_IMPORT_FILE_VERSION 1)
# Protect against multiple inclusion, which would fail when already imported targets are added once more.
set(_cmake_targets_defined "")
set(_cmake_targets_not_defined "")
set(_cmake_expected_targets "")
foreach(_cmake_expected_target IN ITEMS roc::rocshmem)
list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
if(TARGET "${_cmake_expected_target}")
list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
else()
list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
endif()
endforeach()
unset(_cmake_expected_target)
if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
unset(_cmake_targets_defined)
unset(_cmake_targets_not_defined)
unset(_cmake_expected_targets)
unset(CMAKE_IMPORT_FILE_VERSION)
cmake_policy(POP)
return()
endif()
if(NOT _cmake_targets_defined STREQUAL "")
string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
endif()
unset(_cmake_targets_defined)
unset(_cmake_targets_not_defined)
unset(_cmake_expected_targets)
# Compute the installation prefix relative to this file.
get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
if(_IMPORT_PREFIX STREQUAL "/")
set(_IMPORT_PREFIX "")
endif()
# Create imported target roc::rocshmem
add_library(roc::rocshmem STATIC IMPORTED)
set_target_properties(roc::rocshmem PROPERTIES
INTERFACE_COMPILE_OPTIONS "-fgpu-rdc;-fgpu-rdc"
INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include;${_IMPORT_PREFIX}/include"
INTERFACE_LINK_LIBRARIES "IBVerbs::verbs;numa;\$<\$<BOOL:ON>:MPI::MPI_CXX>;Threads::Threads;hip::device;hip::host;dl;hsa-runtime64::hsa-runtime64;-fgpu-rdc"
)
# Load information for each installed configuration.
file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/rocshmem-targets-*.cmake")
foreach(_cmake_config_file IN LISTS _cmake_config_files)
include("${_cmake_config_file}")
endforeach()
unset(_cmake_config_file)
unset(_cmake_config_files)
# Cleanup temporary variables.
set(_IMPORT_PREFIX)
# Loop over all imported files and verify that they actually exist
foreach(_cmake_target IN LISTS _cmake_import_check_targets)
if(CMAKE_VERSION VERSION_LESS "3.28"
OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target}
OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}")
foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}")
if(NOT EXISTS "${_cmake_file}")
message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file
\"${_cmake_file}\"
but this file does not exist. Possible reasons include:
* The file was deleted, renamed, or moved to another location.
* An install or uninstall procedure did not complete successfully.
* The installation package was faulty and contained
\"${CMAKE_CURRENT_LIST_FILE}\"
but not all the files it references.
")
endif()
endforeach()
endif()
unset(_cmake_file)
unset("_cmake_import_check_files_for_${_cmake_target}")
endforeach()
unset(_cmake_target)
unset(_cmake_import_check_targets)
# This file does not depend on other imported targets which have
# been exported from the same project but in a separate export set.
# Commands beyond this point should not need to know the version.
set(CMAKE_IMPORT_FILE_VERSION)
cmake_policy(POP)
...@@ -159,7 +159,6 @@ def test_main(args: argparse.Namespace, num_sms: int, ...@@ -159,7 +159,6 @@ def test_main(args: argparse.Namespace, num_sms: int,
# Test combine # Test combine
# torch.cuda.synchronize() # torch.cuda.synchronize()
# print("lijian test dipatch end and combine start.")
bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode} combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode}
...@@ -264,7 +263,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -264,7 +263,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
if args.test_ll_compatibility: if args.test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_sms = 24 num_sms = 30
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0) num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
hidden_bytes = get_hidden_bytes(args) hidden_bytes = get_hidden_bytes(args)
...@@ -274,7 +273,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -274,7 +273,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes) num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)
buffer = deep_ep.Buffer(group, num_nvl_bytes, num_rdma_bytes, low_latency_mode=args.test_ll_compatibility, buffer = deep_ep.Buffer(group, num_nvl_bytes, num_rdma_bytes, low_latency_mode=args.test_ll_compatibility,
num_qps_per_rank=num_qps_per_rank, explicitly_destroy=True, use_default_stream_as_comm_stream=False) num_qps_per_rank=num_qps_per_rank, explicitly_destroy=True)
assert num_local_ranks == 8 and num_ranks > 8 assert num_local_ranks == 8 and num_ranks > 8
for seed in range(int(1e9)): for seed in range(int(1e9)):
......
import argparse
import random import random
import time
import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import numpy as np
from functools import partial from functools import partial
from typing import Optional
import deep_ep import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back
def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, seed: int = 0):
use_logfmt: bool = False, seed: int = 0):
torch.manual_seed(seed + rank) torch.manual_seed(seed + rank)
random.seed(seed + rank) random.seed(seed + rank)
assert num_experts % num_ranks == 0 assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceed the BF16 precision limit # NOTES: the integers greater than 256 exceeds the BF16 precision limit
rank_offset = 128 rank_offset = 128
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)' assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset) x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1) x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
x_list = [x]
for i in range(4 if use_logfmt else 0):
# NOTES: make more LogFMT casts and also with some BF16
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.5 * random.random())
# NOTES: the last one is for performance testing
# Most of the values in the perf case is lower than the threshold, casting most channels
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
# topk_idx = topk_idx.to(int64_t)
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs() topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
# Randomly mask some positions # Randomly mask some positions
for i in range(10): for i in range(10):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1 topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
...@@ -47,75 +31,67 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -47,75 +31,67 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# Check dispatch correctness # Check dispatch correctness
do_check = True do_check = True
hash_value, num_times = 0, 0 hash_value, num_times = 0, 0
for current_x in x_list:
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
for dispatch_use_fp8 in (False, True): for dispatch_use_fp8 in (False, True):
for round_scale in (False, True) if dispatch_use_fp8 else (False, ): num_times += 1
for use_ue8m0 in (False, True) if round_scale else (False, ): for i in range((num_times % 2) + 1):
num_times += 1 packed_recv_x, packed_recv_count, handle, event, hook = \
for i in range((num_times % 2) + 1): buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8,
cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda') async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
packed_recv_x, packed_recv_count, handle, event, hook = \ hook() if return_recv_hook else event.current_stream_wait()
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, # print('run {}/{}, dispatch_use_fp8={}'.format(i + 1, num_times, dispatch_use_fp8))
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, # return
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
hook() if return_recv_hook else event.current_stream_wait() simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x if dispatch_use_fp8 else packed_recv_x.clone()
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \ # print(f"rank{rank}: packed_recv_x[0]\n{packed_recv_x[0].cpu()}\n")
if dispatch_use_fp8 else packed_recv_x.clone() # print(f"rank{rank}: packed_recv_x[1]\n{packed_recv_x[1].cpu()}\n")
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda') # print(f"simulated_gemm_x{simulated_gemm_x.cpu()}")
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
for i in range(num_local_experts if do_check else 0): dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
expert_id = rank * num_local_experts + i for i in range(num_local_experts if do_check else 0):
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i] expert_id = rank * num_local_experts + i
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i] recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices
int_mask = (2 ** 32) - 1 # Check expert indices
num_valid_tokens = recv_count.item() int_mask = (2 ** 32) - 1
assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}' num_valid_tokens = recv_count.item()
assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()' assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}' assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
if num_valid_tokens == 0: # Check received data
continue recv_x = recv_x[:num_valid_tokens]
# Check received data recv_x_amin = recv_x[:, :-128].amin(dim=-1)
if current_x is x: recv_src_info = recv_src_info[:num_valid_tokens]
recv_x = recv_x[:num_valid_tokens] assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
recv_x_amin = recv_x[:, :-128].amin(dim=-1) assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
recv_src_info = recv_src_info[:num_valid_tokens] for j in range(num_ranks):
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
if round_scale: assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007 assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0
else: if dispatch_use_fp8:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
for j in range(num_ranks): hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item() else:
if not round_scale: hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0 # Check combine correctness
if dispatch_use_fp8: for zero_copy in (False, True):
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens]) if zero_copy:
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens]) buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
else: out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
async_finish=not return_recv_hook,
# Check combine correctness return_recv_hook=return_recv_hook, out=out)
for zero_copy in (False, ) if use_logfmt else (False, True): hook() if return_recv_hook else event.current_stream_wait()
if zero_copy: if do_check:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') assert torch.isnan(combined_x).sum().item() == 0
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, assert diff < 1e-5, f'Error: diff={diff}'
use_logfmt=use_logfmt, hash_value ^= hash_tensor(combined_x)
async_finish=not return_recv_hook, zero_copy=zero_copy,
return_recv_hook=return_recv_hook, out=out)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}'
hash_value ^= hash_tensor(combined_x)
# noinspection PyShadowingNames # noinspection PyShadowingNames
def large_gemm_with_hook(hook): def large_gemm_with_hook(hook):
...@@ -125,100 +101,70 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -125,100 +101,70 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hook() hook()
# noinspection PyShadowingNames # noinspection PyShadowingNames
def test_func(return_recv_hook: bool): def test_func(zero_copy: bool, return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \ recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, async_finish=False, return_recv_hook=return_recv_hook)
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
use_logfmt=use_logfmt, return_recv_hook=return_recv_hook) zero_copy=zero_copy, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth # Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_logfmt10_bytes = hidden * 10 / 8 + hidden / 128 * 4
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens): for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item() num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += (num_logfmt10_bytes if use_logfmt else num_bf16_bytes) * num_selections num_combine_comm_bytes += num_bf16_bytes * num_selections
# Dispatch + combine testing # Dispatch + combine testing
avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False)) avg_t, min_t, max_t = bench(partial(test_func, zero_copy=False, return_recv_hook=False))
print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, '
f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True) f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True)
# Separate profiling # Separate profiling
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
group.barrier() group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
dispatch_t, combine_t = bench_kineto(partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True, kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True,
suppress_kineto_output=True, num_kernels_per_period=2 if return_recv_hook else 1) suppress_kineto_output=True)
if not return_recv_hook: if not return_recv_hook:
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True) f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us')
else: else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | ' print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True) f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} us')
return hash_value return hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames # noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): def test_loop(local_rank: int, num_local_ranks: int):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks) rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens, hidden = args.num_tokens, args.hidden # The default setting of deepEP upstream is below:
num_topk, num_experts = args.num_topk, args.num_experts num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 256
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts) num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0: if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True) print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks, num_qps_per_rank=num_experts // num_ranks, explicitly_destroy=True)
allow_nvlink_for_low_latency_mode=not args.disable_nvlink, explicitly_destroy=True, test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
allow_mnnvl=args.allow_mnnvl)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer,
use_logfmt=args.use_logfmt, seed=1)
do_pressure_test = args.pressure_test do_pressure_test = False
for seed in range(int(1e9) if do_pressure_test else 0): for seed in range(int(1e9) if do_pressure_test else 0):
if local_rank == 0: if local_rank == 0:
print(f'Testing with seed {seed} ...', flush=True) print(f'Testing with seed {seed} ...', flush=True)
ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed)
use_logfmt=args.use_logfmt, seed=seed)
for i in range(20): for i in range(20):
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}'
use_logfmt=args.use_logfmt, seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group
buffer.destroy()
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__': if __name__ == '__main__':
print("main start...")
# TODO: you may modify NUMA binding for less CPU overhead # TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512` num_processes = 8
parser = argparse.ArgumentParser(description='Test low-latency EP kernels') torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)
parser.add_argument('--num-processes', type=int, default=8,
help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=128,
help='Number of tokens (default: 128)')
parser.add_argument('--hidden', type=int, default=7168,
help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8,
help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=288,
help='Number of experts (default: 288)')
parser.add_argument('--allow-mnnvl', action="store_true",
help='Allow MNNVL for communication')
parser.add_argument('--disable-nvlink', action='store_true',
help='Whether to disable NVLink for testing')
parser.add_argument('--use-logfmt', action='store_true',
help='Whether to test LogFMT combine')
parser.add_argument("--pressure-test", action='store_true',
help='Whether to do pressure test')
args = parser.parse_args()
num_processes = args.num_processes
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment