Commit d0fcf024 authored by lishen's avatar lishen
Browse files

Merge branch 'quant_master' into 'main'

And quant.

See merge request dcutoolkit/deeplearing/DeepEP!19
parents 81e56124 ace6e18e
......@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts) {
const int num_scales = hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL;
const int num_scales = hidden / QUANTIZATION_GROUPSIZE;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
......
......@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
int quant_type, int quant_group_size, bool fp8_round_scale,
bool async, bool return_recv_hook) {
EP_HOST_ASSERT(low_latency_mode);
......@@ -1327,8 +1327,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait(launch_stream, compute_stream);
// Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(use_int8 ? torch::kInt8 : use_fp8 ? torch::kFloat8_e4m3fnuz: torch::kBFloat16));
auto packed_recv_x_dtype = torch::kBFloat16;
switch (quant_type) {
case 1: packed_recv_x_dtype = torch::kInt8; break;
case 2: packed_recv_x_dtype = torch::kFloat8_e4m3fn; break;
case 3: packed_recv_x_dtype = torch::kFloat8_e4m3fn; break;
case 4: packed_recv_x_dtype = torch::kFloat8_e5m2; break;
}
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(packed_recv_x_dtype));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
......@@ -1336,21 +1343,28 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate column-majored scales
auto packed_recv_x_scales = std::optional<torch::Tensor>();
void* packed_recv_x_scales_ptr = nullptr;
if (use_fp8) {
if (quant_type > 0) {
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
// TODO: support unaligned cases
EP_HOST_ASSERT(hidden % (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4) == 0);
EP_HOST_ASSERT(!(use_ue8m0 && use_int8));
EP_HOST_ASSERT(hidden % (QUANTIZATION_GROUPSIZE * 4) == 0);
// 计算scale_col的大小
int scales_col_size = 1; // 默认为per-channel
if (quant_group_size > 0) {
if (quant_type == 3) { // FP8_UE8M0比较特殊
scales_col_size = hidden / (QUANTIZATION_GROUPSIZE * 4);
} else {
scales_col_size = hidden / QUANTIZATION_GROUPSIZE;
}
}
if (use_ue8m0) {
EP_HOST_ASSERT(round_scale);
packed_recv_x_scales = torch::empty({num_local_experts, hidden / (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4), num_ranks * num_max_dispatch_tokens_per_rank},
// 设置packed_recv_x_scales
if (quant_type == 3) { // FP8_UE8M0比较特殊,需要单独处理
EP_HOST_ASSERT(fp8_round_scale && quant_group_size == 128);
packed_recv_x_scales = torch::empty({num_local_experts, scales_col_size, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kInt).device(torch::kCUDA));
} else if (use_int8) {
packed_recv_x_scales = torch::empty({num_local_experts, 1, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
} else {
packed_recv_x_scales = torch::empty({num_local_experts, hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL, num_ranks * num_max_dispatch_tokens_per_rank},
packed_recv_x_scales = torch::empty({num_local_experts, scales_col_size, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
}
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
......@@ -1370,7 +1384,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
use_fp8, round_scale, use_ue8m0, use_int8,
quant_type, quant_group_size, fp8_round_scale,
workspace, num_device_sms, launch_stream, phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
......
......@@ -177,7 +177,7 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
int quant_type, int quant_group_size, bool fp8_round_scale,
bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
......
......@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
int quant_type, int group_size, bool fp8_round_scale,
void* workspace, int num_device_sms, hipStream_t stream, int phases);
void combine(void* combined_x,
......
......@@ -23,7 +23,7 @@
#define LOW_LATENCY_SEND_PHASE 1
#define LOW_LATENCY_RECV_PHASE 2
#define NUM_INTERNODE_DISPATCH_BLOCKS_PER_CHANNEL 3
#define FP8_QUANTIZATION_NUM_PER_CHANNEL 128
#define QUANTIZATION_GROUPSIZE 128
#define DEFAULT_NUM_CU 20
#define DEFAULT_NUM_MAX_XGMI_CHUNKED_SEND_TOKENS 6
......
This diff is collapsed.
......@@ -62,7 +62,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case 8: \
case_macro(8); \
default: \
EP_HOST_ASSERT(false and "Unsupported ranks"); \
EP_HOST_ASSERT(false and "Unsupported ranks"); \
} \
while (false)
......@@ -83,7 +83,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case 20: \
case_macro(20); \
default: \
EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
} \
while (false)
......@@ -96,7 +96,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case 8: \
case_macro(dtype, 8); \
default: \
EP_HOST_ASSERT(false and "Unsupported ranks"); \
EP_HOST_ASSERT(false and "Unsupported ranks"); \
} \
while (false)
......@@ -107,7 +107,7 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case HIP_R_32F: \
case_macro(float); \
default: \
EP_HOST_ASSERT(false and "Unsupported type"); \
EP_HOST_ASSERT(false and "Unsupported type"); \
} \
while (false)
......@@ -121,7 +121,9 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, Args &&...a
case_macro(4096); \
case 7168: \
case_macro(7168); \
case 8192: \
case_macro(8192); \
default: \
EP_HOST_ASSERT(false and "Unsupported hidden"); \
EP_HOST_ASSERT(false and "Unsupported hidden"); \
} \
while (false)
......@@ -341,10 +341,13 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return *reinterpret_cast<dtype_t *>(recv_int_values);
}
constexpr float kFP8Margin = 1e-4;
constexpr float kFinfoAmaxE4M3 = 240.0f;
// 设置不同的量化方式的最大值与相反数
constexpr float kFinfoAmaxE4M3 = 448.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
constexpr float kInt8Amax = 127.0f;
constexpr float kFinfoAmaxE5M2 = 57344.0f;
constexpr float kFinfoAmaxInvE5M2 = 1.0f / kFinfoAmaxE5M2;
constexpr float kFinfoAmaxInt8 = 127.0f;
constexpr float kFinfoAmaxInvInt8 = 1.0f / 127.0f;
__forceinline__ __device__ float fast_pow2(int x) {
// We can ensure `-126 <= x and x <= 127`
......@@ -359,22 +362,33 @@ __forceinline__ __device__ int fast_log2_ceil(float x) {
return exp_x - 127 + (man_bits != 0);
}
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) {
if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE4M3;
scale = kFinfoAmaxE4M3 / amax;
template <int kQuantType>
__forceinline__ __device__ void calculate_quant8bit_scales(float amax, float& scale, float& scale_inv, bool round_scale=0) {
amax = fmaxf(amax, 1e-6f);
if constexpr(kQuantType == 1) { // 使用 INT8 对称量化
scale_inv = kFinfoAmaxInvInt8 * amax;
scale = kFinfoAmaxInt8 / amax;
} else if constexpr(kQuantType == 2 || kQuantType == 3) { // 使用 FP8_E4M3 或 FP8_UE8M0 非对称量化
if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE4M3;
scale = kFinfoAmaxE4M3 / amax;
}
} else if constexpr(kQuantType == 4) { // 使用 FP8_E5M2 对称量化
if (round_scale) {
auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE5M2);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * kFinfoAmaxInvE5M2;
scale = kFinfoAmaxE5M2 / amax;
}
}
}
__forceinline__ __device__ void calculate_int8_scales(float amax, float& scale, float& scale_inv) {
scale = kInt8Amax / amax;
scale_inv = amax / kInt8Amax;
}
template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>>
__forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) {
if constexpr (kIsUE8M0) {
......
......@@ -841,7 +841,7 @@ class Buffer:
# noinspection PyTypeChecker
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int,
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False, use_int8: bool = False,
quant_type: int = 1, quant_group_size: int = 0, fp8_round_scale: bool = False,
async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
"""
......@@ -858,10 +858,20 @@ class Buffer:
only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported.
num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value.
num_experts: the number of all experts.
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
use_int8: whether to enable INT8 casting.
量化配置
quant_type: int 量化类型枚举
0 -> None 不量化,保持原始精度
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3_FNUZ)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2_FNUZ)
quant_group_size: int 量化分组大小
0 -> 逐token量化 (per-channel)
128-> 每 128 元素一组 (per-group) 量化
fp8_round_scale: bool 是否将 FP8 缩放因子取整为 2 的幂
true -> 缩放因子 = 2^k,硬件零开销
false -> 缩放因子 = 任意浮点,精度更高
异步配置
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
......@@ -869,15 +879,25 @@ class Buffer:
Returns:
recv_x: a tensor or tuple with received tokens for each expert.
With `use_fp8=True`: the first element is a `torch.Tensor` shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`.
The second tensor is the corresponding scales for the first element with shape
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`,
if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`.
Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility.
With `use_fp8=False`, the result would be a tensor shaped as
`[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`.
- packed_recv_x:
存储接收到的 Token 数据,形状为
`[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`。
数据类型取决于 quant_type:
quant_type == 1 -> torch.int8
quant_type == 2 -> torch.float8_e4m3fnuz
quant_type == 3 -> torch.float8_e4m3fnuz (UE8M0 使用 E4M3 格式存储)
quant_type == 4 -> torch.float8_e5m2fnuz
其他 (非量化) -> torch.bfloat16
- packed_recv_x_scales (可选):
仅在 quant_type > 0 时存在,存储量化的 Scale 值。
形状为 `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, scales_col_size]`。
- 当 quant_type == 3 (UE8M0) 时:
scales_col_size = hidden // 512
数据类型为 torch.int (内部打包存储 4-bit scale)。
*注意:此模式强制要求 fp8_round_scale=True 且 group_size=128。
- 当 quant_type == 1, 2, 4 时:
scales_col_size = hidden // 128 (若使用 group_size) 或 1 (per-channel)。
数据类型为 torch.float32。
Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are,
as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced).
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
......@@ -889,14 +909,15 @@ class Buffer:
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
self.runtime.low_latency_dispatch(x, topk_idx,
num_max_dispatch_tokens_per_rank, num_experts,
use_fp8, round_scale, use_ue8m0, use_int8,
quant_type, quant_group_size, fp8_round_scale,
async_finish, return_recv_hook)
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
tensors_to_record = (x, topk_idx,
packed_recv_x, packed_recv_x_scales, packed_recv_count,
packed_recv_src_info, packed_recv_layout_range)
return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \
EventOverlap(event, tensors_to_record if async_finish else None), hook
recv_x = (packed_recv_x, packed_recv_x_scales) if (quant_type > 0) else packed_recv_x
return recv_x, packed_recv_count, handle, EventOverlap(event, tensors_to_record if async_finish else None), hook
# noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
......
......@@ -6,7 +6,7 @@ from functools import partial
from typing import Literal, Set
import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_pg_back, per_token_cast_pc_back
def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: Set[int]):
......@@ -58,7 +58,8 @@ def test_main(num_tokens: int,
x_list = [x]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
# x_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
# x_list = [x_rand]
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
......@@ -80,22 +81,35 @@ def test_main(num_tokens: int,
hash_value, num_times = 0, 0
for current_x in x_list:
for return_recv_hook in (False, True):
for dispatch_use_fp8 in (False, True):
for round_scale in (False, True) if dispatch_use_fp8 else (False,):
for use_ue8m0 in (False, True) if round_scale else (False,):
for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant = quant_type > 0
for fp8_round_scale in (False, True) if quant_type != 3 else (True, ):
for quant_group_size in (0, 128,) if quant_type >= 2 else (0, ):
if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0):
continue
num_times += 1
for _ in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0,
quant_type=quant_type, fp8_round_scale=fp8_round_scale, quant_group_size=quant_group_size,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \
if dispatch_use_fp8 else packed_recv_x.clone()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_quant else packed_recv_x
if not dispatch_use_quant:
simulated_gemm_x = packed_recv_x.clone()
elif quant_group_size == 0:
simulated_gemm_x = per_token_cast_pc_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].reshape(-1)).view(packed_recv_x[0].shape)
elif quant_group_size == 128:
simulated_gemm_x = per_token_cast_pg_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape)
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
if not dispatch_use_quant:
recv_x = packed_recv_x[i]
elif quant_group_size == 0:
recv_x = per_token_cast_pc_back(packed_recv_x[0][i], packed_recv_x[1][i])
elif quant_group_size == 128:
recv_x = per_token_cast_pg_back(packed_recv_x[0][i], packed_recv_x[1][i])
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices
......@@ -113,18 +127,25 @@ def test_main(num_tokens: int,
if current_x is x:
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
if round_scale:
assert torch.equal(recv_x_amin, recv_x_amax)
if dispatch_use_quant:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
if not round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_fp8:
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
if quant_group_size != 0:
if fp8_round_scale:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item()
if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_quant:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else:
......@@ -147,8 +168,8 @@ def test_main(num_tokens: int,
if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
assert torch.isnan(combined_x).sum().item() == 0
# if not round_scale:
assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: diff={diff}, dispatch_use_fp8={dispatch_use_fp8}, zero_copy={zero_copy}'
# if not fp8_round_scale:
assert diff < (9e-4 if dispatch_use_quant else 1e-5), f'Error: diff={diff}, dispatch_use_quant={dispatch_use_quant}, zero_copy={zero_copy}'
hash_value ^= hash_tensor(combined_x)
# noinspection PyShadowingNames
......@@ -162,7 +183,8 @@ def test_main(num_tokens: int,
def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook)
quant_type=2, quant_group_size=0,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
......
import argparse
import random
import os
import torch
import torch.distributed as dist
from functools import partial
from typing import Literal, Set
import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back_int8
def test_main(num_tokens: int,
hidden: int,
num_experts: int,
num_topk: int,
rank: int,
num_ranks: int,
group: dist.ProcessGroup,
buffer: deep_ep.Buffer,
seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset = 128
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
x_list = [x]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
# Randomly mask some positions
for _ in range(10):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
# For failure simulation and shrink testing
mask_status = torch.zeros((num_ranks,), dtype=torch.int, device='cuda')
# Check dispatch correctness
do_check = True
hash_value, num_times = 0, 0
for current_x in x_list:
for return_recv_hook in (False, ):
for dispatch_use_fp8 in (True, ):
for round_scale in (False, ):
for use_ue8m0 in (False, ):
num_times += 1
use_int8 = True
for _ in range(1):
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, use_int8=use_int8,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
simulated_gemm_x = per_token_cast_back_int8(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, 1)).view(packed_recv_x[0].shape)
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
recv_x = per_token_cast_back_int8(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices
int_mask = (2 ** 32) - 1
num_valid_tokens = recv_count.item()
assert num_valid_tokens == (
recv_layout_range
& int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item(
), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item()}'
if num_valid_tokens == 0:
continue
# Check received data
if current_x is x:
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x_amax)
if round_scale:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
elif use_int8:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.01
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
# for j in range(num_ranks):
# if (not round_scale):
# check_tmp1 = (recv_x_amin == j - rank_offset).sum().item()
# check_tmp2 = (all_topk_idx[j] == expert_id).sum().item()
# print(f'rank: {rank}, j: {j}, check_tmp1: {check_tmp1}, check_tmp2: {check_tmp2}, diff: {abs(check_tmp1 - check_tmp2)}')
# assert abs(check_tmp1 - check_tmp2) < 3
# assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_fp8:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else:
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
print("dispatch int 8 pass")
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
mat_1 = torch.randn((8192, 8192), dtype=torch.float)
mat_0 @ mat_1
hook()
# noinspection PyShadowingNames
def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=True, round_scale=False, use_ue8m0=False, use_int8=True,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth
scale_size = 1 # hidden / 128
num_fp8_bytes, num_bf16_bytes = (hidden + scale_size * 4 + 16), hidden * 2
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += num_bf16_bytes * num_selections
# Separate profiling
for return_recv_hook in (True, False):
group.barrier()
dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
kernel_names=('dispatch', 'combine'),
barrier_comm_profiling=True,
suppress_kineto_output=True,
num_kernels_per_period=2 if return_recv_hook else 1)
if not return_recv_hook:
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us',
flush=True)
else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | '
f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us',
flush=True)
return hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
# Destroy the buffer runtime and communication group
buffer.destroy()
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512`
parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)')
parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=256, help='Number of experts (default: 288)')
parser.add_argument('--allow-mnnvl', action="store_true", help='Allow MNNVL for communication')
parser.add_argument('--disable-nvlink', action='store_true', help='Whether to disable NVLink for testing')
parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
args = parser.parse_args()
num_processes = args.num_processes
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
......@@ -57,28 +57,22 @@ def per_token_cast_to_fp8(x: torch.Tensor):
return (x_padded_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, aligned_n)[:, :n].contiguous(), (x_amax / 448.0).view(m, -1)
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
if x_fp8.numel() == 0:
return x_fp8.to(torch.bfloat16)
def per_token_cast_pg_back(x: torch.Tensor, x_scales: torch.Tensor):
if x.numel() == 0:
return x.to(torch.bfloat16)
assert x_fp8.dim() == 2
m, n = x_fp8.shape
assert x.dim() == 2
m, n = x.shape
aligned_n = align_up(n, 128)
x_fp8_padded = torch.nn.functional.pad(x_fp8, (0, aligned_n - n), mode='constant', value=0)
x_padded = torch.nn.functional.pad(x, (0, aligned_n - n), mode='constant', value=0)
if x_scales.dtype == torch.int:
x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23
x_scales = x_scales.view(dtype=torch.float)
x_fp32_padded = x_fp8_padded.to(torch.float32).view(x_fp8.size(0), -1, 128)
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
return (x_fp32_padded * x_scales).view(x_fp8_padded.shape).to(torch.bfloat16)[:,:n].contiguous()
x_fp32_padded = x_padded.to(torch.float32).view(x.size(0), -1, 128)
x_scales = x_scales.view(x.size(0), -1, 1)
return (x_fp32_padded * x_scales).view(x_padded.shape).to(torch.bfloat16)[:,:n].contiguous()
def per_token_cast_back_int8(x_int8: torch.Tensor, x_scales: torch.Tensor):
"""
x_int8: [m, n] int8 tensor
x_scales: [m, n] 或 [m, 1] 或 [m, n/128] 量化 scale float
return: [m, n] bf16 tensor
"""
def per_token_cast_pc_back(x_int8: torch.Tensor, x_scales: torch.Tensor):
if x_int8.numel() == 0:
return x_int8.to(torch.bfloat16)
......@@ -86,12 +80,9 @@ def per_token_cast_back_int8(x_int8: torch.Tensor, x_scales: torch.Tensor):
m, n = x_int8.shape
aligned_n = align_up(n, 128)
x_int8_padded = torch.nn.functional.pad(
x_int8, (0, aligned_n - n), mode='constant', value=0
)
x_int8_padded = torch.nn.functional.pad(x_int8, (0, aligned_n - n), mode='constant', value=0)
x_fp32_padded = x_int8_padded.to(torch.float32).view(m, -1, 1)
x_scales = x_scales.view(m, -1, 1).to(torch.float32)
# print(f'x_int8.shape: {x_int8.shape}, x_fp32_padded: {x_fp32_padded.shape}, x_scales: {x_scales.shape}')
x_deq = (x_fp32_padded * x_scales).view(m, aligned_n)
return x_deq[:, :n].to(torch.bfloat16).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment