Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
...@@ -24,12 +24,15 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: ...@@ -24,12 +24,15 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
num_tokens_range = [1, 16, 256, 4096] num_tokens_range = [1, 16, 256, 4096]
num_experts_range = [16, 64, 224, 256, 280, 512] num_experts_range = [16, 64, 224, 256, 280, 512]
topk_range = [1, 2, 8] topk_range = [1, 2, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) ep_size_range = [1, 8]
configs = list(
itertools.product(num_tokens_range, num_experts_range, topk_range, ep_size_range)
)
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"], x_names=["num_tokens", "num_experts", "topk", "ep_size"],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["vllm"], line_vals=["vllm"],
...@@ -38,16 +41,26 @@ configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range ...@@ -38,16 +41,26 @@ configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range
args={}, args={},
) )
) )
def benchmark(num_tokens, num_experts, topk, provider): def benchmark(num_tokens, num_experts, topk, ep_size, provider):
"""Benchmark function for Triton.""" """Benchmark function for Triton."""
block_size = 256 block_size = 256
torch.cuda.manual_seed_all(0)
topk_ids = get_topk_ids(num_tokens, num_experts, topk) topk_ids = get_topk_ids(num_tokens, num_experts, topk)
e_map = None
if ep_size != 1:
local_e = num_experts // ep_size
e_ids = torch.randperm(num_experts, device="cuda", dtype=torch.int32)[:local_e]
e_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "vllm": if provider == "vllm":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size(topk_ids, block_size, num_experts), lambda: moe_align_block_size(
topk_ids, block_size, num_experts, e_map, ignore_invalid_experts=True
),
quantiles=quantiles, quantiles=quantiles,
) )
......
...@@ -1246,14 +1246,8 @@ class AttentionMainLoop { ...@@ -1246,14 +1246,8 @@ class AttentionMainLoop {
// rescale sum and partial outputs // rescale sum and partial outputs
if (need_rescale) { if (need_rescale) {
// compute rescale factor // compute rescale factor
#ifdef DEFINE_FAST_EXP
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
rescale_factor_vec = fast_exp(rescale_factor_vec);
rescale_factor = rescale_factor_vec.get_last_elem();
#else
rescale_factor = std::exp(rescale_factor); rescale_factor = std::exp(rescale_factor);
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor); vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
#endif
// rescale sum // rescale sum
new_sum_val += rescale_factor * init_sum_val; new_sum_val += rescale_factor * init_sum_val;
...@@ -1889,15 +1883,8 @@ class AttentionMainLoop { ...@@ -1889,15 +1883,8 @@ class AttentionMainLoop {
: curr_output_buffer; : curr_output_buffer;
float rescale_factor = final_max > curr_max ? curr_max - final_max float rescale_factor = final_max > curr_max ? curr_max - final_max
: final_max - curr_max; : final_max - curr_max;
#ifdef DEFINE_FAST_EXP
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
rescale_factor_vec = fast_exp(rescale_factor_vec);
rescale_factor = rescale_factor_vec.get_last_elem();
#else
rescale_factor = std::exp(rescale_factor); rescale_factor = std::exp(rescale_factor);
vec_op::FP32Vec16 rescale_factor_vec(rescale_factor); vec_op::FP32Vec16 rescale_factor_vec(rescale_factor);
#endif
local_sum[head_idx] = final_max > curr_max local_sum[head_idx] = final_max > curr_max
? final_sum + rescale_factor * curr_sum ? final_sum + rescale_factor * curr_sum
......
...@@ -60,4 +60,54 @@ ...@@ -60,4 +60,54 @@
#endif #endif
#ifdef __aarch64__
// Implementation copied from Arm Optimized Routines (expf AdvSIMD)
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
#include <limits>
#define DEFINE_FAST_EXP \
const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f); \
const float ln2_hi = 0x1.62e4p-1f; \
const float ln2_lo = 0x1.7f7d1cp-20f; \
const float c0 = 0x1.0e4020p-7f; \
const float c2 = 0x1.555e66p-3f; \
const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2}; \
const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000); \
const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f); \
const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f); \
const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f); \
const float32x4_t pos_special_bound = vdupq_n_f32(0x1.5d5e2ap+6f); \
const float32x4_t neg_special_bound = vnegq_f32(pos_special_bound); \
const float32x4_t inf = \
vdupq_n_f32(std::numeric_limits<float>::infinity()); \
const float32x4_t zero = vdupq_n_f32(0.0f); \
auto neon_expf = [&](float32x4_t values) __attribute__((always_inline)) { \
float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2)); \
float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0); \
r = vfmsq_laneq_f32(r, n, ln2_c02, 1); \
uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23); \
float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias)); \
float32x4_t r2 = vmulq_f32(r, r); \
float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2); \
float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3); \
q = vfmaq_f32(q, p, r2); \
p = vmulq_f32(c4, r); \
float32x4_t poly = vfmaq_f32(p, q, r2); \
poly = vfmaq_f32(scale, poly, scale); \
const uint32x4_t hi_mask = vcgeq_f32(values, pos_special_bound); \
const uint32x4_t lo_mask = vcleq_f32(values, neg_special_bound); \
poly = vbslq_f32(hi_mask, inf, poly); \
return vbslq_f32(lo_mask, zero, poly); \
}; \
auto fast_exp = [&](vec_op::FP32Vec16& vec) \
__attribute__((always_inline)) { \
float32x4x4_t result; \
result.val[0] = neon_expf(vec.reg.val[0]); \
result.val[1] = neon_expf(vec.reg.val[1]); \
result.val[2] = neon_expf(vec.reg.val[2]); \
result.val[3] = neon_expf(vec.reg.val[3]); \
return vec_op::FP32Vec16(result); \
};
#endif // __aarch64__
#endif #endif
\ No newline at end of file
...@@ -118,6 +118,24 @@ ...@@ -118,6 +118,24 @@
} \ } \
} }
#define VLLM_DISPATCH_BOOL(expr, const_expr, ...) \
if (expr) { \
constexpr bool const_expr = true; \
__VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
__VA_ARGS__(); \
}
#define VLLM_DISPATCH_GROUP_SIZE(group_size, const_group_size, ...) \
if (group_size == 128) { \
constexpr int const_group_size = 128; \
__VA_ARGS__(); \
} else if (group_size == 64) { \
constexpr int const_group_size = 64; \
__VA_ARGS__(); \
}
#define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \ #define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \
switch (NUM_DIMS) { \ switch (NUM_DIMS) { \
case 2: { \ case 2: { \
......
...@@ -444,23 +444,27 @@ __device__ inline T apply_sigmoid(T val) { ...@@ -444,23 +444,27 @@ __device__ inline T apply_sigmoid(T val) {
return cuda_cast<T, float>(sigmoid_accurate(f)); return cuda_cast<T, float>(sigmoid_accurate(f));
} }
template <typename T> template <ScoringFunc SF, typename T>
__device__ inline T apply_scoring(T val) {
if constexpr (SF == SCORING_SIGMOID) {
return apply_sigmoid(val);
} else {
return val;
}
}
template <typename T, ScoringFunc SF>
__device__ void topk_with_k2(T* output, T const* input, T const* bias, __device__ void topk_with_k2(T* output, T const* input, T const* bias,
cg::thread_block_tile<32> const& tile, cg::thread_block_tile<32> const& tile,
int32_t const lane_id, int32_t const lane_id,
int const num_experts_per_group, int const num_experts_per_group) {
int const scoring_func) {
// Get the top2 per thread // Get the top2 per thread
T largest = neg_inf<T>(); T largest = neg_inf<T>();
T second_largest = neg_inf<T>(); T second_largest = neg_inf<T>();
if (num_experts_per_group > WARP_SIZE) { if (num_experts_per_group > WARP_SIZE) {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
T value = input[i]; T value = apply_scoring<SF>(input[i]);
// Apply scoring function if needed
if (scoring_func == SCORING_SIGMOID) {
value = apply_sigmoid(value);
}
value = value + bias[i]; value = value + bias[i];
if (value > largest) { if (value > largest) {
...@@ -472,11 +476,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, ...@@ -472,11 +476,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
} }
} else { } else {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
T value = input[i]; T value = apply_scoring<SF>(input[i]);
// Apply scoring function if needed
if (scoring_func == SCORING_SIGMOID) {
value = apply_sigmoid(value);
}
value = value + bias[i]; value = value + bias[i];
largest = value; largest = value;
} }
...@@ -501,13 +501,12 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, ...@@ -501,13 +501,12 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
} }
} }
template <typename T> template <typename T, ScoringFunc SF>
__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
int64_t const num_tokens, int64_t const num_tokens,
int64_t const num_cases, int64_t const num_cases,
int64_t const n_group, int64_t const n_group,
int64_t const num_experts_per_group, int64_t const num_experts_per_group) {
int const scoring_func) {
int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE;
...@@ -525,21 +524,21 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, ...@@ -525,21 +524,21 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;"); asm volatile("griddepcontrol.wait;");
#endif #endif
topk_with_k2(output, input, group_bias, tile, lane_id, topk_with_k2<T, SF>(output, input, group_bias, tile, lane_id,
num_experts_per_group, scoring_func); num_experts_per_group);
} }
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;"); asm volatile("griddepcontrol.launch_dependents;");
#endif #endif
} }
template <typename T, typename IdxT> template <typename T, typename IdxT, ScoringFunc SF, int NGroup = -1>
__global__ void group_idx_and_topk_idx_kernel( __global__ void group_idx_and_topk_idx_kernel(
T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices, T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
T const* bias, int64_t const num_tokens, int64_t const n_group, T const* bias, int64_t const num_tokens, int64_t const n_group,
int64_t const topk_group, int64_t const topk, int64_t const num_experts, int64_t const topk_group, int64_t const topk, int64_t const num_experts,
int64_t const num_experts_per_group, bool renormalize, int64_t const num_experts_per_group, bool renormalize,
double routed_scaling_factor, int scoring_func) { double routed_scaling_factor) {
int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id = int32_t case_id =
...@@ -549,6 +548,11 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -549,6 +548,11 @@ __global__ void group_idx_and_topk_idx_kernel(
topk_values += case_id * topk; topk_values += case_id * topk;
topk_indices += case_id * topk; topk_indices += case_id * topk;
constexpr bool kUseStaticNGroup = (NGroup > 0);
// use int32 to avoid implicit conversion
int32_t const n_group_i32 =
kUseStaticNGroup ? NGroup : static_cast<int32_t>(n_group);
int32_t align_num_experts_per_group = int32_t align_num_experts_per_group =
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group); warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
...@@ -574,13 +578,14 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -574,13 +578,14 @@ __global__ void group_idx_and_topk_idx_kernel(
if (case_id < num_tokens) { if (case_id < num_tokens) {
// calculate group_idx // calculate group_idx
int32_t target_num_min = WARP_SIZE - n_group + topk_group; int32_t target_num_min =
WARP_SIZE - n_group_i32 + static_cast<int32_t>(topk_group);
// The check is necessary to avoid abnormal input // The check is necessary to avoid abnormal input
if (lane_id < n_group && is_finite(group_scores[lane_id])) { if (lane_id < n_group_i32 && is_finite(group_scores[lane_id])) {
value = group_scores[lane_id]; value = group_scores[lane_id];
} }
int count_equal_to_top_value = WARP_SIZE - n_group; int count_equal_to_top_value = WARP_SIZE - n_group_i32;
int pre_count_equal_to_top_value = 0; int pre_count_equal_to_top_value = 0;
// Use loop to find the largset top_group // Use loop to find the largset top_group
while (count_equal_to_top_value < target_num_min) { while (count_equal_to_top_value < target_num_min) {
...@@ -604,7 +609,7 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -604,7 +609,7 @@ __global__ void group_idx_and_topk_idx_kernel(
int count_equalto_topkth_group = 0; int count_equalto_topkth_group = 0;
bool if_proceed_next_topk = topk_group_value != neg_inf<T>(); bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
if (case_id < num_tokens && if_proceed_next_topk) { if (case_id < num_tokens && if_proceed_next_topk) {
for (int i_group = 0; i_group < n_group; i_group++) { auto process_group = [&](int i_group) {
if ((group_scores[i_group] > topk_group_value) || if ((group_scores[i_group] > topk_group_value) ||
((group_scores[i_group] == topk_group_value) && ((group_scores[i_group] == topk_group_value) &&
(count_equalto_topkth_group < num_equalto_topkth_group))) { (count_equalto_topkth_group < num_equalto_topkth_group))) {
...@@ -613,11 +618,10 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -613,11 +618,10 @@ __global__ void group_idx_and_topk_idx_kernel(
i += WARP_SIZE) { i += WARP_SIZE) {
T candidates = neg_inf<T>(); T candidates = neg_inf<T>();
if (i < num_experts_per_group) { if (i < num_experts_per_group) {
// Apply scoring function (if any) and add bias // apply scoring function (if any) and add bias
T input = scores[offset + i]; T input = scores[offset + i];
if (is_finite(input)) { if (is_finite(input)) {
T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) T score = apply_scoring<SF>(input);
: input;
candidates = score + bias[offset + i]; candidates = score + bias[offset + i];
} }
} }
...@@ -627,6 +631,17 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -627,6 +631,17 @@ __global__ void group_idx_and_topk_idx_kernel(
count_equalto_topkth_group++; count_equalto_topkth_group++;
} }
} }
};
if constexpr (kUseStaticNGroup) {
#pragma unroll
for (int i_group = 0; i_group < NGroup; ++i_group) {
process_group(i_group);
}
} else {
for (int i_group = 0; i_group < n_group_i32; ++i_group) {
process_group(i_group);
}
} }
queue.done(); queue.done();
__syncwarp(); __syncwarp();
...@@ -646,12 +661,13 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -646,12 +661,13 @@ __global__ void group_idx_and_topk_idx_kernel(
if (i < topk) { if (i < topk) {
// Load the score value (without bias) for normalization // Load the score value (without bias) for normalization
T input = scores[s_topk_idx[i]]; T input = scores[s_topk_idx[i]];
value = value = apply_scoring<SF>(input);
(scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) : input;
s_topk_value[i] = value; s_topk_value[i] = value;
} }
topk_sum += if (renormalize) {
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>()); topk_sum +=
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
}
} }
} }
...@@ -660,13 +676,9 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -660,13 +676,9 @@ __global__ void group_idx_and_topk_idx_kernel(
if (case_id < num_tokens) { if (case_id < num_tokens) {
if (if_proceed_next_topk) { if (if_proceed_next_topk) {
for (int i = lane_id; i < topk; i += WARP_SIZE) { for (int i = lane_id; i < topk; i += WARP_SIZE) {
float value; float base = cuda_cast<float, T>(s_topk_value[i]);
if (renormalize) { float value = renormalize ? (base / topk_sum * routed_scaling_factor)
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum * : (base * routed_scaling_factor);
routed_scaling_factor;
} else {
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
}
topk_indices[i] = s_topk_idx[i]; topk_indices[i] = s_topk_idx[i];
topk_values[i] = value; topk_values[i] = value;
} }
...@@ -684,6 +696,45 @@ __global__ void group_idx_and_topk_idx_kernel( ...@@ -684,6 +696,45 @@ __global__ void group_idx_and_topk_idx_kernel(
#endif #endif
} }
template <typename T, typename IdxT, ScoringFunc SF>
inline void launch_group_idx_and_topk_kernel(
cudaLaunchConfig_t const& config, T* scores, T* group_scores,
float* topk_values, IdxT* topk_indices, T const* bias,
int64_t const num_tokens, int64_t const n_group, int64_t const topk_group,
int64_t const topk, int64_t const num_experts,
int64_t const num_experts_per_group, bool const renormalize,
double const routed_scaling_factor) {
auto launch = [&](auto* kernel_instance2) {
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
topk_values, topk_indices, bias, num_tokens, n_group,
topk_group, topk, num_experts, num_experts_per_group,
renormalize, routed_scaling_factor);
};
switch (n_group) {
case 4: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 4>);
break;
}
case 8: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 8>);
break;
}
case 16: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 16>);
break;
}
case 32: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 32>);
break;
}
default: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF>);
break;
}
}
}
template <typename T, typename IdxT> template <typename T, typename IdxT>
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
IdxT* topk_indices, T const* bias, int64_t const num_tokens, IdxT* topk_indices, T const* bias, int64_t const num_tokens,
...@@ -694,7 +745,6 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, ...@@ -694,7 +745,6 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
cudaStream_t const stream = 0) { cudaStream_t const stream = 0) {
int64_t num_cases = num_tokens * n_group; int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
cudaLaunchConfig_t config; cudaLaunchConfig_t config;
config.gridDim = topk_with_k2_num_blocks; config.gridDim = topk_with_k2_num_blocks;
config.blockDim = BLOCK_SIZE; config.blockDim = BLOCK_SIZE;
...@@ -705,16 +755,33 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, ...@@ -705,16 +755,33 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1; config.numAttrs = 1;
config.attrs = attrs; config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, auto const sf = static_cast<ScoringFunc>(scoring_func);
num_tokens, num_cases, n_group, num_experts / n_group, int64_t const num_experts_per_group = num_experts / n_group;
scoring_func); auto launch_topk_with_k2 = [&](auto* kernel_instance1) {
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
num_tokens, num_cases, n_group, num_experts_per_group);
};
switch (sf) {
case SCORING_NONE: {
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_NONE>;
launch_topk_with_k2(kernel_instance1);
break;
}
case SCORING_SIGMOID: {
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_SIGMOID>;
launch_topk_with_k2(kernel_instance1);
break;
}
default:
// should be guarded by higher level checks.
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
}
int64_t topk_with_k_group_num_blocks = int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
size_t dynamic_smem_in_bytes = size_t dynamic_smem_in_bytes =
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK, warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
topk); topk);
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
config.gridDim = topk_with_k_group_num_blocks; config.gridDim = topk_with_k_group_num_blocks;
config.blockDim = BLOCK_SIZE; config.blockDim = BLOCK_SIZE;
config.dynamicSmemBytes = dynamic_smem_in_bytes; config.dynamicSmemBytes = dynamic_smem_in_bytes;
...@@ -723,10 +790,24 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, ...@@ -723,10 +790,24 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1; config.numAttrs = 1;
config.attrs = attrs; config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, switch (sf) {
topk_values, topk_indices, bias, num_tokens, n_group, case SCORING_NONE: {
topk_group, topk, num_experts, num_experts / n_group, launch_group_idx_and_topk_kernel<T, IdxT, SCORING_NONE>(
renormalize, routed_scaling_factor, scoring_func); config, scores, group_scores, topk_values, topk_indices, bias,
num_tokens, n_group, topk_group, topk, num_experts,
num_experts_per_group, renormalize, routed_scaling_factor);
break;
}
case SCORING_SIGMOID: {
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_SIGMOID>(
config, scores, group_scores, topk_values, topk_indices, bias,
num_tokens, n_group, topk_group, topk, num_experts,
num_experts_per_group, renormalize, routed_scaling_factor);
break;
}
default:
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
}
} }
#define INSTANTIATE_NOAUX_TC(T, IdxT) \ #define INSTANTIATE_NOAUX_TC(T, IdxT) \
......
This diff is collapsed.
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"
namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
int32_t col) {
return row * total_col + col;
}
} // namespace
// TODO: Refactor common parts with moe_align_sum_kernels
template <typename scalar_t, typename token_cnts_t>
__global__ void moe_lora_align_sum_kernel(
scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping,
int64_t block_size, int num_experts, int max_loras, size_t numel,
int max_num_tokens_padded, int max_num_m_blocks,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled,
int32_t* lora_ids) {
const size_t tokens_per_thread = div_ceil(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
int lora_idx = blockIdx.x;
int lora_id = lora_ids[lora_idx];
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
return;
}
extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem;
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
// Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
sorted_token_ids[lora_id * max_num_tokens_padded + it] = numel;
}
// Initialize expert_ids with -1
for (size_t it = threadIdx.x; it < max_num_m_blocks; it += blockDim.x) {
expert_ids[lora_id * max_num_m_blocks + it] = -1;
}
// Initialize total_tokens_post_pad with 0
if (threadIdx.x == 0) {
total_tokens_post_pad[lora_id] = 0;
}
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int mask = token_lora_mapping[i / topk_num] == lora_id;
int idx = index(num_experts, threadIdx.x + 1, topk_ids[i]);
tokens_cnts[idx] += mask;
}
__syncthreads();
// For each expert we accumulate the token counts from the different threads.
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}
__syncthreads();
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] +
div_ceil(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
}
total_tokens_post_pad[lora_id] = static_cast<int32_t>(cumsum[num_experts]);
}
__syncthreads();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[index(max_num_m_blocks, lora_id, i / block_size)] =
threadIdx.x;
}
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
int mask = (int)token_lora_mapping[i / topk_num] == lora_id;
atomicAdd(
&sorted_token_ids[index(max_num_tokens_padded, lora_id, rank_post_pad)],
(i - numel) * mask);
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] += mask;
}
}
void moe_lora_align_block_size(
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size, int64_t max_loras,
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
torch::Tensor lora_ids) {
const int topk_num = topk_ids.size(1);
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
int device_max_shared_mem;
auto dev = topk_ids.get_device();
cudaDeviceGetAttribute(&device_max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int32_t num_thread = max((int32_t)num_experts, 128); // WARP_SIZE,
TORCH_CHECK(num_thread <= 1024,
"num_thread must be less than 1024, "
"and fallback is not implemented yet.");
const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) +
(num_experts + 1) * sizeof(int32_t);
if (shared_mem > device_max_shared_mem) {
TORCH_CHECK(false,
"Shared memory usage exceeds device limit, and global memory "
"fallback is not implemented yet.");
}
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] {
dim3 blockDim(num_thread);
auto kernel = moe_lora_align_sum_kernel<scalar_t, int32_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<max_loras, blockDim, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(),
token_lora_mapping.data_ptr<int32_t>(), block_size, num_experts,
max_loras, topk_ids.numel(), max_num_tokens_padded,
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), topk_num,
num_tokens_post_pad.data_ptr<int32_t>(),
adapter_enabled.data_ptr<int32_t>(), lora_ids.data_ptr<int32_t>());
});
}
\ No newline at end of file
...@@ -11,7 +11,8 @@ void moe_sum(torch::Tensor& input, torch::Tensor& output); ...@@ -11,7 +11,8 @@ void moe_sum(torch::Tensor& input, torch::Tensor& output);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids, int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad,
std::optional<torch::Tensor> maybe_expert_map);
void batched_moe_align_block_size(int64_t max_tokens_per_batch, void batched_moe_align_block_size(int64_t max_tokens_per_batch,
int64_t block_size, int64_t block_size,
...@@ -26,7 +27,7 @@ void moe_lora_align_block_size( ...@@ -26,7 +27,7 @@ void moe_lora_align_block_size(
int64_t max_num_tokens_padded, int64_t max_num_m_blocks, int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
torch::Tensor lora_ids); torch::Tensor lora_ids, std::optional<torch::Tensor> maybe_expert_map);
#ifndef USE_ROCM #ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales, torch::Tensor b_qweight, torch::Tensor b_scales,
......
...@@ -19,7 +19,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -19,7 +19,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"moe_align_block_size(Tensor topk_ids, int num_experts," "moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids," " int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids," " Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()"); " Tensor! num_tokens_post_pad,"
" Tensor? maybe_expert_map) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// Aligning the number of tokens to be processed by each expert such // Aligning the number of tokens to be processed by each expert such
...@@ -46,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -46,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor !experts_ids," " Tensor !experts_ids,"
" Tensor !num_tokens_post_pad," " Tensor !num_tokens_post_pad,"
" Tensor !adapter_enabled," " Tensor !adapter_enabled,"
" Tensor !lora_ids) -> () "); " Tensor !lora_ids,"
" Tensor? maybe_expert_map) -> () ");
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size); m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
#ifndef USE_ROCM #ifndef USE_ROCM
......
...@@ -102,13 +102,16 @@ void apply_repetition_penalties_(torch::Tensor& logits, ...@@ -102,13 +102,16 @@ void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& output_mask, const torch::Tensor& output_mask,
const torch::Tensor& repetition_penalties); const torch::Tensor& repetition_penalties);
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, void top_k_per_row_prefill(const torch::Tensor& logits,
const torch::Tensor& rowEnds, torch::Tensor& indices, const torch::Tensor& rowStarts,
int64_t numRows, int64_t stride0, int64_t stride1); const torch::Tensor& rowEnds, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK);
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
const torch::Tensor& seq_lens, torch::Tensor& indices, const torch::Tensor& seqLens, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1); int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, // void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale, // torch::Tensor& weight, torch::Tensor& scale,
...@@ -128,6 +131,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out, ...@@ -128,6 +131,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
std::optional<torch::Tensor> scale_ub, std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual); std::optional<torch::Tensor> residual);
void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& weight,
torch::Tensor& scales, double const epsilon,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual,
int64_t group_size, bool is_scale_transposed);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size, std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox); torch::Tensor& cos_sin_cache, bool is_neox);
...@@ -254,7 +264,8 @@ void get_cutlass_moe_mm_data( ...@@ -254,7 +264,8 @@ void get_cutlass_moe_mm_data(
void get_cutlass_moe_mm_problem_sizes( void get_cutlass_moe_mm_problem_sizes(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets); const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
std::optional<bool> force_swap_ab = std::nullopt);
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes1,
...@@ -301,6 +312,14 @@ void per_token_group_quant_int8(const torch::Tensor& input, ...@@ -301,6 +312,14 @@ void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size, torch::Tensor& output_s, int64_t group_size,
double eps, double int8_min, double int8_max); double eps, double int8_min, double int8_max);
// Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales.
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s_packed,
int64_t group_size, double eps,
double min_8bit, double max_8bit);
#endif #endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
......
// see csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh
#pragma once
#include <cuda.h>
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
// ElementB is int32 (packed int4)
// ElementGroupScale is cutlass::Array<cutlass::float_e4m3_t, 8> (packed fp8)
template <typename ElementA, typename ElementB, typename ElementC,
typename ElementAccumulator, typename ElementGroupScale>
__global__ void get_group_gemm_starts(
int64_t* expert_offsets, ElementA** a_offsets, ElementB** b_offsets,
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
ElementAccumulator** b_scales_offsets,
ElementGroupScale** b_group_scales_offsets, ElementA* a_base_as_int,
ElementB* b_base_as_int, ElementC* out_base_as_int,
ElementAccumulator* a_scales_base_as_int,
ElementAccumulator* b_scales_base_as_int,
ElementGroupScale* b_group_scales_base_as_int, int64_t n, int64_t k,
int64_t scale_k) {
int expert_id = threadIdx.x;
int64_t expert_offset = expert_offsets[expert_id];
// same as w8a8
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
a_scales_offsets[expert_id] = a_scales_base_as_int + expert_offset;
b_scales_offsets[expert_id] = b_scales_base_as_int + (n * expert_id);
// w4a8 specific
constexpr int pack_factor = 8; // pack 8 int4 into int32
b_offsets[expert_id] = b_base_as_int + (expert_id * k * n / pack_factor);
b_group_scales_offsets[expert_id] =
b_group_scales_base_as_int + (expert_id * scale_k * n);
}
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
cutlass::Array<cutlass::float_e4m3_t, 8>> \
<<<1, num_experts, 0, stream>>>( \
static_cast<int64_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<int32_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<float**>(a_scales_ptrs.data_ptr()), \
static_cast<float**>(b_scales_ptrs.data_ptr()), \
static_cast<cutlass::Array<cutlass::float_e4m3_t, 8>**>( \
b_group_scales_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
static_cast<int32_t*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<float*>(a_scales.data_ptr()), \
static_cast<float*>(b_scales.data_ptr()), \
static_cast<cutlass::Array<cutlass::float_e4m3_t, 8>*>( \
b_group_scales.data_ptr()), \
n, k, scale_k); \
}
namespace {
void run_get_group_gemm_starts(
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor& out_tensors,
torch::Tensor const& a_scales, torch::Tensor const& b_scales,
torch::Tensor const& b_group_scales, const int64_t b_group_size) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_group_scales.dtype() ==
torch::kFloat8_e4m3fn); // the underlying torch type is e4m3
TORCH_CHECK(out_tensors.dtype() ==
torch::kBFloat16); // only support bf16 for now
// expect int64_t to avoid overflow during offset calculations
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
int num_experts = static_cast<int>(expert_offsets.size(0));
// logical k, n
int64_t n = out_tensors.size(1);
int64_t k = a_tensors.size(1);
int64_t scale_k = cutlass::ceil_div(k, b_group_size);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
if (false) {
}
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
} // namespace
\ No newline at end of file
This diff is collapsed.
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <torch/all.h> #include <torch/all.h>
#include "cutlass_extensions/torch_utils.hpp" #include "cutlass_extensions/torch_utils.hpp"
#include "w4a8_utils.cuh"
#include "core/registration.h" #include "core/registration.h"
...@@ -395,71 +396,6 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { ...@@ -395,71 +396,6 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
return packed_scales; return packed_scales;
} }
/*
GPU-accelerated implementation of cutlass::unified_encode_int4b.
Constructs a lookup table in constant memory to map 8 bits
(two 4-bit values) at a time. Assumes memory is contiguous
and pointers are 16-byte aligned.
*/
__constant__ uint8_t kNibbleLUT[256];
__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out,
size_t nbytes) {
constexpr size_t V = sizeof(uint4); // 16 bytes
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t nthreads = size_t(gridDim.x) * blockDim.x;
const size_t nvec = nbytes / V;
// 1-D grid-stride loop over 16-byte chunks
for (size_t vec = tid; vec < nvec; vec += nthreads) {
uint4 v = reinterpret_cast<const uint4*>(in)[vec];
uint8_t* b = reinterpret_cast<uint8_t*>(&v);
#pragma unroll
for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]];
reinterpret_cast<uint4*>(out)[vec] = v;
}
}
static bool upload_lut() {
std::array<uint8_t, 256> lut{};
auto map_nib = [](uint8_t v) -> uint8_t {
// 1..7 -> (8 - v); keep 0 and 8..15
return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v);
};
for (int b = 0; b < 256; ++b) {
uint8_t lo = b & 0xF;
uint8_t hi = (b >> 4) & 0xF;
lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo));
}
cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(),
/*offset=*/0, cudaMemcpyHostToDevice);
return (e == cudaSuccess);
}
static bool unified_encode_int4b(cutlass::int4b_t const* in,
cutlass::int4b_t* out, size_t num_int4_elems) {
// Build/upload LUT
if (!upload_lut()) return false;
static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1,
"int4 storage must be 1 byte");
const size_t nbytes = num_int4_elems >> 1;
auto* in_bytes = reinterpret_cast<uint8_t const*>(in);
auto* out_bytes = reinterpret_cast<uint8_t*>(out);
// kernel launch params
constexpr int block = 256;
const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors
int grid = int((nvec + block - 1) / block);
if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel
unified_encode_int4b_device<<<grid, block>>>(in_bytes, out_bytes, nbytes);
cudaError_t err = cudaGetLastError();
return (err == cudaSuccess);
}
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
TORCH_CHECK(B.dtype() == torch::kInt32); TORCH_CHECK(B.dtype() == torch::kInt32);
TORCH_CHECK(B.dim() == 2); TORCH_CHECK(B.dim() == 2);
...@@ -477,8 +413,8 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { ...@@ -477,8 +413,8 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
LayoutB_Reordered layout_B_reordered = LayoutB_Reordered layout_B_reordered =
cute::tile_to_shape(LayoutAtomQuant{}, shape_B); cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
bool ok = bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr,
vllm::cutlass_w4a8::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); n * k);
TORCH_CHECK(ok, "unified_encode_int4b failed"); TORCH_CHECK(ok, "unified_encode_int4b failed");
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
......
#include "w4a8_utils.cuh"
#include <array>
#include <cuda_runtime.h>
#include <cstdio>
namespace vllm::cutlass_w4a8_utils {
/*
GPU-accelerated implementation of cutlass::unified_encode_int4b.
Constructs a lookup table in constant memory to map 8 bits
(two 4-bit values) at a time. Assumes memory is contiguous
and pointers are 16-byte aligned.
*/
__constant__ uint8_t kNibbleLUT[256];
__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out,
size_t nbytes) {
constexpr size_t V = sizeof(uint4); // 16 bytes
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t nthreads = size_t(gridDim.x) * blockDim.x;
const size_t nvec = nbytes / V;
// 1-D grid-stride loop over 16-byte chunks
for (size_t vec = tid; vec < nvec; vec += nthreads) {
uint4 v = reinterpret_cast<const uint4*>(in)[vec];
uint8_t* b = reinterpret_cast<uint8_t*>(&v);
#pragma unroll
for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]];
reinterpret_cast<uint4*>(out)[vec] = v;
}
}
static bool upload_lut() {
std::array<uint8_t, 256> lut{};
auto map_nib = [](uint8_t v) -> uint8_t {
// 1..7 -> (8 - v); keep 0 and 8..15
return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v);
};
for (int b = 0; b < 256; ++b) {
uint8_t lo = b & 0xF;
uint8_t hi = (b >> 4) & 0xF;
lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo));
}
cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(),
/*offset=*/0, cudaMemcpyHostToDevice);
return (e == cudaSuccess);
}
bool unified_encode_int4b(cutlass::int4b_t const* in, cutlass::int4b_t* out,
size_t num_int4_elems) {
// Build/upload LUT
if (!upload_lut()) return false;
static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1,
"int4 storage must be 1 byte");
const size_t nbytes = num_int4_elems >> 1;
auto* in_bytes = reinterpret_cast<uint8_t const*>(in);
auto* out_bytes = reinterpret_cast<uint8_t*>(out);
// kernel launch params
constexpr int block = 256;
const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors
int grid = int((nvec + block - 1) / block);
if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel
unified_encode_int4b_device<<<grid, block>>>(in_bytes, out_bytes, nbytes);
// launch errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("unified_encode_int4b_device launch error: %s (%d)\n",
cudaGetErrorString(err), err);
return false;
}
// runtime errors
err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("unified_encode_int4b_device runtime error: %s (%d)\n",
cudaGetErrorString(err), err);
return false;
}
return true;
}
} // namespace vllm::cutlass_w4a8_utils
\ No newline at end of file
#pragma once
#include <cstddef>
#include "cutlass/numeric_types.h"
namespace vllm::cutlass_w4a8_utils {
bool unified_encode_int4b(cutlass::int4b_t const* in, cutlass::int4b_t* out,
size_t num_int4_elems);
} // namespace vllm::cutlass_w4a8_utils
\ No newline at end of file
...@@ -136,15 +136,17 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids, ...@@ -136,15 +136,17 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
void get_cutlass_moe_mm_problem_sizes_caller( void get_cutlass_moe_mm_problem_sizes_caller(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) { const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
std::optional<bool> force_swap_ab = std::nullopt) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 = auto options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
// Swap-AB should be disabled for FP4 path // Swap-AB should be disabled for FP4 path
bool may_swap_ab = (!blockscale_offsets.has_value()) && bool may_swap_ab =
(topk_ids.numel() <= SWAP_AB_THRESHOLD); force_swap_ab.value_or((!blockscale_offsets.has_value()) &&
(topk_ids.numel() <= SWAP_AB_THRESHOLD));
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
atomic_buffer, num_experts, n, k, stream, atomic_buffer, num_experts, n, k, stream,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment