Commit fe000141 authored by lijian6's avatar lijian6
Browse files

Merge branch 'update_details' into 'main'

Update details

See merge request dcutoolkit/deeplearing/DeepEP!8
parents f08e5bf1 8a0688f3
...@@ -272,8 +272,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -272,8 +272,9 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr); const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual); const auto* dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
} else { } else
#endif #endif
{
#if !defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_schar_put_nbi_warp(ctx, internode::shmem_ctx_schar_put_nbi_warp(ctx,
#else #else
...@@ -281,9 +282,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -281,9 +282,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
#endif #endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr), reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr),
num_bytes_per_msg, dst_rank); num_bytes_per_msg, dst_rank);
#if defined(FORCE_NVSHMEM_API)
} }
#endif
} else { } else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls // NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr); const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
...@@ -349,19 +348,19 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -349,19 +348,19 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
#if defined(FORCE_NVSHMEM_API) #if defined(FORCE_NVSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank); void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) { // P2P enabled if (peer_base_addr) { // P2P enabled
int *rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_count + dst_expert_local_idx * num_ranks + rank) - (char *)(nvshmemi_device_state_d.heap_base))); int *rptr_actual = (int *)((char *)(peer_base_addr) +
((char *)(rdma_recv_count + dst_expert_local_idx * num_ranks + rank) - (char *)(nvshmemi_device_state_d.heap_base)));
st_na_release(rptr_actual, -num_tokens_sent - 1); st_na_release(rptr_actual, -num_tokens_sent - 1);
} else { } else
#endif #endif
{
#if !defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_long_atomic_add(ctx, internode::shmem_ctx_long_atomic_add(ctx,
#else #else
internode::shmem_long_atomic_add( internode::shmem_long_atomic_add(
#endif #endif
rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank);
#if defined(FORCE_NVSHMEM_API)
} }
#endif
} else { } else {
st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1); st_na_release(reinterpret_cast<int *>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1);
} }
...@@ -648,8 +647,9 @@ combine(void* combined_x, ...@@ -648,8 +647,9 @@ combine(void* combined_x,
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base)); char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base));
const auto dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual); const auto dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
} else { } else
#endif #endif
{
#if !defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_schar_put_nbi_warp(ctx, internode::shmem_ctx_schar_put_nbi_warp(ctx,
#else #else
...@@ -657,9 +657,7 @@ combine(void* combined_x, ...@@ -657,9 +657,7 @@ combine(void* combined_x,
#endif #endif
reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr), reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(buf_ptr),
hidden * sizeof(hip_bfloat16), dst_rank); hidden * sizeof(hip_bfloat16), dst_rank);
#if defined(FORCE_NVSHMEM_API)
} }
#endif
} }
} }
...@@ -677,19 +675,19 @@ combine(void* combined_x, ...@@ -677,19 +675,19 @@ combine(void* combined_x,
#if defined(FORCE_NVSHMEM_API) #if defined(FORCE_NVSHMEM_API)
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank); void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) { if (peer_base_addr) {
int *req_rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_flag + global_expert_idx) - (char *)(nvshmemi_device_state_d.heap_base))); int *req_rptr_actual = (int *)((char *)(peer_base_addr) +
((char *)(rdma_recv_flag + global_expert_idx) - (char *)(nvshmemi_device_state_d.heap_base)));
st_na_release(req_rptr_actual, 1); st_na_release(req_rptr_actual, 1);
} else { } else
#endif #endif
{
#if !defined(ROCM_DISABLE_CTX) #if !defined(ROCM_DISABLE_CTX)
internode::shmem_ctx_long_atomic_add(ctx, internode::shmem_ctx_long_atomic_add(ctx,
#else #else
internode::shmem_long_atomic_add( internode::shmem_long_atomic_add(
#endif #endif
rdma_recv_flag + global_expert_idx, 1, dst_rank); rdma_recv_flag + global_expert_idx, 1, dst_rank);
#if defined(FORCE_NVSHMEM_API)
} }
#endif
} else { } else {
st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1); st_na_release(reinterpret_cast<int*>(rdma_recv_flag + global_expert_idx), 1);
} }
...@@ -750,7 +748,8 @@ LOW_LATENCY_COMBINE_RECV: ...@@ -750,7 +748,8 @@ LOW_LATENCY_COMBINE_RECV:
#pragma unroll #pragma unroll
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) { for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// Read from sources // Read from sources
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + (reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot); auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) +
(reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4); auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type + 4);
// Reduce // Reduce
......
...@@ -162,12 +162,12 @@ class Buffer: ...@@ -162,12 +162,12 @@ class Buffer:
# Get current device and set appropriate HCA # Get current device and set appropriate HCA
current_device = torch.cuda.current_device() current_device = torch.cuda.current_device()
# Translate CUDA_VISIBLE_DEVICES # # Translate CUDA_VISIBLE_DEVICES
if 'CUDA_VISIBLE_DEVICES' in os.environ: # if 'CUDA_VISIBLE_DEVICES' in os.environ:
visible_devices = os.environ['CUDA_VISIBLE_DEVICES'].split(",") # visible_devices = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
assert len(visible_devices) > current_device, f"CUDA_VISIBLE_DEVICES has {len(visible_devices)} entries which is fewer than the current device {current_device}" # assert len(visible_devices) > current_device, f"CUDA_VISIBLE_DEVICES has {len(visible_devices)} entries which is fewer than the current device {current_device}"
assert visible_devices[current_device].isdigit(), f"DEEP_EP_DEVICE_TO_HCA_MAPPING requires CUDA_VISIBLE_DEVICES to contain integer indices" # assert visible_devices[current_device].isdigit(), f"DEEP_EP_DEVICE_TO_HCA_MAPPING requires CUDA_VISIBLE_DEVICES to contain integer indices"
current_device = int(visible_devices[current_device]) # current_device = int(visible_devices[current_device])
assert current_device in device_mapping, f"Current CUDA device {current_device} not found in DEEP_EP_DEVICE_TO_HCA_MAPPING" assert current_device in device_mapping, f"Current CUDA device {current_device} not found in DEEP_EP_DEVICE_TO_HCA_MAPPING"
os.environ['NVSHMEM_ENABLE_PE_MAPPING'] = '1' os.environ['NVSHMEM_ENABLE_PE_MAPPING'] = '1'
......
...@@ -7,7 +7,7 @@ from functools import partial ...@@ -7,7 +7,7 @@ from functools import partial
from typing import Literal, Set from typing import Literal, Set
import deep_ep import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back, per_token_cast_back_int8 from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back_int8
def test_main(num_tokens: int, def test_main(num_tokens: int,
...@@ -68,6 +68,7 @@ def test_main(num_tokens: int, ...@@ -68,6 +68,7 @@ def test_main(num_tokens: int,
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
simulated_gemm_x = per_token_cast_back_int8(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, 1)).view(packed_recv_x[0].shape)
for i in range(num_local_experts if do_check else 0): for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i expert_id = rank * num_local_experts + i
...@@ -133,10 +134,16 @@ def test_main(num_tokens: int, ...@@ -133,10 +134,16 @@ def test_main(num_tokens: int,
use_fp8=True, round_scale=False, use_ue8m0=False, use_int8=True, use_fp8=True, round_scale=False, use_ue8m0=False, use_int8=True,
async_finish=False, return_recv_hook=return_recv_hook) async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth # Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 scale_size = 1 # hidden / 128
num_logfmt10_bytes = hidden * 10 / 8 + hidden / 128 * 4 num_fp8_bytes, num_bf16_bytes = (hidden + scale_size * 4 + 16), hidden * 2
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens): for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item() num_selections = (topk_idx[i] != -1).sum().item()
...@@ -144,18 +151,20 @@ def test_main(num_tokens: int, ...@@ -144,18 +151,20 @@ def test_main(num_tokens: int,
num_combine_comm_bytes += num_bf16_bytes * num_selections num_combine_comm_bytes += num_bf16_bytes * num_selections
# Separate profiling # Separate profiling
for return_recv_hook in (True, ): for return_recv_hook in (True, False):
group.barrier() group.barrier()
dispatch_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook), dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
kernel_names='dispatch', kernel_names=('dispatch', 'combine'),
barrier_comm_profiling=True, barrier_comm_profiling=True,
suppress_kineto_output=True, suppress_kineto_output=True,
num_kernels_per_period=2 if return_recv_hook else 1) num_kernels_per_period=2 if return_recv_hook else 1)
if not return_recv_hook: if not return_recv_hook:
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us', print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us',
flush=True) flush=True)
else: else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us', print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us',
flush=True) flush=True)
return hash_value return hash_value
...@@ -178,30 +187,6 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -178,30 +187,6 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
allow_mnnvl=args.allow_mnnvl) allow_mnnvl=args.allow_mnnvl)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1) test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
# do_pressure_test = args.pressure_test
# for seed in range(int(1e9) if do_pressure_test else 0):
# if local_rank == 0:
# print(f'Testing with seed {seed} ...', flush=True)
# ref_hash = test_main(num_tokens,
# hidden,
# num_experts,
# num_topk,
# rank,
# num_ranks,
# group,
# buffer,
# seed=seed)
# for _ in range(20):
# assert test_main(num_tokens,
# hidden,
# num_experts,
# num_topk,
# rank,
# num_ranks,
# group,
# buffer,
# seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group # Destroy the buffer runtime and communication group
buffer.destroy() buffer.destroy()
dist.barrier() dist.barrier()
...@@ -214,7 +199,7 @@ if __name__ == '__main__': ...@@ -214,7 +199,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Test low-latency EP kernels') parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)') parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)')
parser.add_argument('--hidden', type=int, default=2560, help='Hidden dimension size (default: 7168)') parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)') parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=256, help='Number of experts (default: 288)') parser.add_argument('--num-experts', type=int, default=256, help='Number of experts (default: 288)')
parser.add_argument('--allow-mnnvl', action="store_true", help='Allow MNNVL for communication') parser.add_argument('--allow-mnnvl', action="store_true", help='Allow MNNVL for communication')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment