Commit da6da7c3 authored by lishen's avatar lishen
Browse files

low-latency添加dispatch分层优化和combine gemm overlap

parent ea76f44e
......@@ -135,9 +135,11 @@ struct LowLatencyLayout {
}
LowLatencyLayout(void *rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts, int quant_group_size=0) {
int num_ranks, int num_experts, bool enable_dispatch_ll_layered=false, int quant_group_size=0) {
const int num_scales = hidden / QUANTIZATION_GROUPSIZE;
const int num_nodes = num_ranks / NUM_MAX_NVL_PEERS; // 计算结点数
// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even receive buffers
......@@ -152,7 +154,9 @@ struct LowLatencyLayout {
(quant_group_size == 0 ? 4 : num_scales) * sizeof(float)); // 应该是1,但是代码中为了满足int4对齐
// 与internode_ll::combine 中的 num_bytes_per_slot 相等
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16) + num_scales * sizeof(__hip_bfloat162);
size_t num_bytes_per_combine_msg = hidden * sizeof(hip_bfloat16) +
(enable_dispatch_ll_layered ? 0 : // 即enable_combine_overlap==true,执行函数combine_sbo
num_scales * sizeof(__hip_bfloat162));
// Send buffer
size_t dispatch_send_buffer_bytes =
......@@ -176,6 +180,10 @@ struct LowLatencyLayout {
// Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int64_t);
if (enable_dispatch_ll_layered) {
dispatch_recv_count_buffer_bytes +=
NUM_MAX_NVL_PEERS * num_nodes * num_max_dispatch_tokens_per_rank * sizeof(int) + NUM_MAX_NVL_PEERS * sizeof(int);
}
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes_aligned = ALIGN<size_t>(signaling_buffer_bytes, 128);
......@@ -205,9 +213,11 @@ struct LowLatencyLayout {
};
inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden,
int num_ranks, int num_experts, int quant_group_size=0) {
int num_ranks, int num_experts,
bool enable_dispatch_ll_layered=false, int quant_group_size=0) {
auto num_bytes =
LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size)
LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered, quant_group_size)
.total_bytes;
return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) *
NUM_BUFFER_ALIGNMENT_BYTES;
......
......@@ -13,11 +13,14 @@
namespace deep_ep {
Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy, bool enable_shrink)
bool low_latency_mode, bool explicitly_destroy, bool enable_shrink,
bool enable_dispatch_ll_layered, bool enable_combine_overlap)
: rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes),
num_rdma_bytes(num_rdma_bytes), low_latency_mode(low_latency_mode),
explicitly_destroy(explicitly_destroy),
enable_shrink(enable_shrink),
enable_dispatch_ll_layered(enable_dispatch_ll_layered),
enable_combine_overlap(enable_combine_overlap),
comm_stream(at::hip::getStreamFromPoolMasqueradingAsCUDA(true)) {
// Metadata memory
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
......@@ -25,6 +28,8 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int *);
EP_HOST_ASSERT(enable_shrink == false);
if (enable_dispatch_ll_layered)
EP_HOST_ASSERT(enable_combine_overlap == true);
// Common checks
EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and
......@@ -1274,7 +1279,8 @@ Buffer::internode_combine(
void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts, int quant_group_size) {
EP_HOST_ASSERT(low_latency_mode);
auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size);
auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered, quant_group_size);
auto clean_meta_0 = layout.buffers[0].clean_meta();
auto clean_meta_1 = layout.buffers[1].clean_meta();
......@@ -1311,7 +1317,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto num_local_experts = num_experts / num_ranks;
// Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size);
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, enable_dispatch_ll_layered, quant_group_size);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
......@@ -1336,7 +1342,16 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
}
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(packed_recv_x_dtype));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
torch::Dtype dtype = torch::kInt32;
if(enable_dispatch_ll_layered or enable_combine_overlap){
dtype = torch::kInt64;
}
auto packed_recv_src_info = torch::empty(
{num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank},
torch::dtype(dtype).device(torch::kCUDA)
);
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
......@@ -1371,10 +1386,12 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
}
if(!enable_dispatch_ll_layered){
// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) {
internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
internode_ll::dispatch(
packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(),
global_atomic_counter.data_ptr<int>(),
......@@ -1406,17 +1423,72 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
// Return values
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
} else {
// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) {
internode_ll::dispatch_ll_layered(
!enable_dispatch_ll_layered,
packed_recv_x.data_ptr(), packed_recv_x_scales_ptr,
packed_recv_src_info.data_ptr<int64_t>(), packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(),
global_atomic_counter.data_ptr<int>(),
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
quant_type, quant_group_size, fp8_round_scale,
workspace, num_device_sms, launch_stream, phases);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
// Wait streams
std::optional<EventHandle> event;
if (async) {
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream);
} else if (not return_recv_hook) {
stream_wait(compute_stream, launch_stream);
}
// Receiver callback
std::optional<std::function<void()>> recv_hook = std::nullopt;
if (return_recv_hook)
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
// Return values
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook};
}
}
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& packed_recv_count,
const std::optional<torch::Tensor>& comp_signal,
int block_m, int threshold, int num_sms,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt,
bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) {
EP_HOST_ASSERT(low_latency_mode);
// combine overlap checks
EP_HOST_ASSERT((!enable_combine_overlap || return_recv_hook) and "Overlap mode requires return_recv_hook=True"); // 启用 overlap 时, 必须 hook = True
EP_HOST_ASSERT((!enable_combine_overlap || packed_recv_count.has_value()) && "Overlap mode requires packed_recv_count has value");
EP_HOST_ASSERT((!enable_combine_overlap || comp_signal.has_value()) && "Overlap mode requires comp_signal has value");
EP_HOST_ASSERT((!enable_combine_overlap || block_m != -1) && "Overlap mode requires block_m != -1");
EP_HOST_ASSERT((!enable_combine_overlap || threshold != -1) && "Overlap mode requires threshold != -1");
EP_HOST_ASSERT((!enable_combine_overlap || num_sms != -1) && "Overlap mode requires num_sms != -1");
if (comp_signal.has_value()) {
EP_HOST_ASSERT(comp_signal->dim() == 1 and comp_signal->is_contiguous());
EP_HOST_ASSERT(comp_signal->scalar_type() == torch::kInt32);
EP_HOST_ASSERT(comp_signal->size(0) == num_experts / num_ranks * ((num_ranks * num_max_dispatch_tokens_per_rank + 63) / 64));
}
// Tensor checks
EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16);
......@@ -1430,7 +1502,12 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank);
EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32);
EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous());
if (!enable_dispatch_ll_layered && !enable_combine_overlap) {
EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0));
} else {
EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt64 and x.size(0) == src_info.size(0));
}
EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);
......@@ -1446,7 +1523,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
auto global_atomic_counter = torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
// Buffer control
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, enable_dispatch_ll_layered);
EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes);
auto buffer = layout.buffers[low_latency_buffer_idx];
auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1];
......@@ -1472,8 +1549,10 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
if(!enable_combine_overlap) {
auto launcher = [=](int phases) {
internode_ll::combine(combined_x.data_ptr(),
internode_ll::combine(
combined_x.data_ptr(),
buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer,
buffer.combine_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
......@@ -1506,10 +1585,55 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
// Return values
return {combined_x, event, recv_hook};
} else {
auto launcher = [=](int phases) {
internode_ll::combine_sbo(
combined_x.data_ptr(),
buffer.combine_rdma_recv_data_buffer,
buffer.combine_rdma_recv_flag_buffer,
buffer.combine_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
src_info.data_ptr<int64_t>(), layout_range.data_ptr<int64_t>(),
/* ll_layered 新增参数 */
!enable_dispatch_ll_layered,
/* overlap 新增参数 */
packed_recv_count.has_value() ? packed_recv_count->data_ptr<int>() : nullptr,
comp_signal.has_value() ? comp_signal->data_ptr<int>() : nullptr,
block_m, threshold, num_sms,
/* 辅助tensor */
global_atomic_counter.data_ptr<int>(),
combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
next_clean_meta.first, next_clean_meta.second,
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
workspace, num_device_sms, launch_stream,
phases, zero_copy);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
// Wait streams
std::optional<EventHandle> event;
if (async) {
// NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens,
// so in Python API, we must wrap all tensors into the event handle.
event = EventHandle(launch_stream);
} else if (not return_recv_hook) {
stream_wait(compute_stream, launch_stream);
}
// Receiver callback
std::optional<std::function<void()>> recv_hook = std::nullopt;
if (return_recv_hook)
recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); };
// Return values
return {combined_x, event, recv_hook};
}
}
torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, enable_dispatch_ll_layered);
auto buffer = layout.buffers[low_latency_buffer_idx];
auto dtype = torch::kBFloat16;
auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16));
......@@ -1540,7 +1664,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait);
pybind11::class_<deep_ep::Buffer>(m, "Buffer")
.def(pybind11::init<int, int, int64_t, int64_t, bool, bool, bool>())
.def(pybind11::init<int, int, int64_t, int64_t, bool, bool, bool, bool, bool>())
.def("is_available", &deep_ep::Buffer::is_available)
.def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks)
.def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank)
......
......@@ -35,6 +35,8 @@ private:
// Shrink mode buffer
bool enable_shrink = false;
bool enable_dispatch_ll_layered = false;
bool enable_combine_overlap = false;
int* mask_buffer_ptr = nullptr;
int* sync_buffer_ptr = nullptr;
......@@ -77,7 +79,8 @@ private:
public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes,
bool low_latency_mode, bool explicitly_destroy, bool enable_shrink);
bool low_latency_mode, bool explicitly_destroy, bool enable_shrink,
bool enable_dispatch_ll_layered, bool enable_combine_overlap);
~Buffer() noexcept(false);
......@@ -183,6 +186,9 @@ public:
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& packed_recv_count,
const std::optional<torch::Tensor>& comp_signal,
int block_m, int threshold, int num_sms,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt,
......
......@@ -150,6 +150,20 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int quant_type, int group_size, bool fp8_round_scale,
void* workspace, int num_device_sms, hipStream_t stream, int phases);
void dispatch_ll_layered(bool dispatch_ll_dispatch_opt,
void* packed_recv_x, void* packed_recv_x_scales,
int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int quant_type, int quant_group_size, bool fp8_round_scale,
void* workspace, int num_device_sms,
hipStream_t stream, int phases);
void combine(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
......@@ -163,6 +177,24 @@ void combine(void* combined_x,
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy);
void combine_sbo(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int64_t* src_info, const int64_t* layout_range,
// Overlap 新增控制参数
bool disable_ll_layered,
int* packed_recv_count, int* comp_signal,
int block_m, int threshold, int num_sms,
// 同步与统计参数
int* global_atomic_counter,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int,
// 维度与配置参数
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
// 系统资源与执行参数
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy);
} // namespace internode_ll
} // namespace deep_ep
......@@ -1086,6 +1086,1041 @@ void combine(void* combined_x,
#undef COMBINE_LAUNCH_CASE
}
template <int kHidden, int kQuantType=0, int kQuantGroupSize=0, int kMaxNumWarps=16>
__global__ __launch_bounds__(16 * kWarpSize, 1) void
dispatch_ll_layered(
bool disable_ll_layered,
void* packed_recv_x, void* packed_recv_x_scales,
int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group,
bool fp8_round_scale, int phases) {
// 定义量化类型的枚举
enum class QuantType {
None = 0, // 不进行量化
Int8 = 1, // 采用 Int8 量化
FP8_E4M3 = 2, // 采用 FP8 量化 __HIP_E4M3
FP8_UE8M0 = 3, // 采用 FP8 量化 DeepseekV3.1的 UE8M0
FP8_E5M2 = 4 // 采用 FP8 量化 __HIP_E5M2
};
const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
const auto num_sms = static_cast<int>(gridDim.x);
const auto num_warps = num_warp_groups * num_warps_per_group;
const auto num_local_experts = num_experts / num_ranks;
const auto warp_group_id = warp_id / num_warps_per_group;
const auto sub_warp_id = warp_id % num_warps_per_group;
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id;
char* rdma_recv_x_cahr_ptr = reinterpret_cast<char*>(rdma_recv_x);
const auto num_nvl_ranks = NUM_MAX_NVL_PEERS;
const auto num_nodes = num_ranks / num_nvl_ranks;
int* data_ready_counter = reinterpret_cast<int*>(rdma_recv_count + num_experts);
int* data_ready_send_buffer =
data_ready_counter + num_nodes * num_max_dispatch_tokens_per_rank * num_nvl_ranks;
int* next_clean_data_ready_counter = reinterpret_cast<int*>(next_clean + num_experts);
if (!disable_ll_layered) {
if (thread_id < num_nvl_ranks) {
__hip_atomic_store(data_ready_send_buffer + thread_id, 2, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_SYSTEM);
}
}
__syncthreads();
// May extract UE8M0 from the scales
constexpr bool kUseQuant8Bit = kQuantType > 0;
constexpr bool kUseUE8M0 = kQuantType == 3; // QuantType::FP8_UE8M0
using scale_t = std::conditional_t<kUseUE8M0, uint8_t, float>;
using packed_t = std::conditional_t<kUseUE8M0, uint32_t, float>;
EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length");
// FP8 staffs
constexpr int kNumPerChannels = QUANTIZATION_GROUPSIZE;
constexpr int kNumScales = kHidden / kNumPerChannels;
const size_t hidden_bytes = kHidden * (kUseQuant8Bit ? sizeof(__hip_fp8_storage_t) : sizeof(hip_bfloat16));
const size_t hidden_int4 = hidden_bytes / sizeof(int4);
// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
using vec_t = typename std::conditional<kUseQuant8Bit, int2, int4>::type;
const size_t num_bytes_per_meta = sizeof(int4);
const size_t num_bytes_per_data = (kUseQuant8Bit ? (kHidden + (kQuantGroupSize == 0 ? 4 : kNumScales) * sizeof(float)) : (kHidden * sizeof(hip_bfloat16)));
const size_t num_bytes_per_msg = num_bytes_per_meta + num_bytes_per_data;
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);
char* rdma_recv_x_meta = rdma_recv_x_cahr_ptr;
char* rdma_recv_x_data = rdma_recv_x_cahr_ptr + num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_meta;
// Expert counts
__shared__ int shared_num_tokens_sent_per_expert[kMaxNumWarps];
// Sending phase
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_DISPATCH_RECV;
// There are 2 kinds of warps in this part:
// 1. The first-kind warps for FP8 cast and sending top-k tokens
// 2. The last warp for reading `topk_idx` and count for per-expert information
if (warp_id < num_warps) {
constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(hip_bfloat16);
constexpr int kNumThreadPerGroup = QUANTIZATION_GROUPSIZE / kNumElemsPerRead;
// EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0);
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization");
const auto num_threads = num_warps * kWarpSize;
constexpr int hidden_bf16_int4 = kHidden / kNumElemsPerRead;
for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) {
const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4;
const auto rdma_x_src_idx = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg);
const auto rdma_x_vec = reinterpret_cast<vec_t*>(reinterpret_cast<uint8_t*>(rdma_x_src_idx) + sizeof(int4));
const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_vec) + hidden_bytes);
// Overlap top-k index read and source token index write
auto dst_expert_idx = warp_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1;
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;
// 用于记录per-channel量化的amax
__shared__ float channel_amaxf[kNumScales];
if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
if (thread_id < kNumScales) {
channel_amaxf[thread_id] = 0.0;
}
__syncthreads();
}
// FP8 cast
#pragma unroll
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
// Read
auto int4_value = __ldg(x_int4 + i);
if constexpr(kUseQuant8Bit) {
// Calculate local amax
auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
float fp32_values[kNumElemsPerRead];
float amax = 0.0, scale, scale_inv;
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++ j) {
fp32_values[j] = static_cast<float>(bf16_values[j]);
amax = fmaxf(amax, fabsf(fp32_values[j]));
}
// Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize / kNumPerChannels == 4, "Invalid vectorization");
amax = warp_reduce_max<kNumThreadPerGroup>(amax);
const int scale_offset = i * kNumElemsPerRead / QUANTIZATION_GROUPSIZE;
if constexpr(kQuantGroupSize == 0) {
// 记录每128个数的最大值
channel_amaxf[scale_offset] = fmaxf(amax, channel_amaxf[scale_offset]);
} else {
calculate_quant8bit_scales<kQuantType>(amax, scale, scale_inv, fp8_round_scale);
if (lane_id % kNumThreadPerGroup == 0)
rdma_x_scales[scale_offset] = scale_inv;
// Cast into send buffer
vec_t int2_value;
pack_quantized_values<kQuantType, kNumElemsPerRead>(fp32_values, scale, int2_value);
rdma_x_vec[i] = int2_value;
}
} else {
// Reinterpret-cast is for C++14 compatibility
rdma_x_vec[i] = *reinterpret_cast<vec_t*>(&int4_value);
}
}
__syncthreads();
if constexpr(kUseQuant8Bit && kQuantGroupSize == 0) {
float amax_per_token = 0.0;
// 并行规约,计算每个token的amax
for (int s = 0; s < kNumScales; s+=kWarpSize) {
int src_idx = s + lane_id;
float tmp_amaxf = 0;
if(src_idx < kNumScales) {
tmp_amaxf = channel_amaxf[src_idx];
}
tmp_amaxf = warp_reduce_max<kWarpSize>(tmp_amaxf);
channel_amaxf[0] = fmaxf(tmp_amaxf, channel_amaxf[0]);
__syncthreads();
}
amax_per_token = channel_amaxf[0];
// 根据最大值计算scale
float scale, scale_inv;
calculate_quant8bit_scales<kQuantType>(amax_per_token, scale, scale_inv, fp8_round_scale);
if (thread_id == 0) {
rdma_x_scales[0] = scale_inv;
}
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
// Read
auto int4_value = __ldg(x_int4 + i);
auto bf16_values = reinterpret_cast<hip_bfloat16*>(&int4_value);
// Cast into send buffer
vec_t int2_value;
pack_quantized_values<kQuantType, kNumElemsPerRead>(bf16_values, scale, int2_value);
rdma_x_vec[i] = int2_value;
}
__syncthreads();
}
// Issue IBGDA sends
if (dst_expert_idx >= 0) {
int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0;
slot_idx = shfl_sync(slot_idx, 0);
const auto dst_rank = dst_expert_idx / num_local_experts;
const auto dst_expert_local_idx = dst_expert_idx % num_local_experts;
if(!disable_ll_layered){
int send_node_id = dst_expert_idx / num_local_experts / num_nvl_ranks;
auto real_write_dst_rank = dst_rank / num_nvl_ranks * num_nvl_ranks +
rank % num_nvl_ranks; // send data to same gpu_device_id_rank(same-rail rdma traffic)
auto real_dst_expert_id = real_write_dst_rank * num_local_experts + dst_expert_local_idx;
auto tmp_dst_expert_id = lane_id < num_topk ? static_cast<int>(__ldg(topk_idx + token_idx * num_topk + lane_id)) : -1;
auto tmp_dst_node_id = tmp_dst_expert_id >= 0 ? tmp_dst_expert_id / num_local_experts / num_nvl_ranks : -1;
for (int i = 0; i < warp_id; ++i) {
auto dst_node_id = shfl_sync(tmp_dst_node_id, i); // broadcast
if (dst_node_id == send_node_id) { // whether to send repeatedly
send_node_id = -1;
break;
}
}
if (send_node_id != -1) {
// ======================================= token data ==========================================
int* src_data_ptr = rdma_x_src_idx + 4;
char* dst_data_ptr = rdma_recv_x_data +
(rank / num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_bytes_per_data +
token_idx * num_bytes_per_data;
const auto p2p_data_ptr = internode::shmem_get_p2p_ptr((void*)(dst_data_ptr), rank, real_write_dst_rank);
if (!p2p_data_ptr) {
internode_ll_putmem_nbi(
reinterpret_cast<void*>(dst_data_ptr), reinterpret_cast<void*>(src_data_ptr),
num_ranks, real_write_dst_rank, dst_expert_local_idx, num_bytes_per_data);
} else {
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_data_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(p2p_data_ptr);
UNROLLED_WARP_COPY_LL(8, lane_id, num_bytes_per_data / sizeof(int4), dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
}
// ======================================== token data flag =======================================
uint64_t src_data_flag_ptr = reinterpret_cast<uint64_t>(data_ready_send_buffer);
const auto data_ready_counter_ptr = reinterpret_cast<uint64_t>(data_ready_counter) +
(rank / num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_nvl_ranks * sizeof(int) +
token_idx * num_nvl_ranks * sizeof(int);
uint64_t data_ready_counter_p2p_ptr = internode::shmem_get_p2p_ptr((void*)(data_ready_counter_ptr), rank, real_write_dst_rank);
if (data_ready_counter_p2p_ptr == 0) {
// internode::shmemx_int8_put_nbi_warp_refactoring(
// reinterpret_cast<signed char*>(data_ready_counter_ptr), reinterpret_cast<signed char*>(src_data_flag_ptr),
// num_nvl_ranks * sizeof(int), num_ranks + dst_expert_local_idx * num_ranks + real_write_dst_rank, rank, real_write_dst_rank, true);
internode_ll_putmem_nbi(
reinterpret_cast<void*>(data_ready_counter_ptr), reinterpret_cast<void*>(src_data_flag_ptr),
num_ranks, real_write_dst_rank, dst_expert_local_idx, num_nvl_ranks * sizeof(int));
} else {
int* dst_int_ptr = reinterpret_cast<int*>(data_ready_counter_p2p_ptr);
if(lane_id < num_nvl_ranks){
__hip_atomic_store(dst_int_ptr + lane_id, 2, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_SYSTEM);
}
}
}
// ========================= meta data=============================
const auto src_meta_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
const auto dst_meta_ptr = reinterpret_cast<uint64_t>(rdma_recv_x_meta) +
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_meta +
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_meta +
slot_idx * num_bytes_per_meta;
uint64_t p2p_meta_ptr = internode::shmem_get_p2p_ptr((void*)(dst_meta_ptr), rank, dst_rank);
if (!p2p_meta_ptr) {
// internode::shmemx_int8_put_nbi_warp_refactoring(
// reinterpret_cast<signed char*>(dst_meta_ptr), reinterpret_cast<signed char*>(src_meta_ptr),
// num_bytes_per_meta, num_ranks + dst_expert_local_idx * num_ranks + dst_rank, rank, dst_rank, true);
internode_ll_putmem_nbi(
reinterpret_cast<void*>(dst_meta_ptr), reinterpret_cast<void*>(src_meta_ptr),
num_ranks, dst_rank, dst_expert_local_idx, num_bytes_per_meta);
} else {
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_meta_ptr);
int4* dst_int4_ptr = reinterpret_cast<int4*>(p2p_meta_ptr);
if(lane_id==0){
dst_int4_ptr[0] = src_int4_ptr[0];
}
}
syncwarp();
lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + real_dst_expert_id, 1) : 0;
} else {
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_src_idx);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) +
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg;
// 通过 shmem_get_p2p_ptr 获取 当前远程指针能否可达
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
if (p2p_ptr == 0) { // RDMA
internode_ll_putmem_nbi((void*)dst_ptr, (void*)src_ptr,
num_ranks, dst_rank, dst_expert_local_idx,
num_bytes_per_msg);
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(p2p_ptr);
UNROLLED_WARP_COPY_LL(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
}
// Increase counter after finishing
syncwarp();
lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
}
}
}
}
if (warp_id == num_warps - 1) {
// EP_DEVICE_ASSERT(num_sms > 1);
if (sm_id == 0) {
if (disable_ll_layered) {
// The first SM is also responsible for checking QPs
// The first SM is also responsible for cleaning the next buffer
#pragma unroll
for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
next_clean[i] = 0;
// Notify before executing `int_p`
syncwarp();
#pragma unroll
for (int i = lane_id; i < num_experts; i += kWarpSize)
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
}
}
// This SM should be responsible for some destination experts, read `topk_idx` for them
int expert_count[kMaxNumWarps] = {0};
int waiting_flag[kMaxNumWarps] = {0};
const auto expert_begin_idx = sm_id * num_warp_groups;
const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);
// Per lane count
#pragma unroll 8
for (int i = lane_id; i < num_tokens * num_topk; i += kWarpSize) {
auto idx = static_cast<int>(__ldg(topk_idx + i));
if (idx >= expert_begin_idx and idx < expert_end_idx)
expert_count[idx - expert_begin_idx] ++;
if (!disable_ll_layered) {
if (idx < 0)
continue;
const auto dst_rank = idx / num_local_experts;
const auto dst_expert_local_idx = idx % num_local_experts;
auto real_write_dst_rank = dst_rank / num_nvl_ranks * num_nvl_ranks + rank % num_nvl_ranks;
auto real_dst_expert_id = real_write_dst_rank * num_local_experts + dst_expert_local_idx;
if (real_dst_expert_id >= expert_begin_idx and real_dst_expert_id < expert_end_idx)
waiting_flag[real_dst_expert_id - expert_begin_idx] ++;
}
}
// Warp reduce
#pragma unroll
for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
auto waiting_flag_sum = 0;
if (!disable_ll_layered) { // only open ll dispatch opt, should do
waiting_flag_sum = warp_reduce_sum(waiting_flag[i - expert_begin_idx]);
}
if (lane_id == 0) {
shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - waiting_flag_sum - sum);
}
}
}
if (!disable_ll_layered and sm_id == num_sms - 1) {
// The first SM is also responsible for cleaning the next buffer
for (int i = thread_id; i < num_experts; i += blockDim.x) // clean for combine
next_clean[i] = 0;
// clean data ready flag
for (int i = thread_id; i < num_max_dispatch_tokens_per_rank * num_ranks; i += blockDim.x) {
int token_idx = i / num_ranks;
int rank_id = i % num_ranks;
auto node_id = rank_id / num_nvl_ranks;
auto nvl_rank_id = rank_id % num_nvl_ranks;
auto* data_ready_flag_ptr = reinterpret_cast<int*>(next_clean_data_ready_counter) +
node_id * num_max_dispatch_tokens_per_rank * num_nvl_ranks + token_idx * num_nvl_ranks + rank % num_nvl_ranks;
EP_DEVICE_ASSERT(data_ready_flag_ptr - next_clean_data_ready_counter <
num_max_dispatch_tokens_per_rank * num_nodes * num_nvl_ranks * sizeof(int));
const auto data_ready_p2p_src_ptr =
internode::shmem_get_p2p_ptr((void*)(data_ready_flag_ptr), rank, rank / num_nvl_ranks * num_nvl_ranks + nvl_rank_id);
reinterpret_cast<int*>(data_ready_p2p_src_ptr)[0] = 0;
}
__syncthreads();
#pragma unroll
for (int i = thread_id; i < num_experts; i += blockDim.x)
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
}
__syncthreads();
// Issue count sends
if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
const auto dst_rank = responsible_expert_idx / num_local_experts;
const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups];
// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
auto dst_ptr = rdma_recv_count + dst_expert_local_idx * num_ranks + rank;
// 通过 shmem_get_p2p_ptr 获取 当前远程指针能否可达
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
if (p2p_ptr == 0) { // RDMA
internode_ll_long_atomic_add(dst_ptr, -num_tokens_sent - 1,
num_ranks, dst_rank, dst_expert_local_idx);
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
st_na_release(reinterpret_cast<int *>(p2p_ptr), -num_tokens_sent - 1);
}
// Clean workspace for next use
atomic_counter_per_expert[responsible_expert_idx] = 0;
atomic_finish_counter_per_expert[responsible_expert_idx] = 0;
// Clean `packed_recv_count`
if (dst_rank == 0)
packed_recv_count[dst_expert_local_idx] = 0;
}
syncwarp();
// Receiving phase
LOW_LATENCY_DISPATCH_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return;
// For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible
if (phases & LOW_LATENCY_SEND_PHASE){
grid_barrier(global_atomic_counter, num_sms);
}
// 16 is the max possible number of warps in AMD GPUs
constexpr int num_sync_large_iteration = kMaxNumWarps ;
__shared__ volatile int sync_large_warp_counters[num_sync_large_iteration];
#pragma unroll
for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) {
sync_large_warp_counters[i] = 0;
}
__syncthreads();
// Receiving and packing
if (responsible_expert_idx < num_experts) {
const auto src_rank = responsible_expert_idx / num_local_experts;
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
uint8_t* rdma_recv_x_uint8 = nullptr;
if (!disable_ll_layered) {
rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x_meta) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_meta +
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_meta;
}
if (disable_ll_layered) {
rdma_recv_x_uint8 = reinterpret_cast<uint8_t*>(rdma_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg;
}
const auto recv_x_int4 = reinterpret_cast<int4*>(packed_recv_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4;
const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks;
const auto num_aligned_scales = ALIGN<int>(kNumScales, sizeof(float) / sizeof(scale_t));
const auto recv_x_scales = static_cast<scale_t*>(packed_recv_x_scales) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank *
(kQuantGroupSize == 0 ? 1 : num_aligned_scales);
// Shared between sub-warps in warp groups
__shared__ int shared_num_recv_tokens[kMaxNumWarps], shared_recv_token_begin_idx[kMaxNumWarps];
// Wait tokens to arrive
// NOTES: using sub-warp 1 to overlap with sub-warp 0
int num_recv_tokens, recv_token_begin_idx;
// EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (sub_warp_id == 1 and lane_id == 0) {
while ((num_recv_tokens = ld_acquire_global(reinterpret_cast<int*>(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0);
num_recv_tokens = -num_recv_tokens - 1;
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
}
// no needs to reset because there is no iteration
if (lane_id == 0){
volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
}
syncwarp();
while (sync_large_warp_counters[warp_group_id] < num_warps_per_group);
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id];
const auto real_read_src_rank = src_rank % num_nvl_ranks + rank / num_nvl_ranks * num_nvl_ranks;
// Copy tokens
EP_STATIC_ASSERT(kNumScales <= 64, "Invalid hidden size");
for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) {
int4* src_data = nullptr;
if (!disable_ll_layered) {
int* src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_meta);
int src_token_idx = __builtin_nontemporal_load(src_src_idx);
if (lane_id == 0) {
recv_src_info[recv_token_begin_idx + i] = pack2<int, int64_t>(src_token_idx, src_rank);
}
const auto data_ready_flag_src_ptr = data_ready_counter +
(src_rank / num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_nvl_ranks +
src_token_idx * num_nvl_ranks +
rank % num_nvl_ranks;
const auto src_data_ready_flag_p2p_ptr =
reinterpret_cast<int*>(internode::shmem_get_p2p_ptr((void*)(data_ready_flag_src_ptr), rank, real_read_src_rank));
if (lane_id == 0) {
int tmp = 0;
auto start_time = clock64();
bool flag_get = false;
while (tmp != 2) {
tmp = __hip_atomic_load(src_data_ready_flag_p2p_ptr, __ATOMIC_SEQ_CST, __HIP_MEMORY_SCOPE_SYSTEM);
if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) {
printf(
"DeepEP ll dispatch recv data timeout, src_rank:%d, dst_rank: %d, real_read_src_rank:%d,src_token_idx:%d "
"dst RDMA lane: %d, num_recv_tokens: %d\n",
src_rank,
rank,
real_read_src_rank,
src_token_idx,
lane_id,
num_recv_tokens
);
break;
}
}
}
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_recv_x_data) +
(src_rank / num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_bytes_per_data
+ src_token_idx * num_bytes_per_data;
uint64_t src_ptr_p2p = internode::shmem_get_p2p_ptr((void*)(src_ptr), rank, real_read_src_rank);
src_data = reinterpret_cast<int4*>(src_ptr_p2p);
}
if (disable_ll_layered) {
const auto src_src_idx = reinterpret_cast<int*>(rdma_recv_x_uint8 + i * num_bytes_per_msg);
int src_token_idx = __builtin_nontemporal_load(src_src_idx);
if (lane_id == 0)
// 加入 源rank 信息
recv_src_info[recv_token_begin_idx + i] = pack2<int, int64_t>(src_token_idx, src_rank);
syncwarp();
// Copy data
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));
}
const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
// Copy scales
if constexpr(kUseQuant8Bit) {
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
const auto token_idx = recv_token_begin_idx + i;
const auto token_stride = num_elems_per_pack;
const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack;
if constexpr(kQuantGroupSize == 0) {
if (lane_id == 0) {
recv_x_scales[token_idx] = ld_nc_global(src_scales);
}
} else {
if (lane_id < kNumScales) {
const auto pack_idx = lane_id / num_elems_per_pack;
const auto elem_idx = lane_id % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
if (lane_id + kWarpSize < kNumScales) {
const auto pack_idx = (lane_id + kWarpSize) / num_elems_per_pack;
const auto elem_idx = (lane_id + kWarpSize) % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + kWarpSize));
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
}
}
}
}
}
void dispatch_ll_layered(bool dispatch_ll_dispatch_opt,
void* packed_recv_x, void* packed_recv_x_scales,
int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* global_atomic_counter,
void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int64_t* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int quant_type, int quant_group_size, bool fp8_round_scale,
void* workspace, int num_device_sms,
hipStream_t stream, int phases) {
constexpr int kMaxNumWarps = 16;
constexpr int kNumMaxTopK = 11;
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = kMaxNumWarps / num_warp_groups;
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);
const auto num_warps = num_warp_groups * num_warps_per_group;
const auto num_sms = ceil_div(num_experts, num_warp_groups);
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
// Workspace checks
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts;
EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES);
// 限制groupsize的大小
EP_HOST_ASSERT(quant_group_size == 0 || quant_group_size == 128);
/*量化类型枚举
0 -> None 不量化,保持原始精度
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2)
*/
#define DISPATCH_LL_LAUNCH_CASE(hidden) \
{ \
auto dispatch_func = dispatch_ll_layered<hidden, 0, 0, kMaxNumWarps>; \
if (quant_group_size == 0) { \
switch (quant_type) { \
case 1: dispatch_func = dispatch_ll_layered<hidden, 1, 0, kMaxNumWarps>; break; \
case 2: dispatch_func = dispatch_ll_layered<hidden, 2, 0, kMaxNumWarps>; break; \
case 3: dispatch_func = dispatch_ll_layered<hidden, 3, 0, kMaxNumWarps>; break; \
case 4: dispatch_func = dispatch_ll_layered<hidden, 4, 0, kMaxNumWarps>; break; \
} \
} else { \
switch (quant_type) { \
case 1: dispatch_func = dispatch_ll_layered<hidden, 1, 128, kMaxNumWarps>; break; \
case 2: dispatch_func = dispatch_ll_layered<hidden, 2, 128, kMaxNumWarps>; break; \
case 3: dispatch_func = dispatch_ll_layered<hidden, 3, 128, kMaxNumWarps>; break; \
case 4: dispatch_func = dispatch_ll_layered<hidden, 4, 128, kMaxNumWarps>; break; \
} \
} \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, dispatch_func, dispatch_ll_dispatch_opt, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, packed_recv_count, \
global_atomic_counter, \
rdma_recv_x, rdma_recv_count, rdma_x, x, topk_idx, \
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
next_clean, num_next_clean_int, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, fp8_round_scale, phases); \
} \
break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
SWITCH_HIDDEN(DISPATCH_LL_LAUNCH_CASE);
#undef DISPATCH_LL_LAUNCH_CASE
}
/*
combine 启用 overlop 后的实现
*/
template <int kHidden, int kNumMaxTopk, int kMaxNumWarps=16>
__global__ __launch_bounds__(16 * kWarpSize, 1) void
combine_sbo(bool disable_ll_layered,
void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int64_t* src_info, const int64_t* layout_range,
// Overlap specific parameters
int* packed_recv_count, int* comp_signal, int block_m, int threshold,
int* global_atomic_counter,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int,
int* atomic_clean_flag, int* atomic_finish_counter_per_expert,
int num_combined_tokens, int hidden, int num_topk,
int num_max_dispatch_tokens_per_rank,
int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group,
int phases, bool zero_copy) {
// 假设 启用 3 个block
const auto sm_id = static_cast<int>(blockIdx.x);
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto num_threads = static_cast<int>(blockDim.x);
const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id();
const auto num_local_experts = num_experts / num_ranks; // 16
const auto warp_group_id = warp_id / num_warps_per_group; // 0 0 0 ... 0
const auto sub_warp_id = warp_id % num_warps_per_group; // 0 1 2 ... 15
const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; // 这意味着 一次 并行处理 3个专家 0 1 2
int* next_clean_data_ready_counter = reinterpret_cast<int*>(next_clean + num_experts);
const auto num_nvl_ranks = NUM_MAX_NVL_PEERS;
const auto num_nodes = num_ranks / num_nvl_ranks;
// hidden_bf16_int4: bf16 的 token 包含多少个 int4
constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(hip_bfloat16);
const size_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4;
// Message package
EP_STATIC_ASSERT(kHidden % QUANTIZATION_GROUPSIZE == 0, "Invalid hidden");
constexpr size_t num_bytes_per_slot = kHidden * sizeof(hip_bfloat16);
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// Shared between warps in sms for overlap mode, where each sm only has one warp group
__shared__ volatile int shared_vaild_signal_prefix_sum[40]; // 用于统计 本地专家 有效信号 的 前缀和
// Sending phase
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_COMBINE_RECV;
if (!disable_ll_layered and sm_id == num_sms - 1) {
#pragma unroll
for (int i = thread_id; i < num_experts; i += num_threads)
next_clean[i] = 0;
// clean data ready flag
for (int i = thread_id; i < num_max_dispatch_tokens_per_rank * num_ranks; i += num_threads) {
int token_idx = i / num_ranks;
int rank_id = i % num_ranks;
{
auto node_id = rank_id / num_nvl_ranks;
auto nvl_rank_id = rank_id % num_nvl_ranks;
auto* data_ready_flag_ptr = reinterpret_cast<int*>(next_clean_data_ready_counter) +
node_id * num_max_dispatch_tokens_per_rank * num_nvl_ranks + token_idx * num_nvl_ranks + rank % num_nvl_ranks;
EP_DEVICE_ASSERT(data_ready_flag_ptr - next_clean_data_ready_counter <
num_max_dispatch_tokens_per_rank * num_nodes * num_nvl_ranks * sizeof(int));
const auto data_ready_p2p_src_ptr =
internode::shmem_get_p2p_ptr((void*)(data_ready_flag_ptr), rank, rank / num_nvl_ranks * num_nvl_ranks + nvl_rank_id);
reinterpret_cast<int*>(data_ready_p2p_src_ptr)[0] = 0;
}
}
// Notify before executing `int_p`
__syncthreads();
if (thread_id == 0)
atomic_add_release_global(atomic_clean_flag, num_experts);
}
if (disable_ll_layered) {
// Clean up next buffer
if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) {
#pragma unroll
for (int i = lane_id; i < num_next_clean_int; i += kWarpSize)
next_clean[i] = 0;
// Notify before executing `int_p`
syncwarp();
if (lane_id == 0)
atomic_add_release_global(atomic_clean_flag, num_experts);
}
}
__syncthreads();
// ========================================
// shared_vaild_signal_sum: 本地专家的总信号量
// shared_local_expert_idx: 共享内存中的 本地专家索引。初始置为 0 , 表明 当前 block 当前在 处理的 本地专家索引
__shared__ int shared_vaild_signal_sum, shared_local_expert_idx;
// 计算每个 本地专家 有效信号 计数 的 前缀和,即使没有 token, 也算作一个 任务
if (sub_warp_id == 0 and lane_id == 0) { // 0号 warp 的 0号线程 执行下述操作
shared_vaild_signal_prefix_sum[0] = (packed_recv_count[0] == 0 ? 1 : ceil_div(packed_recv_count[0], block_m));
shared_local_expert_idx = 0; // 共享内存中 本地专家索引 置为 0
for (int i = 1; i < num_local_experts; i++) {
shared_vaild_signal_prefix_sum[i] =
shared_vaild_signal_prefix_sum[i - 1] + (packed_recv_count[i] == 0 ? 1 : ceil_div(packed_recv_count[i], block_m));
}
shared_vaild_signal_sum = shared_vaild_signal_prefix_sum[num_local_experts - 1];
}
__syncthreads(); // 等待前缀和 统计完成 16个 warp 同步等待
// 每个 block 负责一个 处理信号,并循环处理到 最后
for (int vaild_signal_idx = sm_id; vaild_signal_idx < shared_vaild_signal_sum; vaild_signal_idx += num_sms) {
// ====================== 16个 warp 进入 ======================
// 通过扫描前缀和数组找到当前处理的本地专家索引,并记录在 shared_local_expert_idx
if (sub_warp_id == 0 and lane_id == 0) {
while (vaild_signal_idx >= shared_vaild_signal_prefix_sum[shared_local_expert_idx])
shared_local_expert_idx++;
}
__syncthreads();
// ===========================================
// shared_local_expert_idx: 当前处理的任务块 是哪个本地专家
// 上述 操作 确定了 当前 block 负责处理的本地专家为 shared_local_expert_idx
// 需要依据 shared_local_expert_idx 本地索引确定其他 地址
const auto local_expert_idx = shared_local_expert_idx; // 当前处理 的 本地专家索引
const auto global_expert_idx = rank * num_local_experts + local_expert_idx; // 获取 本地专家 在全局中的索引
const auto local_x = static_cast<const int4*>(x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4;
const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank;
const auto rdma_send_x_vec = static_cast<uint8_t*>(rdma_send_x) +
local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot;
// ================================ 等待相应的 comp_signal 达到阈值
//----------------------- 确定 当前等待的信号量位置
// num_tokens_per_expert:当前 负责的专家 dispatch 阶段 接收的 总 token 数
// num_signal_per_expert:当前 负责的专家 需要等待的总 信号 数
// local_expert_signal_idx: 当前处理的信号总索引,是 当前处理专家的 第几个信号
int num_tokens_per_expert, num_signal_per_expert, local_expert_signal_idx;
const int* gemm_comp_signal;
num_tokens_per_expert = packed_recv_count[local_expert_idx]; // 当前专家 dispatch 阶段接收的 总 token 数
num_signal_per_expert = ceil_div(num_ranks * num_max_dispatch_tokens_per_rank, block_m); // 每个专家的 最大 信号数
local_expert_signal_idx =
(local_expert_idx == 0) ? vaild_signal_idx : vaild_signal_idx - shared_vaild_signal_prefix_sum[local_expert_idx - 1]; // 当前专家 中的 信号索引
gemm_comp_signal = comp_signal + num_signal_per_expert * local_expert_idx + local_expert_signal_idx;
//----------------------- 循环等待 信号量到达 阈值
if (sub_warp_id == 0 and lane_id == 0 and num_tokens_per_expert != 0) { // 当前专家 dispatch 阶段接收的 token 数 不是 0 的话,循环等待 信号量的值 到达 阈值
while (ld_acquire_global(gemm_comp_signal) != threshold)
;
}
__syncthreads();
// ============================== 发射 RDMA 指令 ==============================
// ------------------------------ 确定 处理的 token 起始位置 和 结束位置 -----------------
auto token_start_idx = local_expert_signal_idx * block_m;
auto token_end_idx = min((local_expert_signal_idx + 1) * block_m, num_tokens_per_expert);
// 16个 warp 每个warp 负责一个 token 的发射
for (int token_idx = sub_warp_id + token_start_idx; token_idx < token_end_idx; token_idx += num_warps_per_group) {
const auto x_int4 = local_x + token_idx * hidden_bf16_int4;
const auto rdma_send_type_row = reinterpret_cast<int*>(rdma_send_x_vec + token_idx * num_bytes_per_slot);
const auto rdma_send_x_vec_row = reinterpret_cast<uint8_t*>(rdma_send_type_row);
const auto dst_rank = static_cast<int>(__ldg(local_src_info + token_idx) >> 32);
const auto src_idx = static_cast<int>(__ldg(local_src_info + token_idx) & 0xffffffff);
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
if (p2p_ptr == 0) { // RDMA
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy){
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
}
internode_ll_putmem_nbi((void*)dst_ptr, (void*)buf_ptr,
num_ranks, dst_rank, local_expert_idx,
hidden * sizeof(hip_bfloat16));
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(x_int4);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(p2p_ptr);
UNROLLED_WARP_COPY_LL(7, lane_id, hidden_bf16_int4, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
}
}
__syncthreads(); // 等待 16 个 warp 都完成 RDMA 发射
// ================================= 当前所有 RDMA 下发完成后,判断是不是要 发射 完成的 flag=====================================
bool put_finish_flag = false; // 标记是不是要发射 RDMA 结束标记
// 判断是不是 到了 当前专家处理的 最后
if (sub_warp_id == 0) { //
if (lane_id == 0) {
const auto finish_counter = (num_tokens_per_expert == 0 ? 1 : ceil_div(num_tokens_per_expert, block_m)); // 获取当前专家 发送的 总 的信号数
if ((atomicAdd(atomic_finish_counter_per_expert + local_expert_idx, 1) + 1) == finish_counter)
put_finish_flag = true;
}
put_finish_flag = shfl_sync(put_finish_flag, 0);
}
__syncthreads();
// 通知其他 所有 rank,当前本地专家的 token 已经发射完成
if (sub_warp_id == 0 and put_finish_flag) {
for (int dst_rank = lane_id; dst_rank < num_ranks; dst_rank += 64) {
while (ld_acquire_global(atomic_clean_flag) == 0);
auto dst_ptr = rdma_recv_flag + global_expert_idx;
// 通过 shmem_get_p2p_ptr 获取 当前远程指针能否可达
uint64_t p2p_ptr = internode::shmem_get_p2p_ptr((void*)dst_ptr, rank, dst_rank);
if (p2p_ptr == 0) { // RDMA
internode_ll_long_atomic_add(dst_ptr, 1, num_ranks, dst_rank, local_expert_idx);
} else { // 本地 GPU 和 同一计算节点的 其他 GPU 地址
st_na_release(reinterpret_cast<int *>(p2p_ptr), 1);
}
atomic_add_release_global(atomic_clean_flag, -1);
}
if (lane_id == 0) // 清理 标记数组
atomic_finish_counter_per_expert[local_expert_idx] = 0;
}
__syncthreads();
}
// Receiving phase
LOW_LATENCY_COMBINE_RECV:
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return;
// Wait all ranks to arrive and notify PCIe usage
if (responsible_expert_idx < num_experts) {
// EP_DEVICE_ASSERT(num_warps_per_group > 1);
if (sub_warp_id == 0 and lane_id == 0) {
const auto src_rank = responsible_expert_idx / num_local_experts;
auto start_time = wall_clock64();
uint64_t wait_recv_cost = 0;
while (ld_acquire_global(reinterpret_cast<int*>(rdma_recv_flag + responsible_expert_idx)) == 0 // recv not ready
&& (wait_recv_cost = wall_clock64() - start_time) <= NUM_TIMEOUT_CYCLES // not timeout
);
// Mask rank if timeout
if (wait_recv_cost > NUM_TIMEOUT_CYCLES) {
printf("Warning: DeepEP timeout for combine receive, rank %d, local_expert_idx %d, src_rank %d\n",
rank, responsible_expert_idx % num_local_experts, src_rank);
}
if (combine_wait_recv_cost_stats != nullptr) {
atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats + src_rank), wait_recv_cost);
}
}
}
grid_barrier(global_atomic_counter, num_sms);
// Reduce tokens with FP8 cast
// EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads);
EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization");
if (thread_id < hidden_bf16_int4) {
for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
// Read top-k indices and weights
int reg_topk_idx[kNumMaxTopk];
float reg_topk_weights[kNumMaxTopk];
#pragma unroll
for (int i = 0; i < num_topk; ++ i) {
reg_topk_idx[i] = static_cast<int>(__ldg(topk_idx + token_idx * num_topk + i));
reg_topk_weights[i] = __ldg(topk_weights + token_idx * num_topk + i);
}
float combined_values[kNumElemsPerInt4] = {0.0f};
#pragma unroll
for (int i = 0; i < num_topk; ++ i) if (reg_topk_idx[i] >= 0) {
// Read from sources
auto rdma_buffer_type = reinterpret_cast<const int*>(reinterpret_cast<uint8_t*>(rdma_recv_x) +
(reg_topk_idx[i] * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot);
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
// Reduce
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
const auto x_bf16 = reinterpret_cast<hip_bfloat16*>(&x_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_values[j] += static_cast<float>(x_bf16[j]) * reg_topk_weights[i];
}
// Write results
int4& combined_int4 = *reinterpret_cast<int4*>(combined_values);
auto combined_bf16 = reinterpret_cast<hip_bfloat16*>(&combined_values);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_bf16[j] = static_cast<hip_bfloat16>(combined_values[j]);
(reinterpret_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
}
}
}
void combine_sbo(void* combined_x,
void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int64_t* src_info, const int64_t* layout_range,
// Overlap 新增控制参数
bool disable_ll_layered,
int* packed_recv_count, int* comp_signal,
int block_m, int threshold, int num_sms,
// 同步与统计参数
int* global_atomic_counter,
int64_t* combine_wait_recv_cost_stats,
int64_t* next_clean, int num_next_clean_int,
// 维度与配置参数
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
// 系统资源与执行参数
void* workspace, int num_device_sms, hipStream_t stream,
int phases, bool zero_copy) {
constexpr int kMaxNumWarps = 16;
constexpr int kNumMaxTopk = 11;
int num_warp_groups, num_warps_per_group, num_recv_per_sm, num_warps;
if (phases == LOW_LATENCY_SEND_PHASE) { // 如果启用 overlop 必须是 send 阶段
num_warp_groups = 1; // 一个 block 只有一个 warp 组
num_warps_per_group = 16; // 16 个 warp 每个 warp 64 线程
num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0 and block_m > 0 and threshold > 0);
num_warps = num_warp_groups * num_warps_per_group;
} else {
num_warp_groups = ceil_div(num_experts, num_device_sms);
num_warps_per_group = kMaxNumWarps / num_warp_groups;
num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms);
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0);
num_warps = num_warp_groups * num_warps_per_group;
num_sms = max(ceil_div(num_experts, num_warp_groups), num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm));
}
// Check workspace
auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
auto atomic_finish_counter_per_expert = atomic_clean_flag + 1; // overlop 新增使用
EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES);
EP_HOST_ASSERT(num_topk <= kNumMaxTopk);
#define COMBINE_OVERLOP_LAUNCH_CASE(hidden) \
{ \
auto combine_overlop_func = combine_sbo<hidden, kNumMaxTopk, kMaxNumWarps>; \
LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_overlop_func, \
disable_ll_layered, \
combined_x, rdma_recv_x, rdma_recv_flag, rdma_send_x, \
x, topk_idx, topk_weights, src_info, layout_range, \
packed_recv_count, comp_signal, block_m, threshold, \
global_atomic_counter, combine_wait_recv_cost_stats, \
next_clean, num_next_clean_int, \
atomic_clean_flag, atomic_finish_counter_per_expert, \
num_combined_tokens, hidden, \
num_topk, num_max_dispatch_tokens_per_rank, \
num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, phases, zero_copy); \
} \
break
SETUP_LAUNCH_CONFIG(num_sms, num_warps * kWarpSize, stream);
SWITCH_HIDDEN(COMBINE_OVERLOP_LAUNCH_CASE);
#undef COMBINE_OVERLOP_LAUNCH_CASE
}
} // namespace internode_ll
} // namespace deep_ep
......@@ -40,6 +40,8 @@ class Buffer:
allow_mnnvl: bool = False,
explicitly_destroy: bool = False,
enable_shrink: bool = False,
enable_dispatch_ll_layered: bool = False,
enable_combine_overlap: bool = False,
) -> None:
"""
Initialize the communication buffer.
......@@ -60,6 +62,8 @@ class Buffer:
otherwise, the resources will be released by the destructor.
Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically.
enable_dispatch_ll_layered: Enable low-latency mode with hierarchical dispatch operators.
enable_combine_overlap: deepgemm DOWN gemm overlop combine send
"""
check_nvlink_connections(group)
......@@ -72,6 +76,10 @@ class Buffer:
self.low_latency_mode = low_latency_mode
self.explicitly_destroy = explicitly_destroy
self.enable_shrink = enable_shrink
if enable_dispatch_ll_layered and enable_shrink: # Currently, the layered algorithm for ll dispatch has been optimized, so the shrink mode is no longer supported.
print("DeepEP [ERROR] not support shrink, disable it", flush=True)
enable_shrink = False
self.runtime = deep_ep_cpp.Buffer(
self.rank,
self.group_size,
......@@ -79,7 +87,9 @@ class Buffer:
num_rdma_bytes,
low_latency_mode,
explicitly_destroy,
enable_shrink
enable_shrink,
enable_dispatch_ll_layered,
enable_combine_overlap
)
# Synchronize device IDs
......@@ -212,7 +222,8 @@ class Buffer:
@staticmethod
def get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int, quant_group_size: int = 0
num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int,
enable_dispatch_ll_layered: bool = False, quant_group_size: int = 0
) -> int:
"""
Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16.
......@@ -228,7 +239,8 @@ class Buffer:
size: the RDMA buffer size recommended.
"""
return deep_ep_cpp.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts, quant_group_size
num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered, quant_group_size
)
def get_comm_stream(self) -> torch.Stream:
......@@ -921,9 +933,11 @@ class Buffer:
recv_x = (packed_recv_x, packed_recv_x_scales) if (quant_type > 0) else packed_recv_x
return recv_x, packed_recv_count, handle, EventOverlap(event, tensors_to_record if async_finish else None), hook
# noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, use_logfmt: bool = False,
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: tuple,
# combine sbo params
packed_recv_count: torch.Tensor = None, comp_signal: torch.Tensor = None,
block_m: int = -1, threshold: int = -1, num_sms: int = -1,
use_logfmt: bool = False,
zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \
......@@ -945,13 +959,13 @@ class Buffer:
topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched
tokens. The received tokens will be reduced with the weights in this tensor.
handle: the communication handle given by the `dispatch` function.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative
with `get_next_low_latency_combine_buffer`.
async_finish: the current stream will not wait for the communication kernels to be finished if set.
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits).
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
......@@ -964,6 +978,7 @@ class Buffer:
"""
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
packed_recv_count, comp_signal, block_m, threshold, num_sms,
combine_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts,
use_logfmt, zero_copy, async_finish, return_recv_hook, out)
......
#!/bin/bash
# rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=10737418240
export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# duSHMEM
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
export NVSHMEM_SYMMETRIC_SIZE=10737418240
# # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
# export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
# export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
......@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../
# test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
#!/bin/bash
# rocSHMEM
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export ROCSHMEM_HEAP_SIZE=10737418240
export ROCSHMEM_HEAP_SIZE=3737418240
export ROCSHMEM_TOPO_FILE_FORCE=./topo.config
# duSHMEM
export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
export NVSHMEM_SYMMETRIC_SIZE=10737418240
# # duSHMEM
# export LD_LIBRARY_PATH=/opt/dtk/dushmem/lib:$LD_LIBRARY_PATH
# export DEEP_EP_DEVICE_TO_HCA_MAPPING=0:mlx5_2:1,1:mlx5_3:1,2:mlx5_4:1,3:mlx5_5:1,4:mlx5_6:1,5:mlx5_7:1,6:mlx5_8:1,7:mlx5_9:1
# export NVSHMEM_SYMMETRIC_SIZE=10737418240
# common
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
......@@ -17,6 +18,7 @@ export PYTHONPATH=$(pwd)/../
# test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_internode.py --test-ll-compatibility
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py # --pressure-test
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --use-logfmt
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 ./test_low_latency.py --enable-dispatch-ll-layered --enable-combine-overlap
......@@ -34,6 +34,10 @@ def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], bu
assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks
def ceil_div(a, b):
return (a + b - 1) // b
def test_main(num_tokens: int,
hidden: int,
num_experts: int,
......@@ -42,11 +46,16 @@ def test_main(num_tokens: int,
num_ranks: int,
group: dist.ProcessGroup,
buffer: deep_ep.Buffer,
enable_dispatch_ll_layered: bool = False,
enable_combine_overlap: bool = False,
use_logfmt: bool = False,
seed: int = 0):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
print(f"enable_dispatch_ll_layered={enable_dispatch_ll_layered}, enable_combine_overlap={enable_combine_overlap}, use_logfmt={use_logfmt}")
assert not (use_logfmt and (enable_dispatch_ll_layered or enable_combine_overlap)), \
"use_logfmt=True and enable_dispatch_ll_layered/enable_combine_overlap conflict"
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
......@@ -84,10 +93,13 @@ def test_main(num_tokens: int,
hash_value, num_times = 0, 0
for x_i, current_x in enumerate(x_list):
for return_recv_hook in (False, True):
for quant_type in (0, 1, 2, 3, ): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
if enable_combine_overlap and (not return_recv_hook): # return_recv_hook 为False 时,不能启用 overlop
continue
for quant_type in (0, 1, 2, 3,): # 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant = quant_type > 0
for fp8_round_scale in (False, True) if quant_type != 3 else (True, ):
for quant_group_size in (0, 128,) if quant_type >= 2 else (0, ):
for fp8_round_scale in (False, True) if quant_type != 3 else (True,):
for quant_group_size in (0, 128,) if quant_type >= 2 else (0,):
if quant_type == 3 and (fp8_round_scale == False or quant_group_size == 0):
continue
......@@ -131,7 +143,12 @@ def test_main(num_tokens: int,
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_x_amax = recv_x[:, :-128].amax(dim=-1)
if (enable_dispatch_ll_layered or enable_combine_overlap):
recv_src_info = recv_src_info[:num_valid_tokens] & int_mask # 掩掉多余的信息
else:
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x_amax)
if dispatch_use_quant:
......@@ -148,6 +165,7 @@ def test_main(num_tokens: int,
if not fp8_round_scale:
assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item()
assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if dispatch_use_quant:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
......@@ -155,10 +173,32 @@ def test_main(num_tokens: int,
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness
for zero_copy in (False, ) if use_logfmt else (False, True, ):
for zero_copy in (False,) if use_logfmt else (False, True,):
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
if enable_combine_overlap:
block_m, threshold, num_sms = 64, 10, 3
total_num_per_expert = ceil_div(num_tokens * num_ranks, block_m) # 每个本地专家 总的信号数??
comp_signal = torch.zeros(num_local_experts * total_num_per_expert, dtype=torch.int32, device='cuda')
for i in range(num_local_experts):
vaild_num = ceil_div(packed_recv_count[i], block_m)
comp_signal[i * total_num_per_expert:i * total_num_per_expert + vaild_num] = threshold
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
handle,
packed_recv_count=packed_recv_count,
comp_signal=comp_signal,
block_m=block_m,
threshold=threshold,
num_sms=num_sms,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
else:
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x,
topk_idx,
topk_weights,
......@@ -168,6 +208,7 @@ def test_main(num_tokens: int,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
......@@ -180,6 +221,10 @@ def test_main(num_tokens: int,
print(f"data:{x_i}, return_recv_hook:{return_recv_hook}, quant_type:{quant_type}, ",
f"fp8_round_scale:{fp8_round_scale}, quant_group_size:{quant_group_size} pass")
print("deep_ep 全部正确性测试完成")
if enable_dispatch_ll_layered or enable_combine_overlap:
return hash_value
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
......@@ -242,7 +287,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_tokens, hidden = args.num_tokens, args.hidden
num_topk, num_experts = args.num_topk, args.num_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts)
enable_dispatch_ll_layered = args.enable_dispatch_ll_layered
enable_combine_overlap = args.enable_combine_overlap
if enable_dispatch_ll_layered:
enable_combine_overlap = True
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts,
enable_dispatch_ll_layered=enable_dispatch_ll_layered)
if local_rank == 0:
print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True)
buffer = deep_ep.Buffer(group,
......@@ -251,7 +302,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_qps_per_rank=num_experts // num_ranks,
allow_nvlink_for_low_latency_mode=not args.disable_nvlink,
explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl)
allow_mnnvl=args.allow_mnnvl,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap
)
print("deep_ep 初始化完成")
test_main(num_tokens,
hidden,
num_experts,
......@@ -261,6 +316,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group,
buffer,
use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=1)
do_pressure_test = args.pressure_test
......@@ -276,6 +333,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group,
buffer,
use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=seed)
for _ in range(20):
assert test_main(num_tokens,
......@@ -287,6 +346,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
group,
buffer,
use_logfmt=args.use_logfmt,
enable_dispatch_ll_layered=enable_dispatch_ll_layered,
enable_combine_overlap=enable_combine_overlap,
seed=seed) == ref_hash, f'Error: seed={seed}'
# Destroy the buffer runtime and communication group
......@@ -309,6 +370,10 @@ if __name__ == '__main__':
parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test')
parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode')
parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine')
# 新版 sbo 需要的
parser.add_argument('--enable-dispatch-ll-layered', action='store_true', help='Enable low-latency layered dispatch optimization')
parser.add_argument("--enable-combine-overlap", action='store_true', help='Enable GEMM-compute/communication overlap in the combine phase')
args = parser.parse_args()
num_processes = args.num_processes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment