Commit 75b00cfb authored by lijian6's avatar lijian6
Browse files

Merge branch 'logfmt_master' into 'main'

低延迟combine支持10bit量化代码

See merge request dcutoolkit/deeplearing/DeepEP!21
parents dbf9fd61 17d9c844
...@@ -136,7 +136,7 @@ struct LowLatencyLayout { ...@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts, int quant_group_size=0) { int num_ranks, int num_experts, int quant_group_size=0) {
const int num_scales = quant_group_size == 0 ? 4 : hidden / QUANTIZATION_GROUPSIZE; // 应该是1,但是代码中为了满足int4对齐 const int num_scales = hidden / QUANTIZATION_GROUPSIZE;
// Dispatch and combine layout: // Dispatch and combine layout:
// - 2 symmetric odd/even send buffer // - 2 symmetric odd/even send buffer
...@@ -148,11 +148,11 @@ struct LowLatencyLayout { ...@@ -148,11 +148,11 @@ struct LowLatencyLayout {
// transformation // transformation
EP_HOST_ASSERT(num_scales * sizeof(float) <= static_cast<size_t>(hidden)); EP_HOST_ASSERT(num_scales * sizeof(float) <= static_cast<size_t>(hidden));
size_t num_bytes_per_dispatch_msg = size_t num_bytes_per_dispatch_msg =
sizeof(int4) + sizeof(int4) + std::max(hidden * sizeof(hip_bfloat16), hidden +
std::max(hidden * sizeof(hip_bfloat16), hidden + num_scales * sizeof(float)); (quant_group_size == 0 ? 4 : num_scales) * sizeof(float)); // 应该是1,但是代码中为了满足int4对齐
// 与internode_ll::combine 中的 num_bytes_per_slot 相等 // 与internode_ll::combine 中的 num_bytes_per_slot 相等
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16); size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16) + num_scales * sizeof(__hip_bfloat162);
// Send buffer // Send buffer
size_t dispatch_send_buffer_bytes = size_t dispatch_send_buffer_bytes =
......
...@@ -1413,6 +1413,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1413,6 +1413,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats, const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) { const std::optional<torch::Tensor>& out) {
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
...@@ -1482,6 +1483,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1482,6 +1483,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta.first, next_clean_meta.second, next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, num_topk, num_experts, rank, num_ranks,
use_logfmt,
workspace, num_device_sms, launch_stream, workspace, num_device_sms, launch_stream,
phases, zero_copy); phases, zero_copy);
}; };
......
...@@ -185,6 +185,7 @@ public: ...@@ -185,6 +185,7 @@ public:
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats, const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt); const std::optional<torch::Tensor>& out = std::nullopt);
......
...@@ -159,6 +159,7 @@ void combine(void* combined_x, ...@@ -159,6 +159,7 @@ void combine(void* combined_x,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
bool use_logfmt,
void* workspace, int num_device_sms, hipStream_t stream, void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy); int phases, bool zero_copy);
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "shmem_wrapper.cuh" #include "shmem_wrapper.cuh"
#include "internode_ll_logfmt.cuh"
namespace deep_ep { namespace deep_ep {
...@@ -612,7 +613,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -612,7 +613,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
#undef DISPATCH_LAUNCH_CASE #undef DISPATCH_LAUNCH_CASE
} }
template <int kHidden, int kNumMaxTopk, int kMaxNumWarps=16> template <bool kUseLogFMT, int kHidden, int kNumMaxTopk, int kMaxNumWarps=16>
__global__ __launch_bounds__(16 * kWarpSize, 1) void __global__ __launch_bounds__(16 * kWarpSize, 1) void
combine(void* combined_x, combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
...@@ -643,7 +644,24 @@ combine(void* combined_x, ...@@ -643,7 +644,24 @@ combine(void* combined_x,
// Message package // Message package
EP_STATIC_ASSERT(kHidden % QUANTIZATION_GROUPSIZE == 0, "Invalid hidden"); EP_STATIC_ASSERT(kHidden % QUANTIZATION_GROUPSIZE == 0, "Invalid hidden");
constexpr size_t num_bytes_per_slot = kHidden * sizeof(hip_bfloat16);
/////////////// LogFMT使用 ///////////////
constexpr int bSupportLogFMT = kUseLogFMT && hidden_bf16_int4 % (kWarpSize * 2) == 0;
constexpr int kNumSendUnrolls = bSupportLogFMT ? 2 : 1;
constexpr int kNumRecvUnrolls = bSupportLogFMT ? 2 : 1;
constexpr int kNumMsgInt4ElemPerWarp = kWarpSize * kNumSendUnrolls; // 每个warp发送的int4元素数据量,即每个warp发送 kNumMsgInt4ElemPerWarp*sizeof(int4)/sizeof(bfloat16)
EP_STATIC_ASSERT(hidden_bf16_int4 % (kNumSendUnrolls * kWarpSize) == 0, "Invalid hidden");
EP_STATIC_ASSERT(kNumSendUnrolls >= kNumRecvUnrolls, "Invalid unroll factors");
constexpr int kNumDivisions = kHidden / QUANTIZATION_GROUPSIZE;
constexpr int kNumMetaBytes = kNumDivisions * sizeof(__hip_bfloat162); // 用于记录数据的最大最小值
constexpr int kNumSendLogFMTBytes = kNumMsgInt4ElemPerWarp * sizeof(int4);
constexpr int kNumStages = 1; // 使用kNumStages>1,则需要的LDS大于64KB
constexpr int kLogFMTShmemSize = kMaxNumWarps * (kNumStages * kNumSendLogFMTBytes + kNumMetaBytes);
__shared__ uint8_t smem_buffer[kLogFMTShmemSize];
/////////////////////////////////////////////
constexpr size_t num_bytes_per_slot = kHidden * sizeof(hip_bfloat16) + kNumMetaBytes;
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// 初始化用于细粒度warp间同步的计数器数组 // 初始化用于细粒度warp间同步的计数器数组
...@@ -683,6 +701,12 @@ combine(void* combined_x, ...@@ -683,6 +701,12 @@ combine(void* combined_x,
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) + const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot; local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// 用于logfmt的LDS
auto smem_ptr = smem_buffer + warp_id * (kNumStages * kNumSendLogFMTBytes + kNumMetaBytes);
// 存储logfmt的起始地址,并根据stage_idx进行索引块
auto logfmt_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<int4*>(smem_ptr + i * kNumSendLogFMTBytes); });
// 存储logfmt的最大最小值
auto meta_buffers = bSupportLogFMT ? reinterpret_cast<__hip_bfloat162*>(smem_ptr + kNumStages * kNumSendLogFMTBytes) : nullptr;
// Unpack layout // Unpack layout
int offset, num_tokens_to_send; int offset, num_tokens_to_send;
...@@ -699,20 +723,78 @@ combine(void* combined_x, ...@@ -699,20 +723,78 @@ combine(void* combined_x,
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row); const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot; const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank); // 采用logfmt或者直接拷贝
if (p2p_ptr == 0) { // RDMA uint64_t dst_p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr); int num_send_bytes = hidden * sizeof(hip_bfloat16);
if (not zero_copy)
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); if (not zero_copy or dst_p2p_ptr != 0) {
// Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`
const auto cpy_src_int4_ptr = zero_copy ? reinterpret_cast<int4*>(buf_ptr) : x_int4;
const auto cpy_dst_int4_ptr = dst_p2p_ptr == 0 ? reinterpret_cast<int4*>(buf_ptr): reinterpret_cast<int4*>(dst_p2p_ptr);
// 设置数据的真实偏移量
int logfmt_offset_bytes = kNumMetaBytes;
// 进入循环,逐步拷贝数据
constexpr int encode_num_warps = hidden_bf16_int4 / kNumMsgInt4ElemPerWarp;
for (int iter_idx = 0; iter_idx < encode_num_warps; ++iter_idx) {
int num_logfmt_bytes = kNumMsgInt4ElemPerWarp * sizeof(int4);
// 原始数据的warp级编译
int warp_offset = iter_idx * kNumMsgInt4ElemPerWarp;
if constexpr(bSupportLogFMT) {
// 采用 寄存器->lds->global 的流水线方式, 量化后拷贝到buf_ptr中
const int& stage_idx = iter_idx % kNumStages;
// thread偏移
int thread_offset = warp_offset + lane_id * kNumSendUnrolls;
constexpr int kNumInt4PerDivision = 128 / kNumElemsPerInt4; // = 128/(sizeof(int4) / sizeof(hip_bfloat16)) = 128/(16/2)=16
num_logfmt_bytes = logfmt_encode<kNumSendUnrolls>(
cpy_src_int4_ptr + warp_offset, // 等同于 x_int4
logfmt_buffers[stage_idx],
// NOTES: only the leader lane will write the result
(thread_offset % kNumInt4PerDivision == 0) ? meta_buffers + thread_offset / kNumInt4PerDivision : nullptr,
lane_id);
// 将量化后的数据写入
using vec_type = uint32_t;
UNROLLED_WARP_COPY_LL(2, lane_id, num_logfmt_bytes / sizeof(vec_type),
reinterpret_cast<vec_type *>(reinterpret_cast<uint8_t*>(cpy_dst_int4_ptr) + logfmt_offset_bytes),
reinterpret_cast<vec_type *>(logfmt_buffers[stage_idx]),
ld_nc_global, st_na_global);
// 起始地址偏移
logfmt_offset_bytes += num_logfmt_bytes;
} else {
// 非量化数据的传输
UNROLLED_WARP_COPY_LL(2, lane_id, kNumMsgInt4ElemPerWarp,
reinterpret_cast<int4*>(cpy_dst_int4_ptr + warp_offset),
reinterpret_cast<const int4*>(cpy_src_int4_ptr + warp_offset),
ld_nc_global, st_na_global);
}
syncwarp();
}
// Store metadata (min/max values) for LogFMT
if constexpr (bSupportLogFMT) {
// 最终设置节点间传输的字节数
num_send_bytes = logfmt_offset_bytes;
using vec_type = uint32_t;
auto meta_buffers_ptr = reinterpret_cast<vec_type*>(meta_buffers);
auto cpy_dst_uint32_ptr = reinterpret_cast<vec_type*>(cpy_dst_int4_ptr);
for(int j = lane_id; j < kNumMetaBytes / sizeof(vec_type); j+=kWarpSize) {
*(cpy_dst_uint32_ptr + j) = meta_buffers_ptr[j];
}
}
syncwarp();
}
if (dst_p2p_ptr == 0) {
internode_ll_putmem_nbi((void*)dst_ptr, (void*)buf_ptr, internode_ll_putmem_nbi((void*)dst_ptr, (void*)buf_ptr,
num_ranks, dst_rank, local_expert_idx, num_ranks, dst_rank, local_expert_idx,
hidden * sizeof(hip_bfloat16)); num_send_bytes);
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(x_int4);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(p2p_ptr);
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
} }
} }
...@@ -773,40 +855,136 @@ LOW_LATENCY_COMBINE_RECV: ...@@ -773,40 +855,136 @@ LOW_LATENCY_COMBINE_RECV:
// Reduce tokens with FP8 cast // Reduce tokens with FP8 cast
// EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads); // EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization"); EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
if (thread_id < hidden_bf16_int4) {
for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
// Read top-k indices and weights
int reg_topk_idx[kNumMaxTopk];
float reg_topk_weights[kNumMaxTopk];
#pragma unroll
for (int i = 0; i < num_topk; ++ i) {
reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
}
float combined_values[kNumElemsPerInt4] = {0.0f}; // 计算需要多少个warp
#pragma unroll constexpr int num_decode_warps = hidden_bf16_int4 / (kNumRecvUnrolls * kWarpSize);
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) { // 限制thread_id
// Read from sources if (warp_id >= num_decode_warps) {
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) + return;
(reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot); }
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
// 每128个数据记录一个max/min值,即该数为总的max/min值数量
// Reduce constexpr int kNumDivisionBytes = kNumDivisions * sizeof(float);
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id); // 每个warp内总的BF16值的数量
const auto x_bf16 = reinterpret_cast<hip_bfloat16*>(&x_vec); constexpr int kNumBF16PerWarpBytes = kWarpSize * kNumRecvUnrolls * sizeof(int4);
constexpr int kNumLogFMTPerWarpBytes = kNumBF16PerWarpBytes * 10 / 16;
// 用于记录 max/min 值的 log 值
auto log_amax_buffers =
PatternVisitor([=](const int& i) { return reinterpret_cast<float*>(smem_buffer + i * kNumDivisionBytes); });
auto log_amin_buffers = PatternVisitor([=](const int& i) {
return reinterpret_cast<float*>(smem_buffer + kNumStages * kNumDivisionBytes + i * kNumDivisionBytes);
});
auto cast_info_buffers = PatternVisitor([=](const int& i) {
return reinterpret_cast<int*>(smem_buffer + kNumStages * kNumDivisionBytes * 2 + i * kNumDivisionBytes);
});
// 初始化 topk_idx 和 topk_weights
int topk_idx_by_lane = -1;
float topk_weights_by_lane = -1;
int stage_idx = 0;
for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
if (lane_id < num_topk) {
topk_idx_by_lane = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + lane_id));
topk_weights_by_lane = __ldg(topk_weights + token_idx * num_topk + lane_id);
}
float combined_values[kNumElemsPerInt4 * kNumRecvUnrolls] = {0.0f};
#pragma unroll
for (int i = 0; i < num_topk; ++ i) {
int topk_idx_reg = shfl_sync(topk_idx_by_lane, i);
if (topk_idx_reg < 0)
continue;
const auto& topk_weight_reg = shfl_sync(topk_weights_by_lane, i);
// Read from sources
auto rdma_buffer_type = reinterpret_cast<const uint8_t*>(reinterpret_cast<uint8_t*>(rdma_recv_x) +
(topk_idx_reg * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
if constexpr(bSupportLogFMT) {
// 接收到的数据位置
const uint8_t* data_buffer = rdma_buffer_type + kNumMetaBytes;
// 读取max/min数据
if(warp_id == 0) {
// 因为每个warp能处理数据量为 kWarpSize*sizeof(int4)/sizeof(bfloat16) * kNumSendUnrolls
// 即不考虑kNumSendUnrolls,一共 kWarpSize*sizeof(int4)/sizeof(bfloat16)/128 组, 代入参数 = kWarpSize / 16 个warp,nv上为2,dcu上为4
logfmt_check_amaxmin<kNumDivisions / (kWarpSize / 16), kNumSendUnrolls, kNumRecvUnrolls>(
/*meta_buffer*/rdma_buffer_type,
reinterpret_cast<int4*>(log_amax_buffers[stage_idx]),
reinterpret_cast<int4*>(log_amin_buffers[stage_idx]),
cast_info_buffers[stage_idx],
lane_id);
}
__syncthreads();
// 获取cast_info_buffers
const auto& info = cast_info_buffers[stage_idx][warp_id];
bool enable_cast = info & 1;
int num_casted_prefix = info >> 1; // 可用的
// 计算偏移(与TMA版本逻辑一致)
int warp_offset = kNumLogFMTPerWarpBytes * num_casted_prefix +
kNumBF16PerWarpBytes * (warp_id - num_casted_prefix);
int lane_offset = (enable_cast ? kNumLogFMTPerWarpBytes : kNumBF16PerWarpBytes) / kWarpSize * lane_id;
// 使用临时缓冲区进行归约
const uint8_t* thread_data_ptr = data_buffer + warp_offset + lane_offset;
/**
一共有kNumDivisions个max/min数据对,读取时每warp默认处理256bit的max/min,所以logfmt_check_amaxmin的kNumLanes设置为 kNumDivisions/2
保存数据时每个log_amax_buffers为float2数据类型,保存总的warpkNumDivisions / 2
实际保存数据时,每个warp保存的实际数据个数为 kWarpSize*kNumRecvUnrolls*sizeof(int4)/sizeof(hip_bfloat16)
实际每个warp读取的max/min的 warp_idx=kWarpSize*kNumRecvUnrolls*sizeof(int4)/sizeof(hip_bfloat16) / 128 = kNumRecvUnrolls * 2
具体的lane_id处理的数据量为 warp_idx / kWarpSize
*/
int log_amaxmin_per_warp = kNumRecvUnrolls * kWarpSize * sizeof(int4) / sizeof(hip_bfloat16) / QUANTIZATION_GROUPSIZE;
int division_idx = warp_id * log_amaxmin_per_warp + lane_id * log_amaxmin_per_warp / kWarpSize;
// 反量化
decode_and_accumulate<kNumRecvUnrolls>(
reinterpret_cast<const uint32_t*>(thread_data_ptr), // 直接使用全局内存地址
combined_values,
log_amax_buffers[stage_idx][division_idx],
log_amin_buffers[stage_idx][division_idx],
enable_cast,
topk_weight_reg);
} else {
// 接收到的数据位置
const uint8_t* data_buffer = rdma_buffer_type;
// 计算偏移
int warp_offset = kNumBF16PerWarpBytes * warp_id;
int lane_offset = kNumBF16PerWarpBytes / kWarpSize * lane_id;
// 使用临时缓冲区进行归约
const uint8_t* thread_data_ptr = data_buffer + warp_offset + lane_offset;
#pragma unroll #pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j) for (int j = 0; j < kNumRecvUnrolls; ++j) {
combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i]; auto tmp_rdma_value = ld_nc_global(reinterpret_cast<const int4*>(thread_data_ptr) + j);
const auto x_bf16 = reinterpret_cast<const hip_bfloat16*>(&tmp_rdma_value);
#pragma unroll
for (int k = 0; k < kNumElemsPerInt4; ++k) {
int combined_idx = j * kNumElemsPerInt4 + k;
combined_values[combined_idx] += static_cast<float>(x_bf16[k]) * topk_weight_reg;
}
}
} }
}
// Write results // Write results,kNumRecvUnrolls==2时则写256bit的数
int4& combined_int4 = *reinterpret_cast<int4*>(combined_values); int4 combined_int4[kNumRecvUnrolls];
auto combined_bf16 = reinterpret_cast<hip_bfloat16*>(&combined_values); auto combined_bf16 = reinterpret_cast<hip_bfloat16 *>(&combined_int4[0]);
#pragma unroll #pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j) for (int j = 0; j < kNumElemsPerInt4 * kNumRecvUnrolls; ++ j) {
combined_bf16[j] = static_cast<hip_bfloat16>(combined_values[j]); combined_bf16[j] = static_cast<hip_bfloat16>(combined_values[j]);
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4; }
for(int j = 0; j < kNumRecvUnrolls; ++ j) {
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4 +
warp_id * kWarpSize * kNumRecvUnrolls)[lane_id * kNumRecvUnrolls + j] = combined_int4[j];
} }
} }
} }
...@@ -820,6 +998,7 @@ void combine(void* combined_x, ...@@ -820,6 +998,7 @@ void combine(void* combined_x,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
bool use_logfmt,
void* workspace, int num_device_sms, hipStream_t stream, void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy) { int phases, bool zero_copy) {
constexpr int kMaxNumWarps = 16; constexpr int kMaxNumWarps = 16;
...@@ -840,7 +1019,9 @@ void combine(void* combined_x, ...@@ -840,7 +1019,9 @@ void combine(void* combined_x,
#define COMBINE_LAUNCH_CASE(hidden) \ #define COMBINE_LAUNCH_CASE(hidden) \
{ \ { \
auto combine_func = combine<hidden, kNumMaxTopk, kMaxNumWarps>; \ auto combine_func = use_logfmt ? \
combine<true, hidden, kNumMaxTopk, kMaxNumWarps> : \
combine<false, hidden, kNumMaxTopk, kMaxNumWarps>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, \ combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \ x, topk_idx, topk_weights, src_info, layout_range, \
......
#pragma once
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "buffer.cuh"
#include "utils.cuh"
#include <iostream>
#include "hip/hip_runtime.h"
#include "shmem_wrapper.cuh"
namespace deep_ep {
namespace internode_ll {
template <int kNumSendUnrolls>
__forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4* dst_buffer, __hip_bfloat162* shared_amaxmin, const int& lane_id) {
EP_STATIC_ASSERT(kNumSendUnrolls == 2, "kNumSendUnrolls == 2 only");
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(__hip_bfloat16); // 8
constexpr float kLogThreshold = 0;
constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))`
constexpr int kNumBits = 10;
constexpr int kNumValues = 1 << (kNumBits - 1); // = 512
constexpr int kSendValueBytes = kNumSendUnrolls * sizeof(int4); //=2*16=32
constexpr int kNumElementPerInt4 = sizeof(int4) / sizeof(uint32_t);
int4 int4_values[kNumSendUnrolls];
const auto& uint32_values = reinterpret_cast<uint32_t*>(int4_values);
const auto& bf162_values = reinterpret_cast<__hip_bfloat162*>(int4_values);
// Calculate lane offset
const auto& ld_buffer = cpy_src_int4_ptr + lane_id * kNumSendUnrolls;
// Local log amax
auto bf162_amax = __hip_bfloat162(HIPRT_ZERO_BF16, HIPRT_ZERO_BF16);
auto bf162_amin = __hip_bfloat162(HIPRT_INF_BF16, HIPRT_INF_BF16);
uint32_t local_signs = 0;
#pragma unroll
for (int v = 0; v < kNumSendUnrolls; ++v) {
int4 ld_int4_value = ld_nc_global(ld_buffer + v); // 向量化读取
uint32_t* ld_u32_ptr = reinterpret_cast<uint32_t*>(&ld_int4_value);
#pragma unroll
for (int k = 0; k < kNumElementPerInt4; ++k) { // 也是kNumSendUnrolls * kNumElemsPerInt4 / 2
// TODO: eliminate bank conflicts
uint32_t ld_u32_value = ld_u32_ptr[k];
int k_offset = v * kNumElementPerInt4 + k;
// 提取符号位: 每个bfloat16的最高位是符号位
local_signs |= ((ld_u32_value >> 15) & 1) << (k_offset * 2);
local_signs |= ((ld_u32_value >> 31) & 1) << (k_offset * 2 + 1);
// 清除符号位,保留幅值
ld_u32_value &= 0x7fff7fff;
auto ld_bf16_value = *reinterpret_cast<__hip_bfloat162*>(&ld_u32_value);
bf162_amax = __hmax2(bf162_amax, ld_bf16_value);
bf162_amin = __hmin2(bf162_amin, ld_bf16_value);
uint32_values[k_offset] = ld_u32_value;
}
}
// Reduce per 128 channels
// TODO: figure out how hardware do 2-byte min/max
auto amax = __builtin_fmaxf(static_cast<float>(bf162_amax.x), static_cast<float>(bf162_amax.y));
auto amin = __builtin_fminf(static_cast<float>(bf162_amin.x), static_cast<float>(bf162_amin.y));
// 即每128个值进行一次reduce
constexpr static int kNumLanesToReduce = 128 * sizeof(__hip_bfloat16) / kSendValueBytes; // =128*2 / (kNumSendUnrolls * sizeof(int4)) = 8
amax = warp_reduce_max<kNumLanesToReduce>(amax);
amin = warp_reduce_min<kNumLanesToReduce>(amin);
// Write min/max into the shared memory
if (shared_amaxmin != nullptr) {
*shared_amaxmin = __hip_bfloat162(amax, amin);
}
syncwarp();
// Calculate log amin/amax float
const auto& log_amax = __builtin_log2f(amax);
const auto& log_amin = __builtin_fmaxf(__builtin_log2f(amin), log_amax - kMinClip);
// 在组内广播enable_cast结果
const bool& enable_cast = warp_reduce_and<kNumLanesToReduce, true>(log_amax < kLogThreshold and log_amin < log_amax);
// Case into LogFMT-10 if satisfied
if (enable_cast) {
constexpr int dst_buffer_step = kSendValueBytes * 10 / 16;
const auto& st_buffer = reinterpret_cast<uint32_t*>(reinterpret_cast<uint8_t*>(dst_buffer) + lane_id * dst_buffer_step);
uint32_t st_u32_values[dst_buffer_step / sizeof(uint32_t)]; // = 5
// 计算10bit数据的两个相邻数值的差值
const auto step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
const auto step_inv = 1.0f / step;
// 计算舍入值
const auto rounding = 2.0f - __builtin_log2f((1.0f + __builtin_exp2f(step)) * 0.5f) * step_inv;
const auto fused_rounding = rounding - log_amin * step_inv;
// 用于存储编码后的值
uint32_t encoded[kNumElemsPerInt4 * 2];
// 展开循环,处理数据打包
{
// 将int4值(128bit)转换为 bfloat162
#pragma unroll
for (int k = 0; k < kNumElemsPerInt4; ++k) { // 8
// 将 bfloat162 转换为 float2
const auto& fp162_fvalue = __bfloat1622float2(bf162_values[k]);
/*
实际进行压缩的公式为:
q = clamp( round( (log2(abs(x)) - log_min) / (log_max - log_min) * (K - 2) + 0.5 ), 0, K - 1)
其中:
x: 输入的浮点数
q: 输出的整数,表示压缩后的值
log_min: 输入中最小值的log2值
log_max: 输入中最大值的log2值
K: 压缩后的整数的最大值(即,K为2的幂)
*/
// 对 float 值进行编码
encoded[k * 2 + 0] = __float2uint_rd(__builtin_fmaxf(__builtin_log2f(fp162_fvalue.x) * step_inv + fused_rounding, 0));
encoded[k * 2 + 1] = __float2uint_rd(__builtin_fmaxf(__builtin_log2f(fp162_fvalue.y) * step_inv + fused_rounding, 0));
}
// 批量打包编码后的值到 st_buffer
st_u32_values[0] = (encoded[0] >> 0) | (encoded[1] << 9) | (encoded[2] << 18) | (encoded[3] << 27);
st_u32_values[1] = (encoded[3] >> 5) | (encoded[4] << 4) | (encoded[5] << 13) | (encoded[6] << 22) | (encoded[7] << 31);
st_u32_values[2] = (encoded[7] >> 1) | (encoded[8] << 8) | (encoded[9] << 17) | (encoded[10] << 26);
st_u32_values[3] = (encoded[10] >> 6) | (encoded[11] << 3) | (encoded[12] << 12) | (encoded[13] << 21) | (encoded[14] << 30);
st_u32_values[4] = (encoded[14] >> 2) | (encoded[15] << 7) | (local_signs << 16);
}
// 保存160bit的数据到st_buffer
st_buffer[0] = st_u32_values[0];
*(reinterpret_cast<int4*>(st_buffer + 1)) = *(reinterpret_cast<int4*>(st_u32_values + 1));
} else {
// 准备收发数据
using vec_type = int4;
const auto& ld_buffer_vec = reinterpret_cast<const vec_type*>(ld_buffer);
auto st_buffer_vec = reinterpret_cast<vec_type*>(reinterpret_cast<uint8_t*>(dst_buffer) + lane_id * kSendValueBytes);
constexpr int kLoopIter = kSendValueBytes / sizeof(vec_type);
#pragma unroll
for (int k = 0; k < kLoopIter; ++k) {
st_buffer_vec[k] = ld_nc_global(ld_buffer_vec + k);
}
}
// 确保 warp 内的所有线程都完成打包操作
syncwarp();
// 计算量化成功和失败时的数据量
constexpr int unable_cast_num_bytes = kWarpSize * kSendValueBytes; // = 64*2*16 = 2048
constexpr int enable_cast_num_bytes = unable_cast_num_bytes * 10 / 16; // = 2048/16*10=1280
// Return TMA copy bytes
return enable_cast ? enable_cast_num_bytes : unable_cast_num_bytes;
}
template <int kNumLanes, int kNumSendUnrolls, int kNumRecvUnrolls>
__forceinline__ __device__ void logfmt_check_amaxmin(
const uint8_t* meta_buffer, int4* shared_log_amax, int4* shared_log_amin, int* shared_cast_info, const int lane_id) {
// 定义log阈值和最小剪切值
constexpr float kLogThreshold = 0;
constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))`
constexpr int kNumQuantGroupsPerWarp = kWarpSize / 16;
using log_vec_type = int4;
EP_STATIC_ASSERT(sizeof(log_vec_type) / sizeof(__hip_bfloat162) == kNumQuantGroupsPerWarp, "kNumQuantGroupsPerWarp == sizeof(log_vec_type) only");
// 初始化类型转换启用标志
bool enable_cast = true;
// 如果 lane_id 小于 kNumLanes,则进行计算
if (lane_id < kNumLanes) {
// 从 meta_buffer 中读取 amaxmin2 值
auto amaxmin4 = reinterpret_cast<const log_vec_type*>(meta_buffer)[lane_id];
const auto& bf162_amaxmin = reinterpret_cast<__hip_bfloat162*>(&amaxmin4);
// 定义 log_amax 和 log_amin 数组
float log_amax[kNumQuantGroupsPerWarp], log_amin[kNumQuantGroupsPerWarp];
// 展开循环,计算 log_amax 和 log_amin
#pragma unroll
for (int i = 0; i < kNumQuantGroupsPerWarp; ++i) { // sizeof(uint64_t) / sizeof(__hip_bfloat162) = 2
auto amax = static_cast<float>(bf162_amaxmin[i].x);
auto amin = static_cast<float>(bf162_amaxmin[i].y);
log_amax[i] = __builtin_log2f(amax);
log_amin[i] = amin == 0 ? log_amax[i] - kMinClip : __builtin_fmaxf(__builtin_log2f(amin), log_amax[i] - kMinClip);
enable_cast = enable_cast and log_amax[i] < kLogThreshold and log_amin[i] < log_amax[i];
}
// 将计算结果存储到 shared_log_amax 和 shared_log_amin 中
int4 log_amax_int4 = *reinterpret_cast<int4*>(log_amax);
int4 log_amin_int4 = *reinterpret_cast<int4*>(log_amin);
shared_log_amax[lane_id] = log_amax_int4;
shared_log_amin[lane_id] = log_amin_int4;
}
// 计算 casted 值。根据当前线程是否启用了类型转换,计算它所属的组的索引
const auto& casted = warp_reduce_and<kNumSendUnrolls>(enable_cast) ? 1u << (lane_id / kNumRecvUnrolls) : 0u;
// 计算 num_casted_prefix 值。计算当前线程之前有多少个线程启用了类型转换。
const auto& num_casted_prefix = __popc(warp_reduce_or<kNumRecvUnrolls, true>(casted) & ((1u << (lane_id / kNumRecvUnrolls)) - 1));
// 如果 lane_id 小于 kNumLanes 且 lane_id 是 kNumRecvUnrolls 的倍数,则更新 shared_cast_info
if (lane_id < kNumLanes and lane_id % kNumRecvUnrolls == 0) {
// 最低1位保存casted结果,最高31位保存num_casted_prefix值
shared_cast_info[lane_id / kNumRecvUnrolls] = (num_casted_prefix << 1) | (casted ? 1u : 0u);
}
}
template <int kNumRecvUnrolls>
__forceinline__ __device__ void decode_and_accumulate(
const uint32_t* ld_buffer, float* accum, const float& log_amax, const float& log_amin,
const bool& enable_cast, const float& weight) {
EP_STATIC_ASSERT(kNumRecvUnrolls == 2, "kNumRecvUnrolls == 2 only");
if (enable_cast) {
constexpr int kNumBits = 10;
constexpr int kNumValues = 1 << (kNumBits - 1);
const auto& step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
auto decode = [=](const uint32_t& encoded, const uint32_t& sign) {
const auto decoded = encoded == 0 ? .0f : __builtin_exp2f((encoded - 1) * step + log_amin);
return sign ? -decoded : decoded;
};
uint32_t concat[6];
concat[0] = ld_buffer[0];
#pragma unroll
for (int k = 1; k < 5; ++k)
concat[k] = (ld_buffer[k - 1] >> (32 - k * 5)) | (ld_buffer[k] << (k * 5));
concat[5] = ld_buffer[4] >> 7;
const uint32_t& local_signs = ld_buffer[4] >> 16;
#pragma unroll
for (int k = 0; k < 5; ++k) {
accum[k * 3 + 0] += decode((concat[k] >> 0) & 0x1ff, (local_signs >> (k * 3 + 0)) & 1) * weight;
accum[k * 3 + 1] += decode((concat[k] >> 9) & 0x1ff, (local_signs >> (k * 3 + 1)) & 1) * weight;
accum[k * 3 + 2] += decode((concat[k] >> 18) & 0x1ff, (local_signs >> (k * 3 + 2)) & 1) * weight;
}
accum[15] += decode(concat[5] & 0x1ff, (local_signs >> 15) & 1) * weight;
} else {
constexpr int kLoopIter = kNumRecvUnrolls * sizeof(int4) / sizeof(uint32_t);
#pragma unroll
for (int k = 0; k < kLoopIter; ++k) {
auto bf16_pack = *reinterpret_cast<const __hip_bfloat162*>(ld_buffer + k);
accum[k * 2 + 0] += static_cast<float>(bf16_pack.x) * weight;
accum[k * 2 + 1] += static_cast<float>(bf16_pack.y) * weight;
}
}
}
} // namespace internode_ll
} // namespace deep_ep
...@@ -72,9 +72,9 @@ __device__ __forceinline__ T shfl_xor(const T val, int laneMask, int width = kWa ...@@ -72,9 +72,9 @@ __device__ __forceinline__ T shfl_xor(const T val, int laneMask, int width = kWa
return __shfl_xor(val, laneMask, width); return __shfl_xor(val, laneMask, width);
} }
__device__ __forceinline__ int template <typename T>
shfl_sync(const int val, int srcLane = 0, int width = kWarpSize, __device__ __forceinline__ T shfl_sync(const T val, int srcLane = 0, int width = kWarpSize,
uint64_t shfl_sync_mask = kFullWarpMask) { // Let compiler deduce type uint64_t shfl_sync_mask = kFullWarpMask) { // Let compiler deduce type
return __shfl(val, srcLane, width); return __shfl(val, srcLane, width);
} }
...@@ -115,6 +115,15 @@ template <> struct VecInt<16> { ...@@ -115,6 +115,15 @@ template <> struct VecInt<16> {
using vec_t = native_int4; using vec_t = native_int4;
}; };
template <typename FuncT>
struct PatternVisitor {
FuncT func;
__device__ __host__ explicit PatternVisitor(FuncT&& func) : func(std::forward<FuncT>(func)) {}
__device__ __host__ auto operator[](const uint32_t& i) { return func(i); }
};
__device__ __forceinline__ void trap() { __device__ __forceinline__ void trap() {
abort(); abort();
} }
......
...@@ -923,7 +923,8 @@ class Buffer: ...@@ -923,7 +923,8 @@ class Buffer:
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, zero_copy: bool = False, async_finish: bool = False, handle: tuple, use_logfmt: bool = False,
zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None, return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \ combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \
Tuple[torch.Tensor, EventOverlap, Callable]: Tuple[torch.Tensor, EventOverlap, Callable]:
...@@ -944,6 +945,7 @@ class Buffer: ...@@ -944,6 +945,7 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor. tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function. handle: the communication handle given by the `dispatch` function.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`. with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
...@@ -964,7 +966,7 @@ class Buffer: ...@@ -964,7 +966,7 @@ class Buffer:
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
combine_wait_recv_cost_stats, combine_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
zero_copy, async_finish, return_recv_hook, out) use_logfmt, zero_copy, async_finish, return_recv_hook, out)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
......
...@@ -42,6 +42,7 @@ def test_main(num_tokens: int, ...@@ -42,6 +42,7 @@ def test_main(num_tokens: int,
num_ranks: int, num_ranks: int,
group: dist.ProcessGroup, group: dist.ProcessGroup,
buffer: deep_ep.Buffer, buffer: deep_ep.Buffer,
use_logfmt: bool = False,
seed: int = 0): seed: int = 0):
torch.manual_seed(seed + rank) torch.manual_seed(seed + rank)
random.seed(seed + rank) random.seed(seed + rank)
...@@ -56,10 +57,12 @@ def test_main(num_tokens: int, ...@@ -56,10 +57,12 @@ def test_main(num_tokens: int,
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset) x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1) x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
x_list = [x] x_list = [x]
# # NOTES: the last one is for performance testing for _ in range(4 if use_logfmt else 0):
# # Most of the values in the perf case is lower than the threshold, casting most channels # NOTES: make more LogFMT casts and also with some BF16
# x_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1 x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.5 * random.random())
# x_list = [x_rand] # NOTES: the last one is for performance testing
# Most of the values in the perf case is lower than the threshold, casting most channels
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
...@@ -79,7 +82,7 @@ def test_main(num_tokens: int, ...@@ -79,7 +82,7 @@ def test_main(num_tokens: int,
# Check dispatch correctness # Check dispatch correctness
do_check = True do_check = True
hash_value, num_times = 0, 0 hash_value, num_times = 0, 0
for current_x in x_list: for x_i, current_x in enumerate(x_list):
for return_recv_hook in (False, True): for return_recv_hook in (False, True):
for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2 for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant = quant_type > 0 dispatch_use_quant = quant_type > 0
...@@ -152,7 +155,7 @@ def test_main(num_tokens: int, ...@@ -152,7 +155,7 @@ def test_main(num_tokens: int,
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness # Check combine correctness
for zero_copy in (False, True): for zero_copy in (False, ) if use_logfmt else (False, True, ):
if zero_copy: if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
...@@ -160,6 +163,7 @@ def test_main(num_tokens: int, ...@@ -160,6 +163,7 @@ def test_main(num_tokens: int,
topk_idx, topk_idx,
topk_weights, topk_weights,
handle, handle,
use_logfmt=use_logfmt,
async_finish=not return_recv_hook, async_finish=not return_recv_hook,
zero_copy=zero_copy, zero_copy=zero_copy,
return_recv_hook=return_recv_hook, return_recv_hook=return_recv_hook,
...@@ -172,6 +176,10 @@ def test_main(num_tokens: int, ...@@ -172,6 +176,10 @@ def test_main(num_tokens: int,
assert diff < (9e-4 if dispatch_use_quant else 1e-5), f'Error: diff={diff}, dispatch_use_quant={dispatch_use_quant}, zero_copy={zero_copy}' assert diff < (9e-4 if dispatch_use_quant else 1e-5), f'Error: diff={diff}, dispatch_use_quant={dispatch_use_quant}, zero_copy={zero_copy}'
hash_value ^= hash_tensor(combined_x) hash_value ^= hash_tensor(combined_x)
if rank == 0:
print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ",
f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass")
# noinspection PyShadowingNames # noinspection PyShadowingNames
def large_gemm_with_hook(hook): def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float) mat_0 = torch.randn((8192, 8192), dtype=torch.float)
...@@ -190,6 +198,7 @@ def test_main(num_tokens: int, ...@@ -190,6 +198,7 @@ def test_main(num_tokens: int,
topk_idx, topk_idx,
topk_weights, topk_weights,
handle, handle,
use_logfmt=use_logfmt,
return_recv_hook=return_recv_hook) return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None large_gemm_with_hook(hook) if return_recv_hook else None
...@@ -251,6 +260,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -251,6 +260,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_ranks, num_ranks,
group, group,
buffer, buffer,
use_logfmt=args.use_logfmt,
seed=1) seed=1)
do_pressure_test = args.pressure_test do_pressure_test = args.pressure_test
...@@ -265,6 +275,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -265,6 +275,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_ranks, num_ranks,
group, group,
buffer, buffer,
use_logfmt=args.use_logfmt,
seed=seed) seed=seed)
for _ in range(20): for _ in range(20):
assert test_main(num_tokens, assert test_main(num_tokens,
...@@ -275,6 +286,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -275,6 +286,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_ranks, num_ranks,
group, group,
buffer, buffer,
use_logfmt=args.use_logfmt,
seed=seed) == ref_hash, f'Error: seed={seed}' seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group # Destroy the buffer runtime and communication group
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment