Commit 35735902 authored by lijian6's avatar lijian6
Browse files

Merge branch 'int8-main' into 'main'

支持int8类型的kernel接口

See merge request dcutoolkit/deeplearing/DeepEP!3
parents baa261b5 6dfe3bc2
...@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int ...@@ -1293,7 +1293,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
bool async, bool return_recv_hook) { bool async, bool return_recv_hook) {
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
...@@ -1316,13 +1316,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1316,13 +1316,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto buffer = layout.buffers[low_latency_buffer_idx]; auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
// Buffer control
LowLatencyLayout nvl_layout(nvl_buffer_ptrs[nvl_rank], num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
EP_HOST_ASSERT(nvl_layout.total_bytes <= num_rdma_bytes);
auto nvl_buffer = nvl_layout.buffers[low_latency_buffer_idx ^= 1];
auto nvl_next_buffer = nvl_layout.buffers[low_latency_buffer_idx ^= 1];
auto global_atomic_counter = torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto global_atomic_counter = torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Wait previous tasks to be finished // Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream // NOTES: the hook mode will always use the default stream
auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); auto compute_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
...@@ -1333,7 +1328,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1333,7 +1328,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Allocate packed tensors // Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fnuz: torch::kBFloat16)); x.options().dtype(use_int8 ? torch::kInt8 : use_fp8 ? torch::kFloat8_e4m3fnuz: torch::kBFloat16));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
...@@ -1345,13 +1340,18 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1345,13 +1340,18 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");
// TODO: support unaligned cases // TODO: support unaligned cases
EP_HOST_ASSERT(hidden % (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4) == 0); EP_HOST_ASSERT(hidden % (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4) == 0);
if (not use_ue8m0) { EP_HOST_ASSERT(!(use_ue8m0 && use_int8));
packed_recv_x_scales = torch::empty({num_local_experts, hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA)); if (use_ue8m0) {
} else {
EP_HOST_ASSERT(round_scale); EP_HOST_ASSERT(round_scale);
packed_recv_x_scales = torch::empty({num_local_experts, hidden / (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4), num_ranks * num_max_dispatch_tokens_per_rank}, packed_recv_x_scales = torch::empty({num_local_experts, hidden / (FP8_QUANTIZATION_NUM_PER_CHANNEL * 4), num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kInt).device(torch::kCUDA)); torch::dtype(torch::kInt).device(torch::kCUDA));
} else if (use_int8) {
packed_recv_x_scales = torch::empty({num_local_experts, 1, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
} else {
packed_recv_x_scales = torch::empty({num_local_experts, hidden / FP8_QUANTIZATION_NUM_PER_CHANNEL, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
} }
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
...@@ -1369,8 +1369,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i ...@@ -1369,8 +1369,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
x.data_ptr(), topk_idx.data_ptr<int64_t>(), x.data_ptr(), topk_idx.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second, next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank, num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks, num_topk, num_experts, rank, num_ranks,
use_fp8, round_scale, use_ue8m0, use_fp8, round_scale, use_ue8m0, use_int8,
workspace, num_device_sms, launch_stream, phases); workspace, num_device_sms, launch_stream, phases);
}; };
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
...@@ -1427,12 +1427,6 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id ...@@ -1427,12 +1427,6 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx]; auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
// Buffer control
LowLatencyLayout nvl_layout(nvl_buffer_ptrs[nvl_rank], num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
EP_HOST_ASSERT(nvl_layout.total_bytes <= num_rdma_bytes);
auto nvl_buffer = nvl_layout.buffers[low_latency_buffer_idx ^= 1];
auto nvl_next_buffer = nvl_layout.buffers[low_latency_buffer_idx ^= 1];
// Wait previous tasks to be finished // Wait previous tasks to be finished
// NOTES: the hook mode will always use the default stream // NOTES: the hook mode will always use the default stream
......
...@@ -177,7 +177,7 @@ public: ...@@ -177,7 +177,7 @@ public:
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
bool async, bool return_recv_hook); bool async, bool return_recv_hook);
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>> std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
......
...@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -147,7 +147,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
void* workspace, int num_device_sms, hipStream_t stream, int phases); void* workspace, int num_device_sms, hipStream_t stream, int phases);
void combine(void* combined_x, void combine(void* combined_x,
......
...@@ -31,8 +31,7 @@ __device__ void grid_barrier(int* global_counter, int num_blocks) { ...@@ -31,8 +31,7 @@ __device__ void grid_barrier(int* global_counter, int num_blocks) {
__syncthreads(); __syncthreads();
__threadfence(); __threadfence();
if (threadIdx.x == 0 ) { if (threadIdx.x == 0 ) {
// ret = __hip_atomic_fetch_add(&global_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); ret = __hip_atomic_fetch_add(&global_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
ret = atomicAdd(&global_counter[0], 1);
} }
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
...@@ -84,7 +83,7 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, ...@@ -84,7 +83,7 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0,
clean_0, num_clean_int_0, clean_1, num_clean_int_1); clean_0, num_clean_int_0, clean_1, num_clean_int_1);
} }
template <bool kUseFP8, bool kUseUE8M0, int kHidden> template <bool kUseFP8, bool kUseUE8M0, bool kUseInt8, int kHidden>
__global__ __launch_bounds__(16 * kWarpSize, 1) void __global__ __launch_bounds__(16 * kWarpSize, 1) void
dispatch(void* packed_recv_x, void* packed_recv_x_scales, dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_src_info, int64_t* packed_recv_layout_range,
...@@ -115,14 +114,14 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -115,14 +114,14 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// FP8 staffs // FP8 staffs
constexpr int kNumPerChannels = FP8_QUANTIZATION_NUM_PER_CHANNEL; constexpr int kNumPerChannels = FP8_QUANTIZATION_NUM_PER_CHANNEL;
const int num_scales = kHidden / kNumPerChannels; constexpr int kNumScales = kHidden / kNumPerChannels;
const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16)); const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16));
const size_t hidden_int4 = hidden_bytes / sizeof(int4); const size_t hidden_int4 = hidden_bytes / sizeof(int4);
// Message package: hidden data, FP8 scales, index at source // Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use // NOTES: currently we have 3 reserved int fields for future use
using vec_t = typename std::conditional<kUseFP8, int2, int4>::type; using vec_t = typename std::conditional<kUseFP8, int2, int4>::type;
const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16))); const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + kNumScales * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
...@@ -147,7 +146,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -147,7 +146,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization"); EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = (num_warps - 1) * kWarpSize; const auto num_threads = (num_warps - 1) * kWarpSize;
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; constexpr int hidden_bf16_int4 = kHidden / kNumElemsPerRead;
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4; const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
...@@ -159,13 +158,21 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -159,13 +158,21 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
__shared__ float int8_amaxf[kNumScales];
if constexpr(kUseInt8) {
if (thread_id < kNumScales) {
int8_amaxf[thread_id] = kFP8Margin;
}
__syncthreads();
}
// FP8 cast // FP8 cast
#pragma unroll #pragma unroll
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
// Read // Read
auto int4_value = __ldg(x_int4 + i); auto int4_value = __ldg(x_int4 + i);
if (kUseFP8) { if constexpr(kUseFP8) {
// Calculate local amax // Calculate local amax
auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value); auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
float fp32_values[kNumElemsPerRead]; float fp32_values[kNumElemsPerRead];
...@@ -178,25 +185,74 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -178,25 +185,74 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
// Reduce amax and scale // Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization"); EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization");
amax = warp_reduce_max<16>(amax); amax = warp_reduce_max<16>(amax);
calculate_fp8_scales(amax, scale, scale_inv, round_scale); const int scale_offset = i * kNumElemsPerRead / FP8_QUANTIZATION_NUM_PER_CHANNEL;
if (lane_id % 16 == 0)
rdma_x_scales[i * kNumElemsPerRead / FP8_QUANTIZATION_NUM_PER_CHANNEL] = scale_inv; if constexpr(kUseInt8) {
// 记录每128个数的最大值
// Cast into send buffer int8_amaxf[scale_offset] = fmaxf(amax, int8_amaxf[scale_offset]);
vec_t int2_value; } else {
auto fp8x2_values = reinterpret_cast<__hip_fp8x2_storage_t*>(&int2_value); calculate_fp8_scales(amax, scale, scale_inv, round_scale);
#pragma unroll if (lane_id % 16 == 0)
for (int j = 0; j < kNumElemsPerRead; j += 2) { rdma_x_scales[scale_offset] = scale_inv;
float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ); // Cast into send buffer
vec_t int2_value;
auto fp8x2_values = reinterpret_cast<__hip_fp8x2_storage_t*>(&int2_value);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; j += 2) {
float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale};
fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
}
rdma_x_vec[i] = int2_value;
} }
rdma_x_vec[i] = int2_value;
} else { } else {
// Reinterpret-cast is for C++14 compatibility // Reinterpret-cast is for C++14 compatibility
rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value); rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
} }
} }
__syncthreads(); __syncthreads();
if constexpr(kUseInt8) {
float amax_per_token = kFP8Margin;
// 并行规约,计算每个token的amax
for (int s = 0; s < kNumScales; s+=kWarpSize) {
int src_idx = s + lane_id;
float tmp_amaxf = 0;
if(src_idx < kNumScales) {
tmp_amaxf = int8_amaxf[src_idx];
}
tmp_amaxf = warp_reduce_max<kWarpSize>(tmp_amaxf);
int8_amaxf[0] = fmaxf(tmp_amaxf, int8_amaxf[0]);
__syncthreads();
}
amax_per_token = int8_amaxf[0];
// 根据最大值计算scale
float scale, scale_inv;
calculate_int8_scales(amax_per_token, scale, scale_inv);
if (thread_id == 0) {
rdma_x_scales[0] = scale_inv;
}
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
// Read
auto int4_value = __ldg(x_int4 + i);
auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
// Cast into send buffer
vec_t int2_value;
auto int8_values = reinterpret_cast<int8_t*>(&int2_value);
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++ j) {
auto fp32_value = static_cast<float>(bf16_values[j]);
auto fp32_value_scaled = fp32_value * scale;
int8_values[j] = static_cast<int8_t>(nearbyintf(fp32_value_scaled));
}
rdma_x_vec[i] = int2_value;
}
__syncthreads();
}
// Issue IBGDA sends // Issue IBGDA sends
if (dst_expert_idx >= 0) { if (dst_expert_idx >= 0) {
int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
...@@ -339,9 +395,10 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -339,9 +395,10 @@ LOW_LATENCY_DISPATCH_RECV:
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
const auto num_aligned_scales = ALIGN<int>(num_scales, sizeof(float) / sizeof(scale_t)); const auto num_aligned_scales = ALIGN<int>(kNumScales, sizeof(float) / sizeof(scale_t));
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) + const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales; local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank *
(kUseInt8 ? 1 : num_aligned_scales);
// Shared between sub-warps in warp groups // Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups];
...@@ -362,8 +419,7 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -362,8 +419,7 @@ LOW_LATENCY_DISPATCH_RECV:
// no needs to reset because there is no iteration // no needs to reset because there is no iteration
if (lane_id == 0){ if (lane_id == 0){
// volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP); volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
volatile int ret = atomicAdd((int*)&sync_large_warp_counters[warp_group_id], 1);
} }
syncwarp(); syncwarp();
...@@ -372,7 +428,7 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -372,7 +428,7 @@ LOW_LATENCY_DISPATCH_RECV:
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
// Copy tokens // Copy tokens
EP_DEVICE_ASSERT(num_scales <= 64); EP_DEVICE_ASSERT(kNumScales <= 64);
for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) { for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {
// Copy source info // Copy source info
const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg); const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
...@@ -387,24 +443,30 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -387,24 +443,30 @@ LOW_LATENCY_DISPATCH_RECV:
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); UNROLLED_WARP_COPY_LL(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
// Copy scales // Copy scales
if (kUseFP8) { if constexpr(kUseFP8) {
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes); const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t)); const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
const auto token_idx = recv_token_begin_idx + i; const auto token_idx = recv_token_begin_idx + i;
const auto token_stride = num_elems_per_pack; const auto token_stride = num_elems_per_pack;
const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack; const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
if (lane_id < num_scales) { if constexpr(kUseInt8) {
const auto pack_idx = lane_id / num_elems_per_pack; if (lane_id == 0) {
const auto elem_idx = lane_id % num_elems_per_pack; recv_x_scales[token_idx] = ld_nc_global(src_scales);
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id)); }
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; } else {
} if (lane_id < kNumScales) {
if (lane_id + kWarpSize < num_scales) { const auto pack_idx = lane_id / num_elems_per_pack;
const auto pack_idx = (lane_id + kWarpSize) / num_elems_per_pack; const auto elem_idx = lane_id % num_elems_per_pack;
const auto elem_idx = (lane_id + kWarpSize) % num_elems_per_pack; auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + kWarpSize)); recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; }
if (lane_id + kWarpSize < kNumScales) {
const auto pack_idx = (lane_id + kWarpSize) / num_elems_per_pack;
const auto elem_idx = (lane_id + kWarpSize) % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + kWarpSize));
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
} }
} }
} }
...@@ -420,7 +482,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -420,7 +482,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int64_t* next_clean, int num_next_clean_int, int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks, int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_fp8, bool round_scale, bool use_ue8m0, bool use_int8,
void* workspace, int num_device_sms, void* workspace, int num_device_sms,
hipStream_t stream, int phases) { hipStream_t stream, int phases) {
constexpr int kNumMaxTopK = 11; constexpr int kNumMaxTopK = 11;
...@@ -439,11 +501,13 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, ...@@ -439,11 +501,13 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
#define DISPATCH_LAUNCH_CASE(hidden) { \ #define DISPATCH_LAUNCH_CASE(hidden) { \
auto dispatch_func = dispatch<false, false, hidden>; \ auto dispatch_func = dispatch<false, false, false, hidden>; \
if (use_fp8 and not use_ue8m0) \ if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, hidden>; \ dispatch_func = dispatch<true, false, false, hidden>; \
if (use_fp8 and use_ue8m0) \ if (use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, hidden>; \ dispatch_func = dispatch<true, true, false, hidden>; \
if (use_int8) \
dispatch_func = dispatch<true, false, true, hidden>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \ packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \ packed_recv_src_info, packed_recv_layout_range, \
...@@ -575,8 +639,7 @@ combine(void* combined_x, ...@@ -575,8 +639,7 @@ combine(void* combined_x,
// Put finishing flag // Put finishing flag
EP_DEVICE_ASSERT(num_warps_per_group > 1); EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (lane_id == 0){ if (lane_id == 0){
// volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1,__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP); volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1,__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
volatile int ret = atomicAdd((int*)&sync_large_warp_counters[warp_group_id], 1);
} }
syncwarp(); syncwarp();
while (sync_large_warp_counters[warp_group_id] < num_warps_per_group); while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
......
...@@ -184,9 +184,8 @@ __device__ __forceinline__ int64_t ld_acquire_global(const int64_t *ptr) { ...@@ -184,9 +184,8 @@ __device__ __forceinline__ int64_t ld_acquire_global(const int64_t *ptr) {
__device__ __forceinline__ int atomic_add_release_global(const int *ptr, int value) { __device__ __forceinline__ int atomic_add_release_global(const int *ptr, int value) {
int ret; int ret;
// ret = __hip_atomic_fetch_add(const_cast<int *>(ptr), value, __ATOMIC_RELEASE, ret = __hip_atomic_fetch_add(const_cast<int *>(ptr), value, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT);
// __HIP_MEMORY_SCOPE_AGENT); // ret = atomicAdd((int*)ptr, value);
ret = atomicAdd((int*)ptr, value);
return ret; return ret;
} }
...@@ -342,15 +341,10 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) { ...@@ -342,15 +341,10 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
return *reinterpret_cast<dtype_t *>(recv_int_values); return *reinterpret_cast<dtype_t *>(recv_int_values);
} }
#ifndef FORCE_NVSHMEM_API
constexpr float kFP8Margin = 1e-4; constexpr float kFP8Margin = 1e-4;
constexpr float kFinfoAmaxE4M3 = 240.0f; constexpr float kFinfoAmaxE4M3 = 240.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3; constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
#else constexpr float kInt8Amax = 127.0f;
constexpr float kFP8Margin = 1e-4;
constexpr float kFinfoAmaxE4M3 = 448.0f;
constexpr float kFinfoAmaxInvE4M3 = 1.0f / kFinfoAmaxE4M3;
#endif
__forceinline__ __device__ float fast_pow2(int x) { __forceinline__ __device__ float fast_pow2(int x) {
// We can ensure `-126 <= x and x <= 127` // We can ensure `-126 <= x and x <= 127`
...@@ -376,6 +370,11 @@ __forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, f ...@@ -376,6 +370,11 @@ __forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, f
} }
} }
__forceinline__ __device__ void calculate_int8_scales(float amax, float& scale, float& scale_inv) {
scale = kInt8Amax / amax;
scale_inv = amax / kInt8Amax;
}
template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>> template <bool kIsUE8M0, typename out_dtype_t = std::conditional_t<kIsUE8M0, uint8_t, float>>
__forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) { __forceinline__ __device__ out_dtype_t extract_required_scale_format(float value) {
if constexpr (kIsUE8M0) { if constexpr (kIsUE8M0) {
......
...@@ -804,7 +804,7 @@ class Buffer: ...@@ -804,7 +804,7 @@ class Buffer:
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int, num_max_dispatch_tokens_per_rank: int, num_experts: int,
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False, use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False, use_int8: bool = False,
async_finish: bool = False, return_recv_hook: bool = False) -> \ async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
""" """
...@@ -824,6 +824,7 @@ class Buffer: ...@@ -824,6 +824,7 @@ class Buffer:
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors. use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
round_scale: whether round the scaling factors into power of 2. round_scale: whether round the scaling factors into power of 2.
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`). use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
use_int8: whether to enable INT8 casting.
async_finish: the current stream will not wait for the communication kernels to be finished if set. async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
...@@ -851,7 +852,7 @@ class Buffer: ...@@ -851,7 +852,7 @@ class Buffer:
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \ packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
self.runtime.low_latency_dispatch(x, topk_idx, self.runtime.low_latency_dispatch(x, topk_idx,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
use_fp8, round_scale, use_ue8m0, use_fp8, round_scale, use_ue8m0, use_int8,
async_finish, return_recv_hook) async_finish, return_recv_hook)
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts) handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
tensors_to_record = (x, topk_idx, tensors_to_record = (x, topk_idx,
......
import argparse
import random
import os
import torch
import torch.distributed as dist
from functools import partial
from typing import Literal, Set
import deep_ep
from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back, per_token_cast_back_int8
def test_main(num_tokens: int,
hidden: int,
num_experts: int,
num_topk: int,
rank: int,
num_ranks: int,
group: dist.ProcessGroup,
buffer: deep_ep.Buffer,
seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset = 128
assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)'
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset)
x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1)
x_list = [x]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs()
# Randomly mask some positions
for _ in range(10):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1
all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda')
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
# For failure simulation and shrink testing
mask_status = torch.zeros((num_ranks,), dtype=torch.int, device='cuda')
# Check dispatch correctness
do_check = True
hash_value, num_times = 0, 0
for current_x in x_list:
for return_recv_hook in (False, ):
for dispatch_use_fp8 in (True, ):
for round_scale in (False, ):
for use_ue8m0 in (False, ):
num_times += 1
use_int8 = True
for _ in range(1):
packed_recv_x, packed_recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, use_int8=use_int8,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
recv_x = per_token_cast_back_int8(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i]
recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i]
# Check expert indices
int_mask = (2 ** 32) - 1
num_valid_tokens = recv_count.item()
assert num_valid_tokens == (
recv_layout_range
& int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()'
assert num_valid_tokens == (all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item(
), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item()}'
if num_valid_tokens == 0:
continue
# Check received data
if current_x is x:
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x_amax)
if round_scale:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007
elif use_int8:
assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.01
else:
assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0
# for j in range(num_ranks):
# if (not round_scale):
# check_tmp1 = (recv_x_amin == j - rank_offset).sum().item()
# check_tmp2 = (all_topk_idx[j] == expert_id).sum().item()
# print(f'rank: {rank}, j: {j}, check_tmp1: {check_tmp1}, check_tmp2: {check_tmp2}, diff: {abs(check_tmp1 - check_tmp2)}')
# assert abs(check_tmp1 - check_tmp2) < 3
# assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_fp8:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else:
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
print("dispatch int 8 pass")
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
mat_1 = torch.randn((8192, 8192), dtype=torch.float)
mat_0 @ mat_1
hook()
# noinspection PyShadowingNames
def test_func(return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = \
buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts,
use_fp8=True, round_scale=False, use_ue8m0=False, use_int8=True,
async_finish=False, return_recv_hook=return_recv_hook)
large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_logfmt10_bytes = hidden * 10 / 8 + hidden / 128 * 4
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += num_bf16_bytes * num_selections
# Separate profiling
for return_recv_hook in (True, ):
group.barrier()
dispatch_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook),
kernel_names='dispatch',
barrier_comm_profiling=True,
suppress_kineto_output=True,
num_kernels_per_period=2 if return_recv_hook else 1)
if not return_recv_hook:
print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us',
flush=True)
else:
print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us',
flush=True)
return hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl)
test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1)
# do_pressure_test = args.pressure_test
# for seed in range(int(1e9) if do_pressure_test else 0):
# if local_rank == 0:
# print(f'Testing with seed {seed} ...', flush=True)
# ref_hash = test_main(num_tokens,
# hidden,
# num_experts,
# num_topk,
# rank,
# num_ranks,
# group,
# buffer,
# seed=seed)
# for _ in range(20):
# assert test_main(num_tokens,
# hidden,
# num_experts,
# num_topk,
# rank,
# num_ranks,
# group,
# buffer,
# seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group
buffer.destroy()
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__':
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512`
parser = argparse.ArgumentParser(description='Test low-latency EP kernels')
parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)')
parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)')
parser.add_argument('--hidden', type=int, default=2560, help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=256, help='Number of experts (default: 288)')
parser.add_argument('--allow-mnnvl', action="store_true", help='Allow MNNVL for communication')
parser.add_argument('--disable-nvlink', action='store_true', help='Whether to disable NVLink for testing')
parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
args = parser.parse_args()
num_processes = args.num_processes
torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes)
...@@ -73,6 +73,29 @@ def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor): ...@@ -73,6 +73,29 @@ def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
return (x_fp32_padded * x_scales).view(x_fp8_padded.shape).to(torch.bfloat16)[:,:n].contiguous() return (x_fp32_padded * x_scales).view(x_fp8_padded.shape).to(torch.bfloat16)[:,:n].contiguous()
def per_token_cast_back_int8(x_int8: torch.Tensor, x_scales: torch.Tensor):
"""
x_int8: [m, n] int8 tensor
x_scales: [m, n] 或 [m, 1] 或 [m, n/128] 量化 scale float
return: [m, n] bf16 tensor
"""
if x_int8.numel() == 0:
return x_int8.to(torch.bfloat16)
assert x_int8.dim() == 2
m, n = x_int8.shape
aligned_n = align_up(n, 128)
x_int8_padded = torch.nn.functional.pad(
x_int8, (0, aligned_n - n), mode='constant', value=0
)
x_fp32_padded = x_int8_padded.to(torch.float32).view(m, -1, 1)
x_scales = x_scales.view(m, -1, 1).to(torch.float32)
# print(f'x_int8.shape: {x_int8.shape}, x_fp32_padded: {x_fp32_padded.shape}, x_scales: {x_scales.shape}')
x_deq = (x_fp32_padded * x_scales).view(m, aligned_n)
return x_deq[:, :n].to(torch.bfloat16).contiguous()
def inplace_unique(x: torch.Tensor, num_slots: int): def inplace_unique(x: torch.Tensor, num_slots: int):
assert x.dim() == 2 assert x.dim() == 2
mask = x < 0 mask = x < 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment