"vscode:/vscode.git/clone" did not exist on "dfc9ef1c3d730b005c8a5041fe5a7dc66b8af213"
Commit 75b00cfb authored by lijian6's avatar lijian6
Browse files

Merge branch 'logfmt_master' into 'main'

低延迟combine支持10bit量化代码

See merge request dcutoolkit/deeplearing/DeepEP!21
parents dbf9fd61 17d9c844
......@@ -136,7 +136,7 @@ struct LowLatencyLayout {
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts, int quant_group_size=0) {
const int num_scales = quant_group_size == 0 ? 4 : hidden / QUANTIZATION_GROUPSIZE; // 应该是1,但是代码中为了满足int4对齐
const int num_scales = hidden / QUANTIZATION_GROUPSIZE;
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
......@@ -148,11 +148,11 @@ struct LowLatencyLayout {
// transformation
EP_HOST_ASSERT(num_scales * sizeof(float) <= static_cast<size_t>(hidden));
size_t num_bytes_per_dispatch_msg =
sizeof(int4) +
std::max(hidden * sizeof(hip_bfloat16), hidden + num_scales * sizeof(float));
sizeof(int4) + std::max(hidden * sizeof(hip_bfloat16), hidden +
(quant_group_size == 0 ? 4 : num_scales) * sizeof(float)); // 应该是1,但是代码中为了满足int4对齐
// 与internode_ll::combine 中的 num_bytes_per_slot 相等
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16);
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16) + num_scales * sizeof(__hip_bfloat162);
// Send buffer
size_t dispatch_send_buffer_bytes =
......
......@@ -1413,6 +1413,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) {
EP_HOST_ASSERT(low_latency_mode);
......@@ -1482,6 +1483,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
use_logfmt,
workspace, num_device_sms, launch_stream,
phases, zero_copy);
};
......
......@@ -185,6 +185,7 @@ public:
const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt);
......
......@@ -159,6 +159,7 @@ void combine(void* combined_x,
int64_t* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_logfmt,
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy);
......
......@@ -9,6 +9,7 @@
#include "hip/hip_runtime.h"
#include "shmem_wrapper.cuh"
#include "internode_ll_logfmt.cuh"
namespace deep_ep {
......@@ -612,7 +613,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
#undef DISPATCH_LAUNCH_CASE
}
template <int kHidden, int kNumMaxTopk, int kMaxNumWarps=16>
template <bool kUseLogFMT, int kHidden, int kNumMaxTopk, int kMaxNumWarps=16>
__global__ __launch_bounds__(16 * kWarpSize, 1) void
combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
......@@ -643,7 +644,24 @@ combine(void* combined_x,
// Message package
EP_STATIC_ASSERT(kHidden % QUANTIZATION_GROUPSIZE == 0, "Invalid hidden");
constexpr size_t num_bytes_per_slot = kHidden * sizeof(hip_bfloat16);
/////////////// LogFMT使用 ///////////////
constexpr int bSupportLogFMT = kUseLogFMT && hidden_bf16_int4 % (kWarpSize * 2) == 0;
constexpr int kNumSendUnrolls = bSupportLogFMT ? 2 : 1;
constexpr int kNumRecvUnrolls = bSupportLogFMT ? 2 : 1;
constexpr int kNumMsgInt4ElemPerWarp = kWarpSize * kNumSendUnrolls; // 每个warp发送的int4元素数据量,即每个warp发送 kNumMsgInt4ElemPerWarp*sizeof(int4)/sizeof(bfloat16)
EP_STATIC_ASSERT(hidden_bf16_int4 % (kNumSendUnrolls * kWarpSize) == 0, "Invalid hidden");
EP_STATIC_ASSERT(kNumSendUnrolls >= kNumRecvUnrolls, "Invalid unroll factors");
constexpr int kNumDivisions = kHidden / QUANTIZATION_GROUPSIZE;
constexpr int kNumMetaBytes = kNumDivisions * sizeof(__hip_bfloat162); // 用于记录数据的最大最小值
constexpr int kNumSendLogFMTBytes = kNumMsgInt4ElemPerWarp * sizeof(int4);
constexpr int kNumStages = 1; // 使用kNumStages>1,则需要的LDS大于64KB
constexpr int kLogFMTShmemSize = kMaxNumWarps * (kNumStages * kNumSendLogFMTBytes + kNumMetaBytes);
__shared__ uint8_t smem_buffer[kLogFMTShmemSize];
/////////////////////////////////////////////
constexpr size_t num_bytes_per_slot = kHidden * sizeof(hip_bfloat16) + kNumMetaBytes;
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// 初始化用于细粒度warp间同步的计数器数组
......@@ -683,6 +701,12 @@ combine(void* combined_x,
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto rdma_send_x_vec = reinterpret_cast<uint8_t*>(rdma_send_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// 用于logfmt的LDS
auto smem_ptr = smem_buffer + warp_id * (kNumStages * kNumSendLogFMTBytes + kNumMetaBytes);
// 存储logfmt的起始地址,并根据stage_idx进行索引块
auto logfmt_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast<int4*>(smem_ptr + i * kNumSendLogFMTBytes); });
// 存储logfmt的最大最小值
auto meta_buffers = bSupportLogFMT ? reinterpret_cast<__hip_bfloat162*>(smem_ptr + kNumStages * kNumSendLogFMTBytes) : nullptr;
// Unpack layout
int offset, num_tokens_to_send;
......@@ -699,20 +723,78 @@ combine(void* combined_x,
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
if (p2p_ptr == 0) { // RDMA
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
// 采用logfmt或者直接拷贝
uint64_t dst_p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
int num_send_bytes = hidden * sizeof(hip_bfloat16);
if (not zero_copy or dst_p2p_ptr != 0) {
// Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr`
const auto cpy_src_int4_ptr = zero_copy ? reinterpret_cast<int4*>(buf_ptr) : x_int4;
const auto cpy_dst_int4_ptr = dst_p2p_ptr == 0 ? reinterpret_cast<int4*>(buf_ptr): reinterpret_cast<int4*>(dst_p2p_ptr);
// 设置数据的真实偏移量
int logfmt_offset_bytes = kNumMetaBytes;
// 进入循环,逐步拷贝数据
constexpr int encode_num_warps = hidden_bf16_int4 / kNumMsgInt4ElemPerWarp;
for (int iter_idx = 0; iter_idx < encode_num_warps; ++iter_idx) {
int num_logfmt_bytes = kNumMsgInt4ElemPerWarp * sizeof(int4);
// 原始数据的warp级编译
int warp_offset = iter_idx * kNumMsgInt4ElemPerWarp;
if constexpr(bSupportLogFMT) {
// 采用 寄存器->lds->global 的流水线方式, 量化后拷贝到buf_ptr中
const int& stage_idx = iter_idx % kNumStages;
// thread偏移
int thread_offset = warp_offset + lane_id * kNumSendUnrolls;
constexpr int kNumInt4PerDivision = 128 / kNumElemsPerInt4; // = 128/(sizeof(int4) / sizeof(hip_bfloat16)) = 128/(16/2)=16
num_logfmt_bytes = logfmt_encode<kNumSendUnrolls>(
cpy_src_int4_ptr + warp_offset, // 等同于 x_int4
logfmt_buffers[stage_idx],
// NOTES: only the leader lane will write the result
(thread_offset % kNumInt4PerDivision == 0) ? meta_buffers + thread_offset / kNumInt4PerDivision : nullptr,
lane_id);
// 将量化后的数据写入
using vec_type = uint32_t;
UNROLLED_WARP_COPY_LL(2, lane_id, num_logfmt_bytes / sizeof(vec_type),
reinterpret_cast<vec_type *>(reinterpret_cast<uint8_t*>(cpy_dst_int4_ptr) + logfmt_offset_bytes),
reinterpret_cast<vec_type *>(logfmt_buffers[stage_idx]),
ld_nc_global, st_na_global);
// 起始地址偏移
logfmt_offset_bytes += num_logfmt_bytes;
} else {
// 非量化数据的传输
UNROLLED_WARP_COPY_LL(2, lane_id, kNumMsgInt4ElemPerWarp,
reinterpret_cast<int4*>(cpy_dst_int4_ptr + warp_offset),
reinterpret_cast<const int4*>(cpy_src_int4_ptr + warp_offset),
ld_nc_global, st_na_global);
}
syncwarp();
}
// Store metadata (min/max values) for LogFMT
if constexpr (bSupportLogFMT) {
// 最终设置节点间传输的字节数
num_send_bytes = logfmt_offset_bytes;
using vec_type = uint32_t;
auto meta_buffers_ptr = reinterpret_cast<vec_type*>(meta_buffers);
auto cpy_dst_uint32_ptr = reinterpret_cast<vec_type*>(cpy_dst_int4_ptr);
for(int j = lane_id; j < kNumMetaBytes / sizeof(vec_type); j+=kWarpSize) {
*(cpy_dst_uint32_ptr + j) = meta_buffers_ptr[j];
}
}
syncwarp();
}
if (dst_p2p_ptr == 0) {
internode_ll_putmem_nbi((void*)dst_ptr, (void*)buf_ptr,
num_ranks, dst_rank, local_expert_idx,
hidden * sizeof(hip_bfloat16));
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(x_int4);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(p2p_ptr);
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
num_send_bytes);
}
}
......@@ -773,40 +855,136 @@ LOW_LATENCY_COMBINE_RECV:
// Reduce tokens with FP8 cast
// EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
if (thread_id < hidden_bf16_int4) {
// 计算需要多少个warp
constexpr int num_decode_warps = hidden_bf16_int4 / (kNumRecvUnrolls * kWarpSize);
// 限制thread_id
if (warp_id >= num_decode_warps) {
return;
}
// 每128个数据记录一个max/min值,即该数为总的max/min值数量
constexpr int kNumDivisionBytes = kNumDivisions * sizeof(float);
// 每个warp内总的BF16值的数量
constexpr int kNumBF16PerWarpBytes = kWarpSize * kNumRecvUnrolls * sizeof(int4);
constexpr int kNumLogFMTPerWarpBytes = kNumBF16PerWarpBytes * 10 / 16;
// 用于记录 max/min 值的 log 值
auto log_amax_buffers =
PatternVisitor([=](const int& i) { return reinterpret_cast<float*>(smem_buffer + i * kNumDivisionBytes); });
auto log_amin_buffers = PatternVisitor([=](const int& i) {
return reinterpret_cast<float*>(smem_buffer + kNumStages * kNumDivisionBytes + i * kNumDivisionBytes);
});
auto cast_info_buffers = PatternVisitor([=](const int& i) {
return reinterpret_cast<int*>(smem_buffer + kNumStages * kNumDivisionBytes * 2 + i * kNumDivisionBytes);
});
// 初始化 topk_idx 和 topk_weights
int topk_idx_by_lane = -1;
float topk_weights_by_lane = -1;
int stage_idx = 0;
for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
// Read top-k indices and weights
int reg_topk_idx[kNumMaxTopk];
float reg_topk_weights[kNumMaxTopk];
if (lane_id < num_topk) {
topk_idx_by_lane = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + lane_id));
topk_weights_by_lane = __ldg(topk_weights + token_idx * num_topk + lane_id);
}
float combined_values[kNumElemsPerInt4 * kNumRecvUnrolls] = {0.0f};
#pragma unroll
for (int i = 0; i < num_topk; ++ i) {
reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
int topk_idx_reg = shfl_sync(topk_idx_by_lane, i);
if (topk_idx_reg < 0)
continue;
const auto& topk_weight_reg = shfl_sync(topk_weights_by_lane, i);
// Read from sources
auto rdma_buffer_type = reinterpret_cast<const uint8_t*>(reinterpret_cast<uint8_t*>(rdma_recv_x) +
(topk_idx_reg * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
if constexpr(bSupportLogFMT) {
// 接收到的数据位置
const uint8_t* data_buffer = rdma_buffer_type + kNumMetaBytes;
// 读取max/min数据
if(warp_id == 0) {
// 因为每个warp能处理数据量为 kWarpSize*sizeof(int4)/sizeof(bfloat16) * kNumSendUnrolls
// 即不考虑kNumSendUnrolls,一共 kWarpSize*sizeof(int4)/sizeof(bfloat16)/128 组, 代入参数 = kWarpSize / 16 个warp,nv上为2,dcu上为4
logfmt_check_amaxmin<kNumDivisions / (kWarpSize / 16), kNumSendUnrolls, kNumRecvUnrolls>(
/*meta_buffer*/rdma_buffer_type,
reinterpret_cast<int4*>(log_amax_buffers[stage_idx]),
reinterpret_cast<int4*>(log_amin_buffers[stage_idx]),
cast_info_buffers[stage_idx],
lane_id);
}
float combined_values[kNumElemsPerInt4] = {0.0f};
__syncthreads();
// 获取cast_info_buffers
const auto& info = cast_info_buffers[stage_idx][warp_id];
bool enable_cast = info & 1;
int num_casted_prefix = info >> 1; // 可用的
// 计算偏移(与TMA版本逻辑一致)
int warp_offset = kNumLogFMTPerWarpBytes * num_casted_prefix +
kNumBF16PerWarpBytes * (warp_id - num_casted_prefix);
int lane_offset = (enable_cast ? kNumLogFMTPerWarpBytes : kNumBF16PerWarpBytes) / kWarpSize * lane_id;
// 使用临时缓冲区进行归约
const uint8_t* thread_data_ptr = data_buffer + warp_offset + lane_offset;
/**
一共有kNumDivisions个max/min数据对,读取时每warp默认处理256bit的max/min,所以logfmt_check_amaxmin的kNumLanes设置为 kNumDivisions/2
保存数据时每个log_amax_buffers为float2数据类型,保存总的warpkNumDivisions / 2
实际保存数据时,每个warp保存的实际数据个数为 kWarpSize*kNumRecvUnrolls*sizeof(int4)/sizeof(hip_bfloat16)
实际每个warp读取的max/min的 warp_idx=kWarpSize*kNumRecvUnrolls*sizeof(int4)/sizeof(hip_bfloat16) / 128 = kNumRecvUnrolls * 2
具体的lane_id处理的数据量为 warp_idx / kWarpSize
*/
int log_amaxmin_per_warp = kNumRecvUnrolls * kWarpSize * sizeof(int4) / sizeof(hip_bfloat16) / QUANTIZATION_GROUPSIZE;
int division_idx = warp_id * log_amaxmin_per_warp + lane_id * log_amaxmin_per_warp / kWarpSize;
// 反量化
decode_and_accumulate<kNumRecvUnrolls>(
reinterpret_cast<const uint32_t*>(thread_data_ptr), // 直接使用全局内存地址
combined_values,
log_amax_buffers[stage_idx][division_idx],
log_amin_buffers[stage_idx][division_idx],
enable_cast,
topk_weight_reg);
} else {
// 接收到的数据位置
const uint8_t* data_buffer = rdma_buffer_type;
// 计算偏移
int warp_offset = kNumBF16PerWarpBytes * warp_id;
int lane_offset = kNumBF16PerWarpBytes / kWarpSize * lane_id;
// 使用临时缓冲区进行归约
const uint8_t* thread_data_ptr = data_buffer + warp_offset + lane_offset;
#pragma unroll
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// Read from sources
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) +
(reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
for (int j = 0; j < kNumRecvUnrolls; ++j) {
auto tmp_rdma_value = ld_nc_global(reinterpret_cast<const int4*>(thread_data_ptr) + j);
const auto x_bf16 = reinterpret_cast<const hip_bfloat16*>(&tmp_rdma_value);
// Reduce
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
const auto x_bf16 = reinterpret_cast<hip_bfloat16*>(&x_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
for (int k = 0; k < kNumElemsPerInt4; ++k) {
int combined_idx = j * kNumElemsPerInt4 + k;
combined_values[combined_idx] += static_cast<float>(x_bf16[k]) * topk_weight_reg;
}
}
}
}
// Write results
int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
auto combined_bf16 = reinterpret_cast<hip_bfloat16*>(&combined_values);
// Write results,kNumRecvUnrolls==2时则写256bit的数
int4 combined_int4[kNumRecvUnrolls];
auto combined_bf16 = reinterpret_cast<hip_bfloat16 *>(&combined_int4[0]);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
for (int j = 0; j < kNumElemsPerInt4 * kNumRecvUnrolls; ++ j) {
combined_bf16[j] = static_cast<hip_bfloat16>(combined_values[j]);
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
}
for(int j = 0; j < kNumRecvUnrolls; ++ j) {
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4 +
warp_id * kWarpSize * kNumRecvUnrolls)[lane_id * kNumRecvUnrolls + j] = combined_int4[j];
}
}
}
......@@ -820,6 +998,7 @@ void combine(void* combined_x,
int64_t* next_clean, int num_next_clean_int,
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_logfmt,
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy) {
constexpr int kMaxNumWarps = 16;
......@@ -840,7 +1019,9 @@ void combine(void* combined_x,
#define COMBINE_LAUNCH_CASE(hidden) \
{ \
auto combine_func = combine<hidden, kNumMaxTopk, kMaxNumWarps>; \
auto combine_func = use_logfmt ? \
combine<true, hidden, kNumMaxTopk, kMaxNumWarps> : \
combine<false, hidden, kNumMaxTopk, kMaxNumWarps>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \
combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
......
#pragma once
#include "configs.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "buffer.cuh"
#include "utils.cuh"
#include <iostream>
#include "hip/hip_runtime.h"
#include "shmem_wrapper.cuh"
namespace deep_ep {
namespace internode_ll {
template <int kNumSendUnrolls>
__forceinline__ __device__ int logfmt_encode(const int4* cpy_src_int4_ptr, int4* dst_buffer, __hip_bfloat162* shared_amaxmin, const int& lane_id) {
EP_STATIC_ASSERT(kNumSendUnrolls == 2, "kNumSendUnrolls == 2 only");
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(__hip_bfloat16); // 8
constexpr float kLogThreshold = 0;
constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))`
constexpr int kNumBits = 10;
constexpr int kNumValues = 1 << (kNumBits - 1); // = 512
constexpr int kSendValueBytes = kNumSendUnrolls * sizeof(int4); //=2*16=32
constexpr int kNumElementPerInt4 = sizeof(int4) / sizeof(uint32_t);
int4 int4_values[kNumSendUnrolls];
const auto& uint32_values = reinterpret_cast<uint32_t*>(int4_values);
const auto& bf162_values = reinterpret_cast<__hip_bfloat162*>(int4_values);
// Calculate lane offset
const auto& ld_buffer = cpy_src_int4_ptr + lane_id * kNumSendUnrolls;
// Local log amax
auto bf162_amax = __hip_bfloat162(HIPRT_ZERO_BF16, HIPRT_ZERO_BF16);
auto bf162_amin = __hip_bfloat162(HIPRT_INF_BF16, HIPRT_INF_BF16);
uint32_t local_signs = 0;
#pragma unroll
for (int v = 0; v < kNumSendUnrolls; ++v) {
int4 ld_int4_value = ld_nc_global(ld_buffer + v); // 向量化读取
uint32_t* ld_u32_ptr = reinterpret_cast<uint32_t*>(&ld_int4_value);
#pragma unroll
for (int k = 0; k < kNumElementPerInt4; ++k) { // 也是kNumSendUnrolls * kNumElemsPerInt4 / 2
// TODO: eliminate bank conflicts
uint32_t ld_u32_value = ld_u32_ptr[k];
int k_offset = v * kNumElementPerInt4 + k;
// 提取符号位: 每个bfloat16的最高位是符号位
local_signs |= ((ld_u32_value >> 15) & 1) << (k_offset * 2);
local_signs |= ((ld_u32_value >> 31) & 1) << (k_offset * 2 + 1);
// 清除符号位,保留幅值
ld_u32_value &= 0x7fff7fff;
auto ld_bf16_value = *reinterpret_cast<__hip_bfloat162*>(&ld_u32_value);
bf162_amax = __hmax2(bf162_amax, ld_bf16_value);
bf162_amin = __hmin2(bf162_amin, ld_bf16_value);
uint32_values[k_offset] = ld_u32_value;
}
}
// Reduce per 128 channels
// TODO: figure out how hardware do 2-byte min/max
auto amax = __builtin_fmaxf(static_cast<float>(bf162_amax.x), static_cast<float>(bf162_amax.y));
auto amin = __builtin_fminf(static_cast<float>(bf162_amin.x), static_cast<float>(bf162_amin.y));
// 即每128个值进行一次reduce
constexpr static int kNumLanesToReduce = 128 * sizeof(__hip_bfloat16) / kSendValueBytes; // =128*2 / (kNumSendUnrolls * sizeof(int4)) = 8
amax = warp_reduce_max<kNumLanesToReduce>(amax);
amin = warp_reduce_min<kNumLanesToReduce>(amin);
// Write min/max into the shared memory
if (shared_amaxmin != nullptr) {
*shared_amaxmin = __hip_bfloat162(amax, amin);
}
syncwarp();
// Calculate log amin/amax float
const auto& log_amax = __builtin_log2f(amax);
const auto& log_amin = __builtin_fmaxf(__builtin_log2f(amin), log_amax - kMinClip);
// 在组内广播enable_cast结果
const bool& enable_cast = warp_reduce_and<kNumLanesToReduce, true>(log_amax < kLogThreshold and log_amin < log_amax);
// Case into LogFMT-10 if satisfied
if (enable_cast) {
constexpr int dst_buffer_step = kSendValueBytes * 10 / 16;
const auto& st_buffer = reinterpret_cast<uint32_t*>(reinterpret_cast<uint8_t*>(dst_buffer) + lane_id * dst_buffer_step);
uint32_t st_u32_values[dst_buffer_step / sizeof(uint32_t)]; // = 5
// 计算10bit数据的两个相邻数值的差值
const auto step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
const auto step_inv = 1.0f / step;
// 计算舍入值
const auto rounding = 2.0f - __builtin_log2f((1.0f + __builtin_exp2f(step)) * 0.5f) * step_inv;
const auto fused_rounding = rounding - log_amin * step_inv;
// 用于存储编码后的值
uint32_t encoded[kNumElemsPerInt4 * 2];
// 展开循环,处理数据打包
{
// 将int4值(128bit)转换为 bfloat162
#pragma unroll
for (int k = 0; k < kNumElemsPerInt4; ++k) { // 8
// 将 bfloat162 转换为 float2
const auto& fp162_fvalue = __bfloat1622float2(bf162_values[k]);
/*
实际进行压缩的公式为:
q = clamp( round( (log2(abs(x)) - log_min) / (log_max - log_min) * (K - 2) + 0.5 ), 0, K - 1)
其中:
x: 输入的浮点数
q: 输出的整数,表示压缩后的值
log_min: 输入中最小值的log2值
log_max: 输入中最大值的log2值
K: 压缩后的整数的最大值(即,K为2的幂)
*/
// 对 float 值进行编码
encoded[k * 2 + 0] = __float2uint_rd(__builtin_fmaxf(__builtin_log2f(fp162_fvalue.x) * step_inv + fused_rounding, 0));
encoded[k * 2 + 1] = __float2uint_rd(__builtin_fmaxf(__builtin_log2f(fp162_fvalue.y) * step_inv + fused_rounding, 0));
}
// 批量打包编码后的值到 st_buffer
st_u32_values[0] = (encoded[0] >> 0) | (encoded[1] << 9) | (encoded[2] << 18) | (encoded[3] << 27);
st_u32_values[1] = (encoded[3] >> 5) | (encoded[4] << 4) | (encoded[5] << 13) | (encoded[6] << 22) | (encoded[7] << 31);
st_u32_values[2] = (encoded[7] >> 1) | (encoded[8] << 8) | (encoded[9] << 17) | (encoded[10] << 26);
st_u32_values[3] = (encoded[10] >> 6) | (encoded[11] << 3) | (encoded[12] << 12) | (encoded[13] << 21) | (encoded[14] << 30);
st_u32_values[4] = (encoded[14] >> 2) | (encoded[15] << 7) | (local_signs << 16);
}
// 保存160bit的数据到st_buffer
st_buffer[0] = st_u32_values[0];
*(reinterpret_cast<int4*>(st_buffer + 1)) = *(reinterpret_cast<int4*>(st_u32_values + 1));
} else {
// 准备收发数据
using vec_type = int4;
const auto& ld_buffer_vec = reinterpret_cast<const vec_type*>(ld_buffer);
auto st_buffer_vec = reinterpret_cast<vec_type*>(reinterpret_cast<uint8_t*>(dst_buffer) + lane_id * kSendValueBytes);
constexpr int kLoopIter = kSendValueBytes / sizeof(vec_type);
#pragma unroll
for (int k = 0; k < kLoopIter; ++k) {
st_buffer_vec[k] = ld_nc_global(ld_buffer_vec + k);
}
}
// 确保 warp 内的所有线程都完成打包操作
syncwarp();
// 计算量化成功和失败时的数据量
constexpr int unable_cast_num_bytes = kWarpSize * kSendValueBytes; // = 64*2*16 = 2048
constexpr int enable_cast_num_bytes = unable_cast_num_bytes * 10 / 16; // = 2048/16*10=1280
// Return TMA copy bytes
return enable_cast ? enable_cast_num_bytes : unable_cast_num_bytes;
}
template <int kNumLanes, int kNumSendUnrolls, int kNumRecvUnrolls>
__forceinline__ __device__ void logfmt_check_amaxmin(
const uint8_t* meta_buffer, int4* shared_log_amax, int4* shared_log_amin, int* shared_cast_info, const int lane_id) {
// 定义log阈值和最小剪切值
constexpr float kLogThreshold = 0;
constexpr float kMinClip = 32; // `== log_2(2 ^ (2 ^ 5))`
constexpr int kNumQuantGroupsPerWarp = kWarpSize / 16;
using log_vec_type = int4;
EP_STATIC_ASSERT(sizeof(log_vec_type) / sizeof(__hip_bfloat162) == kNumQuantGroupsPerWarp, "kNumQuantGroupsPerWarp == sizeof(log_vec_type) only");
// 初始化类型转换启用标志
bool enable_cast = true;
// 如果 lane_id 小于 kNumLanes,则进行计算
if (lane_id < kNumLanes) {
// 从 meta_buffer 中读取 amaxmin2 值
auto amaxmin4 = reinterpret_cast<const log_vec_type*>(meta_buffer)[lane_id];
const auto& bf162_amaxmin = reinterpret_cast<__hip_bfloat162*>(&amaxmin4);
// 定义 log_amax 和 log_amin 数组
float log_amax[kNumQuantGroupsPerWarp], log_amin[kNumQuantGroupsPerWarp];
// 展开循环,计算 log_amax 和 log_amin
#pragma unroll
for (int i = 0; i < kNumQuantGroupsPerWarp; ++i) { // sizeof(uint64_t) / sizeof(__hip_bfloat162) = 2
auto amax = static_cast<float>(bf162_amaxmin[i].x);
auto amin = static_cast<float>(bf162_amaxmin[i].y);
log_amax[i] = __builtin_log2f(amax);
log_amin[i] = amin == 0 ? log_amax[i] - kMinClip : __builtin_fmaxf(__builtin_log2f(amin), log_amax[i] - kMinClip);
enable_cast = enable_cast and log_amax[i] < kLogThreshold and log_amin[i] < log_amax[i];
}
// 将计算结果存储到 shared_log_amax 和 shared_log_amin 中
int4 log_amax_int4 = *reinterpret_cast<int4*>(log_amax);
int4 log_amin_int4 = *reinterpret_cast<int4*>(log_amin);
shared_log_amax[lane_id] = log_amax_int4;
shared_log_amin[lane_id] = log_amin_int4;
}
// 计算 casted 值。根据当前线程是否启用了类型转换,计算它所属的组的索引
const auto& casted = warp_reduce_and<kNumSendUnrolls>(enable_cast) ? 1u << (lane_id / kNumRecvUnrolls) : 0u;
// 计算 num_casted_prefix 值。计算当前线程之前有多少个线程启用了类型转换。
const auto& num_casted_prefix = __popc(warp_reduce_or<kNumRecvUnrolls, true>(casted) & ((1u << (lane_id / kNumRecvUnrolls)) - 1));
// 如果 lane_id 小于 kNumLanes 且 lane_id 是 kNumRecvUnrolls 的倍数,则更新 shared_cast_info
if (lane_id < kNumLanes and lane_id % kNumRecvUnrolls == 0) {
// 最低1位保存casted结果,最高31位保存num_casted_prefix值
shared_cast_info[lane_id / kNumRecvUnrolls] = (num_casted_prefix << 1) | (casted ? 1u : 0u);
}
}
template <int kNumRecvUnrolls>
__forceinline__ __device__ void decode_and_accumulate(
const uint32_t* ld_buffer, float* accum, const float& log_amax, const float& log_amin,
const bool& enable_cast, const float& weight) {
EP_STATIC_ASSERT(kNumRecvUnrolls == 2, "kNumRecvUnrolls == 2 only");
if (enable_cast) {
constexpr int kNumBits = 10;
constexpr int kNumValues = 1 << (kNumBits - 1);
const auto& step = (log_amax - log_amin) / static_cast<float>(kNumValues - 2);
auto decode = [=](const uint32_t& encoded, const uint32_t& sign) {
const auto decoded = encoded == 0 ? .0f : __builtin_exp2f((encoded - 1) * step + log_amin);
return sign ? -decoded : decoded;
};
uint32_t concat[6];
concat[0] = ld_buffer[0];
#pragma unroll
for (int k = 1; k < 5; ++k)
concat[k] = (ld_buffer[k - 1] >> (32 - k * 5)) | (ld_buffer[k] << (k * 5));
concat[5] = ld_buffer[4] >> 7;
const uint32_t& local_signs = ld_buffer[4] >> 16;
#pragma unroll
for (int k = 0; k < 5; ++k) {
accum[k * 3 + 0] += decode((concat[k] >> 0) & 0x1ff, (local_signs >> (k * 3 + 0)) & 1) * weight;
accum[k * 3 + 1] += decode((concat[k] >> 9) & 0x1ff, (local_signs >> (k * 3 + 1)) & 1) * weight;
accum[k * 3 + 2] += decode((concat[k] >> 18) & 0x1ff, (local_signs >> (k * 3 + 2)) & 1) * weight;
}
accum[15] += decode(concat[5] & 0x1ff, (local_signs >> 15) & 1) * weight;
} else {
constexpr int kLoopIter = kNumRecvUnrolls * sizeof(int4) / sizeof(uint32_t);
#pragma unroll
for (int k = 0; k < kLoopIter; ++k) {
auto bf16_pack = *reinterpret_cast<const __hip_bfloat162*>(ld_buffer + k);
accum[k * 2 + 0] += static_cast<float>(bf16_pack.x) * weight;
accum[k * 2 + 1] += static_cast<float>(bf16_pack.y) * weight;
}
}
}
} // namespace internode_ll
} // namespace deep_ep
......@@ -72,8 +72,8 @@ __device__ __forceinline__ T shfl_xor(const T val, int laneMask, int width = kWa
return __shfl_xor(val, laneMask, width);
}
__device__ __forceinline__ int
shfl_sync(const int val, int srcLane = 0, int width = kWarpSize,
template <typename T>
__device__ __forceinline__ T shfl_sync(const T val, int srcLane = 0, int width = kWarpSize,
uint64_t shfl_sync_mask = kFullWarpMask) { // Let compiler deduce type
return __shfl(val, srcLane, width);
}
......@@ -115,6 +115,15 @@ template <> struct VecInt<16> {
using vec_t = native_int4;
};
template <typename FuncT>
struct PatternVisitor {
FuncT func;
__device__ __host__ explicit PatternVisitor(FuncT&& func) : func(std::forward<FuncT>(func)) {}
__device__ __host__ auto operator[](const uint32_t& i) { return func(i); }
};
__device__ __forceinline__ void trap() {
abort();
}
......
......@@ -923,7 +923,8 @@ class Buffer:
# noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, zero_copy: bool = False, async_finish: bool = False,
handle: tuple, use_logfmt: bool = False,
zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \
Tuple[torch.Tensor, EventOverlap, Callable]:
......@@ -944,6 +945,7 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
......@@ -964,7 +966,7 @@ class Buffer:
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
combine_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts,
zero_copy, async_finish, return_recv_hook, out)
use_logfmt, zero_copy, async_finish, return_recv_hook, out)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
......
......@@ -42,6 +42,7 @@ def test_main(num_tokens: int,
num_ranks: int,
group: dist.ProcessGroup,
buffer: deep_ep.Buffer,
use_logfmt: bool = False,
seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
......@@ -56,10 +57,12 @@ def test_main(num_tokens: int,
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
x_list = [x]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1
# x_list = [x_rand]
for _ in range(4 if use_logfmt else 0):
# NOTES: make more LogFMT casts and also with some BF16
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.5 * random.random())
# NOTES: the last one is for performance testing
# Most of the values in the perf case is lower than the threshold, casting most channels
x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
......@@ -79,7 +82,7 @@ def test_main(num_tokens: int,
# Check dispatch correctness
do_check = True
hash_value, num_times = 0, 0
for current_x in x_list:
for x_i, current_x in enumerate(x_list):
for return_recv_hook in (False, True):
for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant = quant_type > 0
......@@ -152,7 +155,7 @@ def test_main(num_tokens: int,
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness
for zero_copy in (False, True):
for zero_copy in (False, ) if use_logfmt else (False, True, ):
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
......@@ -160,6 +163,7 @@ def test_main(num_tokens: int,
topk_idx,
topk_weights,
handle,
use_logfmt=use_logfmt,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
......@@ -172,6 +176,10 @@ def test_main(num_tokens: int,
assert diff < (9e-4 if dispatch_use_quant else 1e-5), f'Error: diff={diff}, dispatch_use_quant={dispatch_use_quant}, zero_copy={zero_copy}'
hash_value ^= hash_tensor(combined_x)
if rank == 0:
print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ",
f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass")
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
......@@ -190,6 +198,7 @@ def test_main(num_tokens: int,
topk_idx,
topk_weights,
handle,
use_logfmt=use_logfmt,
return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
......@@ -251,6 +260,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_ranks,
group,
buffer,
use_logfmt=args.use_logfmt,
seed=1)
do_pressure_test = args.pressure_test
......@@ -265,6 +275,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_ranks,
group,
buffer,
use_logfmt=args.use_logfmt,
seed=seed)
for _ in range(20):
assert test_main(num_tokens,
......@@ -275,6 +286,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_ranks,
group,
buffer,
use_logfmt=args.use_logfmt,
seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment