Unverified Commit c5facf5c authored by Zhean Xu's avatar Zhean Xu Committed by GitHub
Browse files

Support 10-bit LogFMT Combine (#345)



* independent logfmt_simulate function

* draft: logfmt low latency combine

* Minor bug fixes

* Fix non-logfmt bugs

* Fix logfmt bugs

* Fix logfmt bugs

* Minor fix

* Minor fix

* Clean code

* Clean code

* Use fewer regs

* Use two warp groups

* Correct shared memory size

* Minor fix

* Minor fix

* More rigorous tests

* Clean code

* Use more SMs

* Use different unroll factor for send & recv

* Update csrc/kernels/internode_ll.cu
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>

* Update csrc/kernels/internode_ll.cu
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>

* Some renaming

* Some comments of tests

* Format `logfmt_encode`

* More lints

* Some refactors on sends

* Fix testing

* Fix bugs

* Renaming

* Use the full warp

* Unify combine recv

* Lint

* Lint

* Support 2560

* Fix meta buffer dtype

* Better encode calls

* Better amin/max writes

* Extra sync

* Read `topk_idx` by once

* Better specialization

* Read weights by once

* Rename

* Bug fixed

* Some renaming

* Fix local memory usage for sending

* Fix local memory usage for receiving

* Less writes

* Optimize performance

* Optimize performance

* Better performance

* Optimize performance

* Fix rounding

* Manually unroll

* Fix bench

---------
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: default avatarChenggang Zhao <chenggangz@deepseek.com>
parent 26cf250a
...@@ -136,9 +136,10 @@ struct LowLatencyLayout { ...@@ -136,9 +136,10 @@ struct LowLatencyLayout {
// Message sizes // Message sizes
// NOTES: you should add a control `int4` for combine messages if you want to do data transformation // NOTES: you should add a control `int4` for combine messages if you want to do data transformation
// NOTES: `num_scales * sizeof(nv_bfloat162)` means the per-128-channel min/max
EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden);
size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float));
size_t num_bytes_per_combine_msg = hidden * sizeof(nv_bfloat16); size_t num_bytes_per_combine_msg = num_scales * sizeof(nv_bfloat162) + hidden * sizeof(nv_bfloat16);
// Send buffer // Send buffer
size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg;
......
This diff is collapsed.
...@@ -322,8 +322,15 @@ __device__ __forceinline__ void mbarrier_init(uint64_t* mbar_ptr, uint32_t arriv ...@@ -322,8 +322,15 @@ __device__ __forceinline__ void mbarrier_init(uint64_t* mbar_ptr, uint32_t arriv
asm volatile("mbarrier.init.shared::cta.b64 [%1], %0;" :: "r"(arrive_count), "r"(mbar_int_ptr)); asm volatile("mbarrier.init.shared::cta.b64 [%1], %0;" :: "r"(arrive_count), "r"(mbar_int_ptr));
} }
__device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phase) { __device__ __forceinline__ void mbarrier_inval(uint64_t* mbar_ptr) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr)); auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile("mbarrier.inval.shared::cta.b64 [%0];" :: "r"(mbar_int_ptr));
}
template <bool kWithMultiStages = false>
__device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phase, int stage_idx = 0) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
const auto& wait = kWithMultiStages ? (phase >> stage_idx) & 1 : phase;
asm volatile("{\n\t" asm volatile("{\n\t"
".reg .pred P1; \n\t" ".reg .pred P1; \n\t"
"LAB_WAIT: \n\t" "LAB_WAIT: \n\t"
...@@ -331,8 +338,8 @@ __device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phas ...@@ -331,8 +338,8 @@ __device__ __forceinline__ void mbarrier_wait(uint64_t* mbar_ptr, uint32_t& phas
"@P1 bra DONE; \n\t" "@P1 bra DONE; \n\t"
"bra LAB_WAIT; \n\t" "bra LAB_WAIT; \n\t"
"DONE: \n\t" "DONE: \n\t"
"}" :: "r"(mbar_int_ptr), "r"(phase), "r"(0x989680)); "}" :: "r"(mbar_int_ptr), "r"(wait), "r"(0x989680));
phase ^= 1; phase ^= kWithMultiStages ? (1 << stage_idx) : 1;
} }
__device__ __forceinline__ void mbarrier_arrive_and_expect_tx(uint64_t* mbar_ptr, int num_bytes) { __device__ __forceinline__ void mbarrier_arrive_and_expect_tx(uint64_t* mbar_ptr, int num_bytes) {
...@@ -340,6 +347,11 @@ __device__ __forceinline__ void mbarrier_arrive_and_expect_tx(uint64_t* mbar_ptr ...@@ -340,6 +347,11 @@ __device__ __forceinline__ void mbarrier_arrive_and_expect_tx(uint64_t* mbar_ptr
asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: "r"(num_bytes), "r"(mbar_int_ptr)); asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: "r"(num_bytes), "r"(mbar_int_ptr));
} }
__device__ __forceinline__ void mbarrier_arrive(uint64_t* mbar_ptr) {
auto mbar_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar_ptr));
asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0]; \n\t" :: "r"(mbar_int_ptr));
}
__device__ __forceinline__ void tma_store_fence() { __device__ __forceinline__ void tma_store_fence() {
asm volatile ("fence.proxy.async.shared::cta;"); asm volatile ("fence.proxy.async.shared::cta;");
} }
...@@ -518,36 +530,56 @@ __forceinline__ __device__ void release_lock(int* mutex) { ...@@ -518,36 +530,56 @@ __forceinline__ __device__ void release_lock(int* mutex) {
template <typename T> struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } }; template <typename T> struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } };
template <typename T> struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } }; template <typename T> struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } };
template <typename T> struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } }; template <typename T> struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } };
template <typename T> struct ReduceAnd { __device__ T operator()(T a, T b) const { return a & b; } };
template <typename T> struct ReduceOr { __device__ T operator()(T a, T b) const { return a | b; } };
// Unified reduction function // Unified reduction function
template <uint32_t kNumLanes, typename T, typename Op> template <int kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>
__forceinline__ __device__ T warp_reduce(T value, Op op) { __forceinline__ __device__ T warp_reduce(T value, Op op) {
EP_STATIC_ASSERT(kNumLanes == 32 or kNumLanes == 16 or kNumLanes == 8 or EP_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or
kNumLanes == 4 or kNumLanes == 2 or kNumLanes == 1, kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,
"Invalid number of lanes"); "Invalid number of lanes");
constexpr uint32_t mask = 0xffffffff;
if constexpr (kNumLanes >= 32) value = op(value, __shfl_xor_sync(0xffffffff, value, 16)); if constexpr (kIntergroupReduce) {
if constexpr (kNumLanes >= 16) value = op(value, __shfl_xor_sync(0xffffffff, value, 8)); if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1));
if constexpr (kNumLanes >= 8) value = op(value, __shfl_xor_sync(0xffffffff, value, 4)); if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2));
if constexpr (kNumLanes >= 4) value = op(value, __shfl_xor_sync(0xffffffff, value, 2)); if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4));
if constexpr (kNumLanes >= 2) value = op(value, __shfl_xor_sync(0xffffffff, value, 1)); if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8));
if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16));
} else {
if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16));
if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8));
if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4));
if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2));
if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1));
}
return value; return value;
} }
// Convenience aliases // Convenience aliases
template < uint32_t kNumLanes = 32, typename T> template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_sum(T value) { __forceinline__ __device__ T warp_reduce_sum(T value) {
return warp_reduce<kNumLanes, T>(value, ReduceSum<T>{}); return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});
} }
template <uint32_t kNumLanes = 32, typename T> template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_max(T value) { __forceinline__ __device__ T warp_reduce_max(T value) {
return warp_reduce<kNumLanes, T>(value, ReduceMax<T>{}); return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceMax<T>{});
} }
template <uint32_t kNumLanes = 32, typename T> template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_min(T value) { __forceinline__ __device__ T warp_reduce_min(T value) {
return warp_reduce<kNumLanes, T>(value, ReduceMin<T>{}); return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceMin<T>{});
}
template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_and(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceAnd<T>{});
}
template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_or(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceOr<T>{});
} }
} // namespace deep_ep } // namespace deep_ep
...@@ -27,7 +27,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -27,7 +27,14 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset) x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1) x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1 x_list = [x]
for i in range(4 if use_logfmt else 0):
# NOTES: make more LogFMT casts and also with some BF16
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.5 * random.random())
# NOTES: the last one is for performance testing
# Most of the values in the perf case is lower than the threshold, casting most channels
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs() topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
...@@ -39,7 +46,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -39,7 +46,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# Check dispatch correctness # Check dispatch correctness
do_check = True do_check = True
hash_value, num_times = 0, 0 hash_value, num_times = 0, 0
for current_x in (x, x_pure_rand): for current_x in x_list:
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
for dispatch_use_fp8 in (False, True): for dispatch_use_fp8 in (False, True):
for round_scale in (False, True) if dispatch_use_fp8 else (False, ): for round_scale in (False, True) if dispatch_use_fp8 else (False, ):
...@@ -71,7 +78,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -71,7 +78,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}' assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}'
# Check received data # Check received data
if current_x is not x_pure_rand: if current_x is x:
recv_x = recv_x[:num_valid_tokens] recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1) recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens] recv_src_info = recv_src_info[:num_valid_tokens]
...@@ -104,7 +111,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -104,7 +111,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
if do_check: if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0 assert torch.isnan(combined_x).sum().item() == 0
assert diff < (7e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {zero_copy=}' assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}'
hash_value ^= hash_tensor(combined_x) hash_value ^= hash_tensor(combined_x)
# noinspection PyShadowingNames # noinspection PyShadowingNames
...@@ -117,7 +124,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -117,7 +124,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# noinspection PyShadowingNames # noinspection PyShadowingNames
def test_func(return_recv_hook: bool): def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \ recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(x_pure_rand, topk_idx, num_tokens, num_experts, buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats,
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook) use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
...@@ -127,11 +134,12 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, ...@@ -127,11 +134,12 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
# Calculate bandwidth # Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_logfmt10_bytes = hidden * 10 / 8 + hidden / 128 * 4
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens): for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item() num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += num_bf16_bytes * num_selections num_combine_comm_bytes += (num_logfmt10_bytes if use_logfmt else num_bf16_bytes) * num_selections
# Dispatch + combine testing # Dispatch + combine testing
avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False)) avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False))
......
...@@ -53,7 +53,7 @@ def per_token_cast_to_fp8(x: torch.Tensor): ...@@ -53,7 +53,7 @@ def per_token_cast_to_fp8(x: torch.Tensor):
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor): def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
if x_scales.dtype == torch.int: if x_scales.dtype == torch.int:
x_scales = x_scales.view(dtype=torch.int8).to(torch.int) << 23 x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23
x_scales = x_scales.view(dtype=torch.float) x_scales = x_scales.view(dtype=torch.float)
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128) x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
x_scales = x_scales.view(x_fp8.size(0), -1, 1) x_scales = x_scales.view(x_fp8.size(0), -1, 1)
...@@ -171,6 +171,7 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr ...@@ -171,6 +171,7 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
for _ in range(num_tests): for _ in range(num_tests):
fn() fn()
torch.cuda.synchronize()
prof.step() prof.step()
# Parse the profiling table # Parse the profiling table
...@@ -219,4 +220,4 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr ...@@ -219,4 +220,4 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr
def hash_tensor(t: torch.Tensor): def hash_tensor(t: torch.Tensor):
return t.view(torch.int64).sum().item() return t.view(torch.int).sum().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment