Unverified Commit c8dceba1 authored by Chenggang Zhao's avatar Chenggang Zhao Committed by GitHub
Browse files

Use TMA instead of LD/ST for intra-node normal kernels (#191)

* Update CMake files

* Use TMA instead of LD/ST for intranode dispatch

* Use TMA instead of LD/ST for intranode combine

* Adjust configs

* Test default configs as well

* More warps for combine

* Add inter-thread fence

* Enable more warps

* Do not use TMA for senders

* Update configs

* Remove useless wait
parent df4debe3
...@@ -9,7 +9,10 @@ set(CUDA_SEPARABLE_COMPILATION ON) ...@@ -9,7 +9,10 @@ set(CUDA_SEPARABLE_COMPILATION ON)
list(APPEND CUDA_NVCC_FLAGS "-O3") list(APPEND CUDA_NVCC_FLAGS "-O3")
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage") list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")
set(TORCH_CUDA_ARCH_LIST "9.0") set(USE_SYSTEM_NVTX on)
set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile")
set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
find_package(CUDAToolkit REQUIRED) find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED) find_package(pybind11 REQUIRED)
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
...@@ -19,9 +22,8 @@ add_library(nvshmem ALIAS nvshmem::nvshmem) ...@@ -19,9 +22,8 @@ add_library(nvshmem ALIAS nvshmem::nvshmem)
add_library(nvshmem_host ALIAS nvshmem::nvshmem_host) add_library(nvshmem_host ALIAS nvshmem::nvshmem_host)
add_library(nvshmem_device ALIAS nvshmem::nvshmem_device) add_library(nvshmem_device ALIAS nvshmem::nvshmem_device)
# Seems bugs with CMake, NVCC 12 and C++ 17
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 14) set(CMAKE_CUDA_STANDARD 17)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${NVSHMEM_INCLUDE_DIR}) include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${NVSHMEM_INCLUDE_DIR})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib ${NVSHMEM_LIB_DIR}) link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib ${NVSHMEM_LIB_DIR})
......
...@@ -19,8 +19,6 @@ ...@@ -19,8 +19,6 @@
#ifdef __CLION_IDE__ #ifdef __CLION_IDE__
#define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier) #define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier)
#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier) #define __CUDACC_RDC__ // NOLINT(*-reserved-identifier)
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { asm volatile("trap;"); }
#define printf host_device_printf
#endif #endif
// Remove Torch restrictions // Remove Torch restrictions
......
This diff is collapsed.
...@@ -266,6 +266,67 @@ __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value ...@@ -266,6 +266,67 @@ __device__ __forceinline__ void st_na_global(const int4 *ptr, const int4& value
::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w)); ::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
} }
__device__ __forceinline__ void fence_view_async_shared() {
asm volatile("fence.proxy.async.shared::cta; \n" :: );
}
__device__ __forceinline__ void fence_barrier_init() {
asm volatile("fence.mbarrier_init.release.cluster; \n" :: );
}
__device__ __forceinline__ void mbarrier_init(uint64_t* mbar_ptr, uint32_t arrive_count) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile("mbarrier.init.shared::cta.b64 [%1], %0;" :: "r"(arrive_count), "r"(mbar_int_ptr));
}
__device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phase) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile("{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra DONE; \n\t"
"bra LAB_WAIT; \n\t"
"DONE: \n\t"
"}" :: "r"(mbar_int_ptr), "r"(phase), "r"(0x989680));
phase ^= 1;
}
__device__ __forceinline__ void mbarrier_arrive_and_expect_tx(uint64_t* mbar_ptr, int num_bytes) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: "r"(num_bytes), "r"(mbar_int_ptr));
}
__device__ __forceinline__ void tma_store_fence() {
asm volatile ("fence.proxy.async.shared::cta;");
}
constexpr uint64_t kEvictFirst = 0x12f0000000000000;
constexpr uint64_t kEvictNormal = 0x1000000000000000;
__device__ __forceinline__ void tma_load_1d(const void* smem_ptr, const void* gmem_ptr, uint64_t* mbar_ptr, int num_bytes,
bool evict_first = true) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;
asm volatile("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\n"
:: "r"(smem_int_ptr), "l"(gmem_ptr), "r"(num_bytes), "r"(mbar_int_ptr), "l"(cache_hint) : "memory");
}
__device__ __forceinline__ void tma_store_1d(const void* smem_ptr, const void* gmem_ptr, int num_bytes,
bool evict_first = true) {
auto smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal;
asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\n"
:: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(num_bytes), "l"(cache_hint) : "memory");
asm volatile("cp.async.bulk.commit_group;");
}
template <int N = 0>
__device__ __forceinline__ void tma_store_wait() {
asm volatile("cp.async.bulk.wait_group.read %0;" :: "n"(N) : "memory");
}
template <typename dtype_t> template <typename dtype_t>
__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) { __host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b; return (a + b - 1) / b;
......
...@@ -171,8 +171,8 @@ class Buffer: ...@@ -171,8 +171,8 @@ class Buffer:
""" """
config_map = { config_map = {
2: Config(Buffer.num_sms, 16, 256, 6, 128), 2: Config(Buffer.num_sms, 24, 256, 6, 128),
4: Config(Buffer.num_sms, 16, 256, 6, 128), 4: Config(Buffer.num_sms, 6, 256, 6, 128),
8: Config(Buffer.num_sms, 6, 256, 6, 128), 8: Config(Buffer.num_sms, 6, 256, 6, 128),
16: Config(Buffer.num_sms, 16, 288, 20, 128), 16: Config(Buffer.num_sms, 16, 288, 20, 128),
24: Config(Buffer.num_sms, 8, 288, 32, 128), 24: Config(Buffer.num_sms, 8, 288, 32, 128),
...@@ -198,9 +198,9 @@ class Buffer: ...@@ -198,9 +198,9 @@ class Buffer:
""" """
config_map = { config_map = {
2: Config(Buffer.num_sms, 6, 256, 6, 128), 2: Config(Buffer.num_sms, 10, 256, 6, 128),
4: Config(Buffer.num_sms, 6, 256, 6, 128), 4: Config(Buffer.num_sms, 9, 256, 6, 128),
8: Config(Buffer.num_sms, 6, 256, 6, 128), 8: Config(Buffer.num_sms, 4, 256, 6, 128),
16: Config(Buffer.num_sms, 2, 288, 28, 128), 16: Config(Buffer.num_sms, 2, 288, 28, 128),
24: Config(Buffer.num_sms, 1, 288, 20, 128), 24: Config(Buffer.num_sms, 1, 288, 20, 128),
32: Config(Buffer.num_sms, 1, 288, 20, 128), 32: Config(Buffer.num_sms, 1, 288, 20, 128),
......
...@@ -153,14 +153,20 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: ...@@ -153,14 +153,20 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
for current_x in (x_e4m3, x): for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None best_time, best_results = 1e10, None
nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes
for nvl_chunk_size in range(4, 33, 4): for nvl_chunk_size in tuple(range(4, 33, 2)) + (0, ):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) if nvl_chunk_size > 0:
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
else:
# Test default config as well
deep_ep.Buffer.set_num_sms(num_sms)
config = deep_ep.Buffer.get_dispatch_config(num_ranks)
tune_args = {'x': current_x, 'handle': handle, 'config': config} tune_args = {'x': current_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.dispatch(**tune_args))[0] t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time: if t < best_time and nvl_chunk_size > 0:
best_time, best_results = t, (num_sms, nvl_chunk_size) best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0: if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
if local_rank == 0: if local_rank == 0:
print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True) print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True)
print('', flush=True) print('', flush=True)
...@@ -180,13 +186,19 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: ...@@ -180,13 +186,19 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer:
# Tune combine performance # Tune combine performance
best_time, best_results = 1e10, None best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 7, 1): for nvl_chunk_size in tuple(range(1, 17, 1)) + (0, ):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) if nvl_chunk_size > 0:
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
else:
# Test default config as well
deep_ep.Buffer.set_num_sms(num_sms)
config = deep_ep.Buffer.get_combine_config(num_ranks)
tune_args = {'x': recv_x, 'handle': handle, 'config': config} tune_args = {'x': recv_x, 'handle': handle, 'config': config}
t = bench(lambda: buffer.combine(**tune_args))[0] t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0: if local_rank == 0:
print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: '
if t < best_time: f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True)
if t < best_time and nvl_chunk_size > 0:
best_time, best_results = t, (num_sms, nvl_chunk_size) best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0: if local_rank == 0:
...@@ -202,7 +214,7 @@ def test_loop(local_rank: int, num_local_ranks: int): ...@@ -202,7 +214,7 @@ def test_loop(local_rank: int, num_local_ranks: int):
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts) num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
buffer = deep_ep.Buffer(group, int(1e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility, buffer = deep_ep.Buffer(group, int(2e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1)) num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
torch.manual_seed(rank) torch.manual_seed(rank)
...@@ -216,6 +228,10 @@ def test_loop(local_rank: int, num_local_ranks: int): ...@@ -216,6 +228,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1) test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1)
# Destroy the communication group
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__': if __name__ == '__main__':
num_processes = 8 num_processes = 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment