Commit 7a985548 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.0' into v0.9.0-ori

parents 45d3785c dc1440cf
...@@ -239,6 +239,280 @@ void static_quant_epilogue(const float* input, scalar_t* output, ...@@ -239,6 +239,280 @@ void static_quant_epilogue(const float* input, scalar_t* output,
} }
} }
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
void dynamic_quant_epilogue(const float* input, scalar_t* output,
const float* a_scale, const float* b_scale,
const int32_t* azp, const int32_t* azp_adj,
const scalar_t* bias, const int num_tokens,
const int hidden_size) {
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using azp_adj_load_vec_t =
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
int j = 0;
cvt_vec_t token_scale_vec(a_scale[i]);
cvt_vec_t token_zp_scale_vec;
if constexpr (AZP) {
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
if constexpr (!PerChannel) {
zp_scale_val *= *b_scale;
}
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
}
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
cvt_vec_t elems_fp32(input + i * hidden_size + j);
elems_fp32 = elems_fp32 * token_scale_vec;
if constexpr (AZP) {
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
if constexpr (PerChannel) {
cvt_vec_t b_scale_vec(b_scale + j);
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
}
elems_fp32 = elems_fp32 - azp_adj_fp32;
}
if constexpr (Bias) {
load_vec_t bias_vec(bias + j);
cvt_vec_t bias_vec_fp32(bias_vec);
elems_fp32 = elems_fp32 + bias_vec_fp32;
}
load_vec_t elems_out(elems_fp32);
elems_out.save(output + i * hidden_size + j);
}
cvt_vec_t elems_fp32(input + i * hidden_size + j);
elems_fp32 = elems_fp32 * token_scale_vec;
if constexpr (AZP) {
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
if constexpr (PerChannel) {
cvt_vec_t b_scale_vec(b_scale + j);
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
}
elems_fp32 = elems_fp32 - azp_adj_fp32;
}
if constexpr (Bias) {
load_vec_t bias_vec(bias + j);
cvt_vec_t bias_vec_fp32(bias_vec);
elems_fp32 = elems_fp32 + bias_vec_fp32;
}
load_vec_t elems_out(elems_fp32);
elems_out.save(output + i * hidden_size + j, hidden_size - j);
}
}
#elif defined(__powerpc64__)
template <bool AZP, typename scalar_t>
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
const float* scale, const int32_t* azp,
const int num_tokens,
const int hidden_size) {
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
constexpr float i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
constexpr float i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
const cvt_vec_t inv_scale(1.0 / *scale);
const cvt_vec_t i8_min_vec(i8_min);
const cvt_vec_t i8_max_vec(i8_max);
cvt_vec_t zp_vec;
if constexpr (AZP) {
zp_vec = cvt_vec_t(static_cast<float>(*azp));
}
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
int j = 0;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = elems_fp32 * inv_scale;
if constexpr (AZP) {
elems_fp32 = elems_fp32 + zp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j);
}
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = elems_fp32 * inv_scale;
if constexpr (AZP) {
elems_fp32 = elems_fp32 + zp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
}
}
template <bool AZP, typename scalar_t>
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
float* scale, int32_t* azp,
const int num_tokens,
const int hidden_size) {
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
constexpr float i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
constexpr float i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
const cvt_vec_t i8_min_vec(i8_min);
const cvt_vec_t i8_max_vec(i8_max);
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
cvt_vec_t min_value(std::numeric_limits<float>::max());
{
int j = 0;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
if constexpr (AZP) {
max_value = max_value.max(elems_fp32);
min_value = min_value.min(elems_fp32);
} else {
max_value = max_value.max(elems_fp32.abs());
}
}
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
if (j + vec_elem_num == hidden_size) {
if constexpr (AZP) {
max_value = max_value.max(elems_fp32);
min_value = min_value.min(elems_fp32);
} else {
max_value = max_value.max(elems_fp32.abs());
}
} else {
if constexpr (AZP) {
max_value = max_value.max(elems_fp32, hidden_size - j);
min_value = min_value.min(elems_fp32, hidden_size - j);
} else {
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
}
}
}
float scale_val, azp_val;
if constexpr (AZP) {
float max_scalar = max_value.reduce_max();
float min_scalar = min_value.reduce_min();
scale_val = (max_scalar - min_scalar) / 255.0f;
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
azp[i] = static_cast<int32_t>(azp_val);
scale[i] = scale_val;
} else {
scale_val = max_value.reduce_max() / 127.0f;
scale[i] = scale_val;
}
const cvt_vec_t inv_scale(1.0 / scale_val);
const cvt_vec_t azp_vec(azp_val);
{
int j = 0;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = (elems_fp32 * inv_scale);
if constexpr (AZP) {
elems_fp32 = elems_fp32 + azp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j);
}
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = (elems_fp32 * inv_scale);
if constexpr (AZP) {
elems_fp32 = elems_fp32 + azp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
}
}
}
template <bool PerChannel, typename scalar_t>
void static_quant_epilogue(const float* input, scalar_t* output,
const float a_scale, const float* b_scale,
const int32_t* azp_with_adj, const int num_tokens,
const int hidden_size) {
CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using azp_adj_load_vec_t =
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
cvt_vec_t a_scale_vec(a_scale);
cvt_vec_t b_scale_vec(*b_scale);
cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
int j = 0;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
cvt_vec_t elems_fp32(input + i * hidden_size + j);
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
if constexpr (PerChannel) {
b_scale_vec = cvt_vec_t(b_scale + j);
scale_vec = b_scale_vec * a_scale_vec;
}
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
load_vec_t elems_out(elems_fp32);
elems_out.save(output + i * hidden_size + j);
}
cvt_vec_t elems_fp32(input + i * hidden_size + j);
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
if constexpr (PerChannel) {
b_scale_vec = cvt_vec_t(b_scale + j);
scale_vec = b_scale_vec * a_scale_vec;
}
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
load_vec_t elems_out(elems_fp32);
elems_out.save(output + i * hidden_size + j, hidden_size - j);
}
}
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t> template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
void dynamic_quant_epilogue(const float* input, scalar_t* output, void dynamic_quant_epilogue(const float* input, scalar_t* output,
const float* a_scale, const float* b_scale, const float* a_scale, const float* b_scale,
...@@ -324,7 +598,8 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, ...@@ -324,7 +598,8 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
const float* scale, const int32_t* azp, const float* scale, const int32_t* azp,
const int num_tokens, const int num_tokens,
const int hidden_size) { const int hidden_size) {
TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.") TORCH_CHECK(
false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
} }
template <typename scalar_t> template <typename scalar_t>
...@@ -332,7 +607,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, ...@@ -332,7 +607,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
float* scale, int32_t* azp, float* scale, int32_t* azp,
const int num_tokens, const int num_tokens,
const int hidden_size) { const int hidden_size) {
TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.") TORCH_CHECK(
false,
"dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
} }
template <bool PerChannel, typename scalar_t> template <bool PerChannel, typename scalar_t>
...@@ -340,7 +617,7 @@ void static_quant_epilogue(const float* input, scalar_t* output, ...@@ -340,7 +617,7 @@ void static_quant_epilogue(const float* input, scalar_t* output,
const float a_scale, const float* b_scale, const float a_scale, const float* b_scale,
const int32_t* azp_with_adj, const int num_tokens, const int32_t* azp_with_adj, const int num_tokens,
const int hidden_size) { const int hidden_size) {
TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.") TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.")
} }
template <typename scalar_t> template <typename scalar_t>
...@@ -349,7 +626,8 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output, ...@@ -349,7 +626,8 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output,
const int32_t* azp, const int32_t* azp_with_adj, const int32_t* azp, const int32_t* azp_with_adj,
const scalar_t* bias, const int num_tokens, const scalar_t* bias, const int num_tokens,
const int hidden_size) { const int hidden_size) {
TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.") TORCH_CHECK(false,
"dynamic_quant_epilogue requires AVX512/powerpc64 support.")
} }
#endif #endif
} // namespace } // namespace
...@@ -611,3 +889,58 @@ void dynamic_scaled_int8_quant( ...@@ -611,3 +889,58 @@ void dynamic_scaled_int8_quant(
} }
}); });
} }
#if defined(__powerpc64__)
void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major
const torch::Tensor& a, // [M, IC], row-major
const torch::Tensor& b, // [IC, OC], column-major
const torch::Tensor& a_scales,
const torch::Tensor& b_scales,
const std::optional<torch::Tensor>& bias // [OC]
) {
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
// Checks for conformality
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
"int8_scaled_mm_ppc64le only supports INT8 inputs.");
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
// We dont need this
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
bias->dim() == 1);
}
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] {
torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float);
// Compute C_inter=s_b * (A@B)
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
if (bias.has_value()) {
// Compute C=s_a * C_inter + bias
dynamic_quant_epilogue<false, true, true>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr,
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
} else {
// Compute C=s_a * C_inter
dynamic_quant_epilogue<false, true, false, scalar_t>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr,
c.size(0), c.size(1));
}
});
}
#endif
...@@ -18,6 +18,14 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, ...@@ -18,6 +18,14 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
const std::optional<torch::Tensor>& azp, const std::optional<torch::Tensor>& azp,
const std::optional<torch::Tensor>& bias); const std::optional<torch::Tensor>& bias);
#if defined(__powerpc64__)
void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& a_scales,
const torch::Tensor& b_scales,
const std::optional<torch::Tensor>& bias);
#endif
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& kv_cache, double scale, torch::Tensor& kv_cache, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens); torch::Tensor& block_tables, torch::Tensor& seq_lens);
...@@ -117,7 +125,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -117,7 +125,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def( ops.def(
"rotary_embedding(Tensor positions, Tensor! query," "rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size," " Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()"); " Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
...@@ -150,6 +158,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -150,6 +158,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b_scales, Tensor azp_adj," " Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()"); " Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
#elif defined(__powerpc64__)
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
"Tensor? azp) -> ()");
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
&dynamic_scaled_int8_quant);
// W8A8 GEMM, supporting symmetric quantization.
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le);
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
#endif #endif
// SHM CCL // SHM CCL
......
...@@ -59,3 +59,13 @@ struct enable_sm90_only : Kernel { ...@@ -59,3 +59,13 @@ struct enable_sm90_only : Kernel {
#endif #endif
} }
}; };
template <typename Kernel>
struct enable_sm100_only : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
...@@ -65,5 +65,19 @@ ...@@ -65,5 +65,19 @@
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
...@@ -140,6 +140,10 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] ...@@ -140,6 +140,10 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
double epsilon) { double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
......
#include "marlin_moe_kernel_ku4.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = true;
if (false) {
}
AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128)
AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
} // namespace marlin_moe
#include "marlin_moe_kernel_ku4b8.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4b8(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) {
}
GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256)
GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256)
GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128)
GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4b8(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
} // namespace marlin_moe
#include "marlin_moe_kernel_ku8b128.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku8b128(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) {
}
GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256)
GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256)
GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128)
GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
bool call_marlin_moe_kernel_ku8b128(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
}
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include "core/exception.hpp"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
template <typename T>
inline std::string str(T x) {
return std::to_string(x);
}
namespace marlin_moe {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {
int start_row = block_rows * blockIdx.x;
int finish_row = start_row + block_rows;
if (finish_row > size_m) {
finish_row = size_m;
}
int cur_block_rows = finish_row - start_row;
int row_stride = size_k * sizeof(half) / 16;
auto permute_row = [&](int row) {
int iters = size_k / blockDim.x;
int rest = size_k % blockDim.x;
int offset = row * row_stride;
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
int base_k = 0;
for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
base_k += blockDim.x;
}
if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
}
}
};
for (int i = 0; i < cur_block_rows; i++) {
int cur_row = start_row + i;
if (cur_row < size_m) {
permute_row(cur_row);
}
}
}
__global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
int* __restrict__ expert_offsets,
int topk_length, int block_size) {
int expert_id = threadIdx.x;
int num_experts = blockDim.x;
int occurrences = 0;
for (int i = 0; i < topk_length; ++i) {
occurrences += (topk_ids[i] == expert_id);
}
expert_offsets[expert_id + 1] = occurrences;
__syncthreads();
if (threadIdx.x == 0) {
int tot_offset = 0;
expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size;
expert_offsets[i + 1] = tot_offset;
}
}
__syncthreads();
}
#else
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {
// Marlin is not implemented yet for SM < 8.0
assert(false);
return;
}
__global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
int* __restrict__ expert_offsets,
int topk_length, int block_size) {
// Marlin is not implemented yet for SM < 8.0
assert(false);
return;
}
#endif
typedef struct {
int thread_k;
int thread_n;
int num_threads;
} thread_config_t;
typedef struct {
int max_m_blocks;
thread_config_t tb_cfg;
} exec_config_t;
thread_config_t small_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{128, 128, 256}, // Default
{128, 64, 128}, // Reduce N 2X, same K
{64, 256, 256}, // Reduce K 2X, increase N 2X
{64, 128, 128}, // Reduce K 2X, same N
{64, 64, 128}, // Reduce both 2X
};
thread_config_t large_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{64, 256, 256}, // Default
{128, 128, 256}, // Reduce N 2X, increase K 2X
{64, 128, 128}, // Reduce N 2X, same K
{128, 64, 128}, // Reduce N 4X, increase K 2X
{64, 64, 128}, // Reduce N 4X, same K
};
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) {
bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n;
int tb_k = th_config.thread_k;
// Get max scale groups per thread-block
int tb_groups;
if (group_size == -1) {
tb_groups = 1;
} else if (group_size == 0) {
tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size
} else {
tb_groups = ceildiv(tb_k, group_size);
}
if (cache_scales_chunk) {
int load_groups =
tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 4;
} else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * STAGES;
}
}
bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int scales_cache_size, int max_shared_mem) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int b_size = (tb_k * tb_n / pack_factor) * 4;
// Get A size
int m_blocks = ceildiv(prob_m, 16);
int tb_max_m = 16;
while (true) {
if (m_blocks >= max_m_blocks) {
tb_max_m *= max_m_blocks;
break;
}
max_m_blocks--;
if (max_m_blocks == 0) {
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
}
}
int a_size = (tb_max_m * tb_k) * 2;
float pipe_size = (a_size + b_size) * STAGES;
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
}
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
return false;
}
// Verify K/N are divisible by thread K/N
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
return false;
}
// thread_k can be only 128 or 64 (because it must be less than groupsize
// which is 128)
if (th_config.thread_k != 128 && th_config.thread_k != 64) {
return false;
}
// Verify min for thread K/N
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
return false;
}
// num_threads must be at least 128 (= 4 warps)
if (th_config.num_threads < 128) {
return false;
}
// Determine cache for scales
int scales_cache_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
// Check that pipeline fits into cache
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, scales_cache_size, max_shared_mem)) {
return false;
}
return true;
}
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int num_bits, int group_size,
bool has_act_order, bool is_k_full,
int max_shared_mem) {
int max_m_blocks = 4;
while (max_m_blocks > 0) {
if (prob_m <= 16) {
for (auto th_config : small_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
} else {
for (auto th_config : large_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
}
max_m_blocks--; // Process less M blocks per invocation to reduce cache
// usage
}
return exec_config_t{0, {-1, -1, -1}};
}
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
else if (KERNEL_FUNCTION( \
q_type, thread_n_blocks, thread_k_blocks, has_act_order, \
group_blocks, num_threads, blocks, max_shared_mem, stream, \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par, \
exec_cfg.max_m_blocks)) { \
}
void marlin_mm_moe(const void* A, const void* B, void* C,
const void* sorted_ids, const void* topk_weights,
const void* topk_ids, const void* s, void* zp,
const void* g_idx, const void* perm, void* a_tmp,
void* expert_offsets, int prob_m, int prob_n, int prob_k,
void* workspace, vllm::ScalarType const& q_type,
bool has_act_order, bool is_k_full, bool has_zp,
int num_groups, int group_size, int num_experts, int topk,
int moe_block_size, int dev, cudaStream_t stream,
int thread_k, int thread_n, int sms, int max_par,
bool replicate_input, bool apply_weights) {
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
if (sms == -1) {
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
}
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
int num_bits = q_type.size_bits();
// Set thread config
exec_config_t exec_cfg;
if (thread_k != -1 && thread_n != -1) {
// User-defined config
exec_cfg =
exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}};
} else {
// Auto config
exec_cfg =
determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, max_shared_mem);
}
TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, max_shared_mem),
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
", thread_k = ", exec_cfg.tb_cfg.thread_k,
", thread_n = ", exec_cfg.tb_cfg.thread_n,
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", max_shared_mem = ", max_shared_mem);
int num_threads = exec_cfg.tb_cfg.num_threads;
thread_k = exec_cfg.tb_cfg.thread_k;
thread_n = exec_cfg.tb_cfg.thread_n;
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
int blocks = sms;
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
" is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k);
int group_blocks = 0;
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(group_size != -1);
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
} else {
TORCH_CHECK(group_size == 0);
group_blocks = 0;
}
} else {
if (group_size == -1) {
group_blocks = -1;
} else {
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
}
}
int tot_m = prob_m;
const int* topk_ids_ptr = (const int*)topk_ids;
int* expert_offsets_ptr = (int*)expert_offsets;
compute_expert_offsets<<<1, num_experts, 0, stream>>>(
topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size);
bool do_permute_a = has_act_order;
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by
// having a full K, we have full original groups)
if (is_k_full) {
has_act_order = false;
}
int pack_factor = 32 / q_type.size_bits();
for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
const int4* A_ptr = (const int4*)A;
int4* a_tmp_ptr = (int4*)a_tmp;
const int4* B_ptr =
(const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx;
int4* C_ptr = (int4*)C;
const float* topk_weights_ptr = (const float*)topk_weights;
const int* sorted_ids_ptr = (const int*)sorted_ids;
const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
const int4* zp_ptr =
(const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx;
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
int* locks = (int*)workspace;
if (do_permute_a) {
// Permute A columns
int topk_rows = replicate_input ? tot_m : tot_m * topk;
int block_rows = ceildiv(topk_rows, blocks);
permute_cols_kernel<<<blocks, num_threads, 0, stream>>>(
A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows);
A_ptr = a_tmp_ptr;
}
int tot_m_blocks = ceildiv(tot_m, 16);
for (int m_block = 0; m_block < tot_m_blocks;
m_block += 4 * exec_cfg.max_m_blocks) {
if (false) {
}
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" +
", has_act_order = " + str(has_act_order) +
", num_groups = " + str(num_groups) +
", group_size = " + str(group_size) +
", thread_n_blocks = " + str(thread_n_blocks) +
", thread_k_blocks = " + str(thread_k_blocks));
}
}
}
}
} // namespace marlin_moe
torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
const torch::Tensor& perm, torch::Tensor& workspace,
vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
int64_t moe_block_size, bool replicate_input, bool apply_weights) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
bool has_zp = b_zeros.size(1) != 0;
if (has_zp) {
TORCH_CHECK(
b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str());
}
int pack_factor = 32 / b_q_type.size_bits();
int max_par = 4;
int dev = a.get_device();
auto options_dtype =
torch::TensorOptions().dtype(a.dtype()).device(a.device());
auto options_int =
torch::TensorOptions().dtype(torch::kInt).device(a.device());
torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype);
torch::Tensor a_tmp =
replicate_input ? torch::zeros({size_m, size_k}, options_dtype)
: torch::zeros({size_m, topk, size_k}, options_dtype);
torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_k = -1;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_n = -1;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int sms = -1;
// Detect groupsize and act_order
int num_groups = -1;
int group_size = -1;
bool has_act_order = g_idx.size(1) != 0;
int b_rank = b_scales.sizes().size();
TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3");
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
" is not size_n = ", size_n);
num_groups = b_scales.size(1);
TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order),
"if is_k_full is false, has_act_order must be true");
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by num_groups = ", num_groups);
group_size = size_k / num_groups;
} else {
group_size = 0;
}
} else {
if (num_groups > 1) {
TORCH_CHECK(
size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by b_scales.size(0) = ", b_scales.size(0));
group_size = size_k / num_groups;
} else {
group_size = -1;
}
}
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
TORCH_CHECK(b_zeros.size(1) == num_groups,
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
"b_zeros dim 2 = ", b_zeros.size(2),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
marlin_moe::marlin_mm_moe(
a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(),
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
num_experts, topk, moe_block_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
replicate_input, apply_weights);
return c;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_gemm_moe", &marlin_gemm_moe);
}
kernel_*.cu
\ No newline at end of file
...@@ -25,15 +25,16 @@ TEMPLATE = ("template __global__ void Marlin<" ...@@ -25,15 +25,16 @@ TEMPLATE = ("template __global__ void Marlin<"
"{{thread_k_blocks}}, " "{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, " "{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, " "{{stages}}, "
"{{'true' if has_act_order else 'false'}}, "
"{{'true' if has_zp else 'false'}}, "
"{{group_blocks}}, " "{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>" "{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );") "( MARLIN_KERNEL_PARAMS );")
# int8 with zero point case (vllm::kU8) is also supported, # int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size. # we don't add it to reduce wheel size.
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"] SCALAR_TYPES = [
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
"vllm::kFE2M1f"
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
...@@ -41,7 +42,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] ...@@ -41,7 +42,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# = 0 : act order case # = 0 : act order case
# = -1 : channelwise quantization # = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks # > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 2, 4, 8] GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
DTYPES = ["fp16", "bf16"] DTYPES = ["fp16", "bf16"]
...@@ -52,21 +53,35 @@ def remove_old_kernels(): ...@@ -52,21 +53,35 @@ def remove_old_kernels():
def generate_new_kernels(): def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
has_zp = "B" not in scalar_type
all_template_str_list = [] all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product( for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
has_act_order = group_blocks == 0 # act order case only support gptq-int4 and gptq-int8
if has_zp and has_act_order: if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8", "vllm::kU8B128"
]:
continue continue
if thread_configs[2] == 256: if thread_configs[2] == 256:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128: if m_blocks <= 1 and thread_configs[0] != 128:
continue continue
if m_blocks > 1 and thread_configs[0] != 64: if m_blocks > 1 and thread_configs[0] != 64:
continue continue
# we only support channelwise quantization and group_size == 128
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
# nvfp4 only supports group_size == 16
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue
k_blocks = thread_configs[0] // 16 k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16 n_blocks = thread_configs[1] // 16
threads = thread_configs[2] threads = thread_configs[2]
...@@ -82,8 +97,6 @@ def generate_new_kernels(): ...@@ -82,8 +97,6 @@ def generate_new_kernels():
thread_k_blocks=k_blocks, thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5, m_block_size_8=m_blocks == 0.5,
stages="pipe_stages", stages="pipe_stages",
has_act_order=has_act_order,
has_zp=has_zp,
group_blocks=group_blocks, group_blocks=group_blocks,
is_zp_float=False, is_zp_float=False,
) )
......
...@@ -7,18 +7,19 @@ ...@@ -7,18 +7,19 @@
#include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \ #define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \ const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ const int4 *__restrict__ scales_ptr, \
const int *__restrict__ g_idx, \ const uint16_t *__restrict__ scale2_ptr, \
const int32_t *__restrict__ sorted_token_ids_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ expert_ids_ptr, \ const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \ const int32_t *__restrict__ expert_ids_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \ const int32_t *__restrict__ num_tokens_past_padded_ptr, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ const float *__restrict__ topk_weights_ptr, int top_k, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \ bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
bool use_fp32_reduce int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16 template <typename scalar_t, // compute dtype, half or nv_float16
...@@ -33,11 +34,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16 ...@@ -33,11 +34,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1 // only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled const int group_blocks, // number of consecutive 16x16 blocks
const bool has_zp, // whether zero-points are enabled // with a separate quantization scale
const int group_blocks, // number of consecutive 16x16 blocks const bool is_zp_float // is zero point of float16 type?
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
> >
__global__ void Marlin(MARLIN_KERNEL_PARAMS); __global__ void Marlin(MARLIN_KERNEL_PARAMS);
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "quantization/gptq_marlin/marlin.cuh" #include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h"
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...@@ -48,11 +49,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16 ...@@ -48,11 +49,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1 // only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled const int group_blocks, // number of consecutive 16x16 blocks
const bool has_zp, // whether zero-points are enabled // with a separate quantization scale
const int group_blocks, // number of consecutive 16x16 blocks const bool is_zp_float // is zero point of float16 type?
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
> >
__global__ void Marlin( __global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
...@@ -77,8 +76,8 @@ __global__ void Marlin( ...@@ -77,8 +76,8 @@ __global__ void Marlin(
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce // whether to use fp32 global reduce bool use_fp32_reduce, // whether to use fp32 global reduce
) {} int max_shared_mem) {}
} // namespace MARLIN_NAMESPACE_NAME } // namespace MARLIN_NAMESPACE_NAME
...@@ -166,144 +165,6 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a, ...@@ -166,144 +165,6 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a,
} }
} }
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
template <typename scalar_t, int bit>
__device__ inline typename ScalarType<scalar_t>::FragB dequant(
int q, typename ScalarType<scalar_t>::FragB& frag_b);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant<half, 4>(
int q, typename ScalarType<half>::FragB& frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, 4>(int q,
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
q >>= 4;
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC308C308;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant<half, 8>(
int q, typename ScalarType<half>::FragB& frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, 8>(int q,
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388736.f;
fp32_intermediates[1] -= 8388736.f;
fp32_intermediates[2] -= 8388736.f;
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used // Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization. // only for grouped quantization.
template <typename scalar_t> template <typename scalar_t>
...@@ -429,11 +290,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16 ...@@ -429,11 +290,9 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1 // only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared const int stages, // number of stages for the async global->shared
// fetch pipeline // fetch pipeline
const bool has_act_order, // whether act_order is enabled const int group_blocks, // number of consecutive 16x16 blocks
const bool has_zp, // whether zero-points are enabled // with a separate quantization scale
const int group_blocks, // number of consecutive 16x16 blocks const bool is_zp_float // is zero point of float16 type?
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
> >
__global__ void Marlin( __global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A, // fp16 input matrix of shape mxk
...@@ -442,9 +301,11 @@ __global__ void Marlin( ...@@ -442,9 +301,11 @@ __global__ void Marlin(
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn // (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
// (k/groupsize)x(n/pack_factor) // only)
const int* __restrict__ g_idx, // int32 group indices of shape k const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
...@@ -458,8 +319,8 @@ __global__ void Marlin( ...@@ -458,8 +319,8 @@ __global__ void Marlin(
int prob_k, // reduction dimension k int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce // whether to use fp32 global reduce bool use_fp32_reduce, // whether to use fp32 global reduce
) { int max_shared_mem) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the // Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 * // same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
...@@ -481,13 +342,26 @@ __global__ void Marlin( ...@@ -481,13 +342,26 @@ __global__ void Marlin(
extern __shared__ int4 sh[]; extern __shared__ int4 sh[];
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
!is_int_type ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
scalar_t2 global_scale;
constexpr bool has_act_order = group_blocks == 0;
constexpr int pack_factor = 32 / w_type.size_bits(); constexpr int pack_factor = 32 / w_type.size_bits();
static_assert(thread_m_blocks == 1 || !m_block_size_8); static_assert(thread_m_blocks == 1 || !m_block_size_8);
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
const int group_size = const int group_size =
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
const int scales_expert_stride = prob_n * prob_k / group_size / 8; const int scales_expert_stride =
prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8);
const int zp_expert_stride = const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8 is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4); : prob_n * prob_k / group_size / (pack_factor * 4);
...@@ -534,13 +408,20 @@ __global__ void Marlin( ...@@ -534,13 +408,20 @@ __global__ void Marlin(
int64_t B_expert_off = 0; int64_t B_expert_off = 0;
int4* sh_block_sorted_ids_int4 = sh; int4* sh_block_sorted_ids_int4 = sh;
int4* sh_rd_block_sorted_ids_int4 =
sh_block_sorted_ids_int4 + moe_block_size / 4;
int4* sh_block_topk_weights_int4 =
sh_rd_block_sorted_ids_int4 + moe_block_size / 4;
// sh_block_topk_weights_int4 only need (moe_block_size / 4);
// but we pad to align to 256 bytes
int4* sh_new =
sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size;
int32_t* sh_block_sorted_ids = int32_t* sh_block_sorted_ids =
reinterpret_cast<int*>(sh_block_sorted_ids_int4); reinterpret_cast<int*>(sh_block_sorted_ids_int4);
int4* sh_block_topk_weights_int4 = int32_t* sh_rd_block_sorted_ids =
sh_block_sorted_ids_int4 + moe_block_size / 4; reinterpret_cast<int*>(sh_rd_block_sorted_ids_int4);
scalar_t2* sh_block_topk_weights = scalar_t2* sh_block_topk_weights =
reinterpret_cast<scalar_t2*>(sh_block_topk_weights_int4); reinterpret_cast<scalar_t2*>(sh_block_topk_weights_int4);
int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 4;
int32_t block_num_valid_tokens = 0; int32_t block_num_valid_tokens = 0;
int32_t locks_off = 0; int32_t locks_off = 0;
...@@ -584,12 +465,24 @@ __global__ void Marlin( ...@@ -584,12 +465,24 @@ __global__ void Marlin(
sh_block_sorted_ids_int4[tid4] = reinterpret_cast<const int4*>( sh_block_sorted_ids_int4[tid4] = reinterpret_cast<const int4*>(
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4];
#pragma unroll
for (int i = 0; i < 4; i++)
sh_rd_block_sorted_ids[tid4 * 4 + i] =
sh_block_sorted_ids[tid4 * 4 + i] / top_k;
if (mul_topk_weights) { if (mul_topk_weights) {
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
sh_block_topk_weights[tid4 * 4 + i] = int idx = tid4 * 4 + i;
Dtype::num2num2(Dtype::float2num( idx = idx < block_num_valid_tokens ? idx : 0;
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])); if constexpr (w_type == vllm::kFE2M1f) {
sh_block_topk_weights[idx] = __hmul2(
global_scale, Dtype::num2num2(Dtype::float2num(
topk_weights_ptr[sh_block_sorted_ids[idx]])));
} else {
sh_block_topk_weights[idx] = Dtype::num2num2(
Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
}
} }
} }
} }
...@@ -620,6 +513,11 @@ __global__ void Marlin( ...@@ -620,6 +513,11 @@ __global__ void Marlin(
expert_id = expert_ids_ptr[block_id]; expert_id = expert_ids_ptr[block_id];
} }
if constexpr (w_type == vllm::kFE2M1f) {
uint16_t val = scale2_ptr[expert_id];
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
}
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
scales_ptr += (expert_id - old_expert_id) * scales_expert_stride; scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;
if constexpr (has_zp) { if constexpr (has_zp) {
...@@ -733,7 +631,7 @@ __global__ void Marlin( ...@@ -733,7 +631,7 @@ __global__ void Marlin(
constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups = constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1)
: 1; : 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride; constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride; int s_gl_rd_delta = s_gl_stride;
...@@ -743,6 +641,7 @@ __global__ void Marlin( ...@@ -743,6 +641,7 @@ __global__ void Marlin(
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
// constexpr int act_s_row_stride = 1; // constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups; // int act_s_col_stride = act_s_row_stride * num_groups;
constexpr int act_s_max_num_groups = 32;
int act_s_col_stride = 1; int act_s_col_stride = 1;
int act_s_col_warp_stride = act_s_col_stride * 8; int act_s_col_warp_stride = act_s_col_stride * 8;
int tb_n_warps = thread_n_blocks / 4; int tb_n_warps = thread_n_blocks / 4;
...@@ -758,9 +657,9 @@ __global__ void Marlin( ...@@ -758,9 +657,9 @@ __global__ void Marlin(
int zp_gl_rd_delta = zp_gl_stride; int zp_gl_rd_delta = zp_gl_stride;
// Global A read index of current thread. // Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o;
(threadIdx.x % a_gl_rd_delta_o); int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o;
a_gl_rd += a_gl_rd_delta_o * slice_row;
// Shared write index of current thread. // Shared write index of current thread.
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o); (threadIdx.x % a_gl_rd_delta_o);
...@@ -774,8 +673,8 @@ __global__ void Marlin( ...@@ -774,8 +673,8 @@ __global__ void Marlin(
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs; (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row; b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x * b_thread_vecs; auto b_sh_wr = threadIdx.x * b_thread_vecs;
int b_sh_rd = threadIdx.x * b_thread_vecs; auto b_sh_rd = threadIdx.x * b_thread_vecs;
// For act_order // For act_order
constexpr int k_iter_size = tb_k / b_sh_wr_iters; constexpr int k_iter_size = tb_k / b_sh_wr_iters;
...@@ -790,11 +689,12 @@ __global__ void Marlin( ...@@ -790,11 +689,12 @@ __global__ void Marlin(
if constexpr (group_blocks == -1) { if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else { } else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
(w_type == vllm::kFE2M1f ? 2 : 1) +
s_sh_stride * slice_col + threadIdx.x; s_sh_stride * slice_col + threadIdx.x;
} }
} }
int s_sh_wr = threadIdx.x; auto s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride; bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// Zero-points // Zero-points
...@@ -807,17 +707,27 @@ __global__ void Marlin( ...@@ -807,17 +707,27 @@ __global__ void Marlin(
zp_sh_stride * slice_col + threadIdx.x; zp_sh_stride * slice_col + threadIdx.x;
} }
} }
int zp_sh_wr = threadIdx.x; auto zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
// We use a different scale layout for grouped and column-wise quantization as // We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in // we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case. // row-major in the latter case.
int s_sh_rd; int s_sh_rd;
if constexpr (group_blocks != -1) if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) {
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4; (threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp)) s_sh_rd = s_sh_rd * 2 + warp_row % 2;
} else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 &&
(m_block_size_8 || (has_zp && !dequant_skip_flop)))
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8; (threadIdx.x % 32) / 8;
else else
...@@ -851,7 +761,7 @@ __global__ void Marlin( ...@@ -851,7 +761,7 @@ __global__ void Marlin(
// each warp must also write a consecutive memory segment? // each warp must also write a consecutive memory segment?
auto transform_a = [&](int i) { auto transform_a = [&](int i) {
int row = i / a_gl_rd_delta_o; int row = i / a_gl_rd_delta_o;
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8);
}; };
// Since the computation of this remapping is non-trivial and, due to our main // Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute // loop unrolls, all shared memory accesses are static, we simply precompute
...@@ -879,12 +789,28 @@ __global__ void Marlin( ...@@ -879,12 +789,28 @@ __global__ void Marlin(
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
// Shared memory storage for global fetch pipelines. // Shared memory storage for global fetch pipelines.
int4* sh_a = sh_new; constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks;
int4* sh_b = sh_a + (stages * a_sh_stage); constexpr int sh_b_size = stages * b_sh_stage;
int4* sh_g_idx = sh_b + (stages * b_sh_stage); int4* sh_b = sh_new;
int4* sh_red = sh_new;
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
int4* sh_zp = sh_g_idx + (stages * g_idx_stage); int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage); int4* sh_s = sh_zp + (stages * zp_sh_stage);
int4* sh_red = sh_b; // shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size;
constexpr int shm_size_used =
moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
// all remaining shared memory is used to cache A (input)
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
int sh_a_max_row =
((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2);
// Register storage for double buffer of shared memory reads. // Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks]; FragA frag_a[2][thread_m_blocks];
...@@ -905,15 +831,14 @@ __global__ void Marlin( ...@@ -905,15 +831,14 @@ __global__ void Marlin(
int sh_first_group_id = -1; int sh_first_group_id = -1;
int sh_num_groups = -1; int sh_num_groups = -1;
constexpr int sh_max_num_groups = 32;
auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id,
int last_group_id) { int last_group_id) {
sh_first_group_id = first_group_id; sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1; sh_num_groups = last_group_id - first_group_id + 1;
if (sh_num_groups < sh_max_num_groups) { if (sh_num_groups > act_s_max_num_groups) {
sh_num_groups = sh_max_num_groups; sh_num_groups = act_s_max_num_groups;
} }
if (sh_first_group_id + sh_num_groups > num_groups) { if (sh_first_group_id + sh_num_groups > num_groups) {
...@@ -940,27 +865,31 @@ __global__ void Marlin( ...@@ -940,27 +865,31 @@ __global__ void Marlin(
} }
} }
}; };
// Asynchronously fetch the next A, B and s tile from global to the next // Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location. // shared memory pipeline location.
int a_remaining_load_count_in_slice = stages; bool should_load_a = true;
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { int max_num_stage_groups =
((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages;
max_num_stage_groups = max(max_num_stage_groups, 1);
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true,
int pipe_a = 0) {
if (pred) { if (pred) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe; if (should_load_a) {
if (prob_k > thread_k_blocks * 16 * stages || slice_col == 0 || int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
a_remaining_load_count_in_slice > 0) {
a_remaining_load_count_in_slice--;
#pragma unroll #pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) { for (int i = 0; i < a_sh_wr_iters; i++) {
int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row;
int row = a_idx / a_gl_stride;
int64_t sorted_row = 0; int64_t sorted_row = 0;
if (!m_block_size_8 || row < 8) if (!m_block_size_8 || row < 8)
sorted_row = sh_block_sorted_ids[row] / top_k; sorted_row = sh_rd_block_sorted_ids[row];
int64_t true_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; int64_t true_idx =
sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off;
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx],
row < block_num_valid_tokens); row < block_num_valid_tokens);
} }
} }
int4* sh_b_stage = sh_b + b_sh_stage * pipe; int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) { for (int i = 0; i < b_sh_wr_iters; i++) {
...@@ -1063,8 +992,8 @@ __global__ void Marlin( ...@@ -1063,8 +992,8 @@ __global__ void Marlin(
// Load the next sub-tile from the current location in the shared memory pipe // Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer. // into the current register buffer.
auto fetch_to_registers = [&](int k, int pipe) { auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe; int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) for (int i = 0; i < thread_m_blocks; i++)
ldsm<m_block_size_8 ? 2 : 4, scalar_t>( ldsm<m_block_size_8 ? 2 : 4, scalar_t>(
...@@ -1109,12 +1038,17 @@ __global__ void Marlin( ...@@ -1109,12 +1038,17 @@ __global__ void Marlin(
} }
} else if constexpr (group_blocks != -1) { } else if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) { if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_s_stage = if (k % b_sh_wr_iters == 0) {
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * int4* sh_s_stage =
(pipe / (group_blocks / thread_k_blocks))); sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; (pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
}
} else { } else {
int warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4; int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps; int warp_row = warp_id / n_warps;
...@@ -1123,12 +1057,19 @@ __global__ void Marlin( ...@@ -1123,12 +1057,19 @@ __global__ void Marlin(
cur_k += k_iter_size * (k % b_sh_wr_iters); cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16; int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / group_blocks; int cur_group_id =
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
int4* sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = if constexpr (w_type_id != vllm::kFE2M1f.id()) {
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
} }
} }
...@@ -1152,7 +1093,7 @@ __global__ void Marlin( ...@@ -1152,7 +1093,7 @@ __global__ void Marlin(
// Determine "position" inside the thread-block (based on warp and // Determine "position" inside the thread-block (based on warp and
// thread-id) // thread-id)
int warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = int n_warps =
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
...@@ -1161,7 +1102,7 @@ __global__ void Marlin( ...@@ -1161,7 +1102,7 @@ __global__ void Marlin(
cur_k += warp_row * 16; cur_k += warp_row * 16;
int th_id = threadIdx.x % 32; auto th_id = threadIdx.x % 32;
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
int s_col_shift = int s_col_shift =
...@@ -1222,15 +1163,18 @@ __global__ void Marlin( ...@@ -1222,15 +1163,18 @@ __global__ void Marlin(
} }
} else if constexpr (group_blocks >= thread_k_blocks) { } else if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage = if (k % b_sh_wr_iters == 0) {
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * int4* sh_zp_stage =
(pipe / (group_blocks / thread_k_blocks))); sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
for (int i = 0; i < num_ints_per_thread; i++) { (pipe / (group_blocks / thread_k_blocks)));
frag_qzp[k % 2][i] = #pragma unroll
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i]; for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} }
} else { } else {
int warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4; int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps; int warp_row = warp_id / n_warps;
...@@ -1251,6 +1195,7 @@ __global__ void Marlin( ...@@ -1251,6 +1195,7 @@ __global__ void Marlin(
sh_zp_stage += cur_group_id * zp_sh_stride; sh_zp_stage += cur_group_id * zp_sh_stride;
#pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) { for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i]; (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
...@@ -1263,12 +1208,16 @@ __global__ void Marlin( ...@@ -1263,12 +1208,16 @@ __global__ void Marlin(
if constexpr (group_blocks != -1) { if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) { if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage = if (k % b_sh_wr_iters == 0) {
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * int4* sh_zp_stage =
(pipe / (group_blocks / thread_k_blocks))); sh_zp +
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
sh_zp_stage[zp_sh_rd];
}
} else { } else {
int warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4; int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps; int warp_row = warp_id / n_warps;
...@@ -1292,6 +1241,10 @@ __global__ void Marlin( ...@@ -1292,6 +1241,10 @@ __global__ void Marlin(
} }
}; };
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
};
// Execute the actual tensor core matmul of a sub-tile. // Execute the actual tensor core matmul of a sub-tile.
bool is_first_matmul_in_slice = true; bool is_first_matmul_in_slice = true;
auto matmul = [&](int k) { auto matmul = [&](int k) {
...@@ -1315,15 +1268,27 @@ __global__ void Marlin( ...@@ -1315,15 +1268,27 @@ __global__ void Marlin(
zp_quant_1 = frag_qzp[k2][1]; zp_quant_1 = frag_qzp[k2][1];
} }
dequant<scalar_t, w_type.size_bits()>(zp_quant_0, frag_zp_0); dequant_data(zp_quant_0, reinterpret_cast<scalar_t2*>(&frag_zp));
dequant<scalar_t, w_type.size_bits()>(zp_quant_1, frag_zp_1); dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
}
frag_zp[0] = frag_zp_0[0]; }
frag_zp[1] = frag_zp_0[1]; if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
frag_zp[2] = frag_zp_1[0]; if (is_new_zp) {
frag_zp[3] = frag_zp_1[1]; reinterpret_cast<int4*>(&frag_zp)[0] =
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
} }
} }
if constexpr (w_type == vllm::kFE2M1f) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
dequant_fp8_scales<scalar_t2>(s_quant_0,
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
}
// We have the m dimension as the inner loop in order to encourage overlapping // We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations. // dequantization and matmul operations.
#pragma unroll #pragma unroll
...@@ -1332,7 +1297,10 @@ __global__ void Marlin( ...@@ -1332,7 +1297,10 @@ __global__ void Marlin(
FragB frag_b1; FragB frag_b1;
int b_quant_0, b_quant_1; int b_quant_0, b_quant_1;
if constexpr (w_type.size_bits() == 4) { if constexpr (w_type_id == vllm::kFE2M1f.id()) {
b_quant_1 = frag_b_quant[k2][0][j];
b_quant_0 = b_quant_1 << 8;
} else if constexpr (w_type.size_bits() == 4) {
b_quant_0 = frag_b_quant[k2][0][j]; b_quant_0 = frag_b_quant[k2][0][j];
b_quant_1 = b_quant_0 >> 8; b_quant_1 = b_quant_0 >> 8;
} else { } else {
...@@ -1342,8 +1310,13 @@ __global__ void Marlin( ...@@ -1342,8 +1310,13 @@ __global__ void Marlin(
b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
} }
dequant<scalar_t, w_type.size_bits()>(b_quant_0, frag_b0); dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
dequant<scalar_t, w_type.size_bits()>(b_quant_1, frag_b1); dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
}
// Apply scale to frag_b0 // Apply scale to frag_b0
if constexpr (has_act_order) { if constexpr (has_act_order) {
...@@ -1351,9 +1324,9 @@ __global__ void Marlin( ...@@ -1351,9 +1324,9 @@ __global__ void Marlin(
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k][2][j], act_frag_s[k2][3][j], 1); act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { group_blocks == -1) {
int idx = (threadIdx.x / 4) % 2; int idx = (threadIdx.x / 4) % 2;
scalar_t2 s2 = Dtype::nums2num2( scalar_t2 s2 = Dtype::nums2num2(
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx], reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
...@@ -1361,18 +1334,12 @@ __global__ void Marlin( ...@@ -1361,18 +1334,12 @@ __global__ void Marlin(
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x); scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y); scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
} else if constexpr (has_zp && !is_zp_float && group_blocks != -1) { } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {
if (is_new_zp) if (is_new_zp)
frag_zp[j] = __hmul2(frag_zp[j], frag_zp[j] = __hmul2(frag_zp[j],
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j])); *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k % 2][j][0].x, frag_zp[j].x); scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k % 2][j][0].y, frag_zp[j].y); scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y);
} else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
if (is_new_zp)
frag_zpf[k2][j] = __hmul2(
frag_zpf[k2][j], *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j].x, frag_zpf[k2][j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j].y, frag_zpf[k2][j].y);
} else if constexpr (group_blocks != -1) { } else if constexpr (group_blocks != -1) {
scale<scalar_t>(frag_b0, frag_s[k2][j], 0); scale<scalar_t>(frag_b0, frag_s[k2][j], 0);
scale<scalar_t>(frag_b1, frag_s[k2][j], 1); scale<scalar_t>(frag_b1, frag_s[k2][j], 1);
...@@ -1397,7 +1364,7 @@ __global__ void Marlin( ...@@ -1397,7 +1364,7 @@ __global__ void Marlin(
auto thread_block_reduce = [&]() { auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride_threads / 2; constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) { if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride_threads; auto red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride_threads; constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
...@@ -1634,10 +1601,17 @@ __global__ void Marlin( ...@@ -1634,10 +1601,17 @@ __global__ void Marlin(
// For per-column quantization we finally apply the scale here (only for // For per-column quantization we finally apply the scale here (only for
// 4-bit) // 4-bit)
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4 && !has_zp) { w_type.size_bits() == 4 &&
(has_zp && dequant_skip_flop || !has_zp)) {
res = __hmul2(res, s[0]); res = __hmul2(res, s[0]);
} }
if constexpr (w_type == vllm::kFE2M1f) {
if (!mul_topk_weights) {
res = __hmul2(res, global_scale);
}
}
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x; ((scalar_t*)sh_red)[idx] = res.x;
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
...@@ -1728,10 +1702,12 @@ __global__ void Marlin( ...@@ -1728,10 +1702,12 @@ __global__ void Marlin(
if constexpr (has_zp && !is_zp_float && group_blocks == -1) { if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
if (i == 0) { if (i == 0) {
fetch_col_zp_to_shared(); fetch_col_zp_to_shared();
fetch_col_scale_to_shared(); if constexpr (!dequant_skip_flop) {
fetch_col_scale_to_shared();
}
} }
} }
fetch_to_shared(i, i, i < slice_iters); fetch_to_shared(i, i, i < slice_iters, i);
} }
zero_accums(); zero_accums();
...@@ -1740,8 +1716,10 @@ __global__ void Marlin( ...@@ -1740,8 +1716,10 @@ __global__ void Marlin(
fetch_to_registers(0, 0); fetch_to_registers(0, 0);
fetch_scales_to_registers(0, 0); fetch_scales_to_registers(0, 0);
fetch_zp_to_registers(0, 0); fetch_zp_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1); a_gl_rd_col += a_gl_rd_delta_o * (stages - 1);
slice_k_start_shared_fetch += tb_k * (stages - 1); if constexpr (has_act_order) {
slice_k_start_shared_fetch += tb_k * (stages - 1);
}
}; };
if (slice_iters) { if (slice_iters) {
start_pipes(); start_pipes();
...@@ -1754,43 +1732,59 @@ __global__ void Marlin( ...@@ -1754,43 +1732,59 @@ __global__ void Marlin(
// have even length meaning that the next iteration will always start at // have even length meaning that the next iteration will always start at
// index 0. // index 0.
for (int stage_group_id = 0; stage_group_id < max_num_stage_groups;
stage_group_id++) {
#pragma unroll #pragma unroll
for (int pipe = 0; pipe < stages;) { for (int pipe = 0; pipe < stages;) {
#pragma unroll #pragma unroll
for (int k = 0; k < b_sh_wr_iters; k++) { for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages); int idx =
fetch_scales_to_registers(k + 1, pipe); (pipe >= stages && stage_group_id == max_num_stage_groups - 1)
fetch_zp_to_registers(k + 1, pipe); ? (pipe - stages)
if (k == b_sh_wr_iters - 2) { : (pipe + stage_group_id * stages);
fetch_to_shared((pipe + stages - 1) % stages, pipe, fetch_to_registers(k + 1, pipe % stages, idx);
slice_iters >= stages); fetch_scales_to_registers(k + 1, pipe);
pipe++; fetch_zp_to_registers(k + 1, pipe);
wait_for_stage(); if (k == b_sh_wr_iters - 2) {
init_same_group(pipe % stages); int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1)
? (pipe - 1)
: (pipe + (stage_group_id + 1) * stages - 1);
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages, idx);
pipe++;
wait_for_stage();
init_same_group(pipe % stages);
}
matmul(k);
}
slice_iters--;
if (slice_iters == 0) {
break;
} }
matmul(k);
}
slice_iters--;
if (slice_iters == 0) {
break;
} }
}
a_remaining_load_count_in_slice = 0;
a_gl_rd += a_gl_rd_delta_o * stages; a_gl_rd_col += a_gl_rd_delta_o * stages;
slice_k_start += tb_k * stages;
slice_k_start_shared_fetch += tb_k * stages;
if constexpr (has_act_order) { if constexpr (has_act_order) {
int first_group_id = g_idx[slice_k_start]; slice_k_start += tb_k * stages;
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) { if (slice_k_start < prob_k) {
last_g_idx = prob_k - 1; slice_k_start_shared_fetch += tb_k * stages;
int first_group_id = g_idx[slice_k_start];
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
}
int last_group_id = g_idx[last_g_idx];
if (last_group_id >= sh_first_group_id + sh_num_groups) {
fetch_act_order_scales_to_shared(false, first_group_id,
last_group_id);
__syncthreads();
}
}
} }
int last_group_id = g_idx[last_g_idx]; if (slice_iters == 0) {
if (last_group_id >= sh_first_group_id + sh_num_groups) { break;
fetch_act_order_scales_to_shared(false, first_group_id, last_group_id);
__syncthreads();
} }
} }
...@@ -1802,7 +1796,8 @@ __global__ void Marlin( ...@@ -1802,7 +1796,8 @@ __global__ void Marlin(
bool last = slice_idx == slice_count - 1; bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before // For per-column scales, we only fetch them here in the final step before
// write-out // write-out
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
if (s_sh_wr_pred) { if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
...@@ -1812,7 +1807,8 @@ __global__ void Marlin( ...@@ -1812,7 +1807,8 @@ __global__ void Marlin(
} }
thread_block_reduce(); thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
cp_async_wait<0>(); cp_async_wait<0>();
__syncthreads(); __syncthreads();
...@@ -1836,7 +1832,8 @@ __global__ void Marlin( ...@@ -1836,7 +1832,8 @@ __global__ void Marlin(
// that converts the fp32 results to fp16 (so that we avoid possible // that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16) // overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 8 && !has_zp) { w_type.size_bits() == 8 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
...@@ -1877,15 +1874,30 @@ __global__ void Marlin( ...@@ -1877,15 +1874,30 @@ __global__ void Marlin(
if (last || use_atomic_add) if (last || use_atomic_add)
// only the last block in a slice actually writes the result // only the last block in a slice actually writes the result
write_result(); write_result();
if (slice_row) a_remaining_load_count_in_slice = stages; int old_slice_row = slice_row;
slice_row = 0; slice_row = 0;
slice_col_par++; slice_col_par++;
slice_col++; slice_col++;
is_first_matmul_in_slice = true; is_first_matmul_in_slice = true;
init_slice(); init_slice();
// Should we load A matrix in next slice?
// `slice_col == 0`: when move to a new moe block
// `old_slice_row > 0`:
// when the last slice is not starting from k_index == 0
// (only happen when it is the first slice of a threadblock)
// `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`:
// when the required shared memory size is larger than
// the remaining shared memory
if (slice_col == 0 || old_slice_row ||
prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) {
should_load_a = true;
} else {
should_load_a = false;
}
if (slice_iters) { if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o);
(threadIdx.x % a_gl_rd_delta_o);
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
...@@ -1900,12 +1912,10 @@ __global__ void Marlin( ...@@ -1900,12 +1912,10 @@ __global__ void Marlin(
slice_k_finish = slice_k_start + tb_k * slice_iters; slice_k_finish = slice_k_start + tb_k * slice_iters;
slice_k_start_shared_fetch = slice_k_start; slice_k_start_shared_fetch = slice_k_start;
slice_n_offset = act_s_col_tb_stride * slice_col; slice_n_offset = act_s_col_tb_stride * slice_col;
} else { } else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} }
start_pipes(); start_pipes();
} }
} }
......
...@@ -116,7 +116,7 @@ __global__ void permute_cols_kernel( ...@@ -116,7 +116,7 @@ __global__ void permute_cols_kernel(
int base_k = 0; int base_k = 0;
for (int i = 0; i < iters; i++) { for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x; auto cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k]; int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos]; out_half[cur_k] = a_row_half[src_pos];
...@@ -126,7 +126,7 @@ __global__ void permute_cols_kernel( ...@@ -126,7 +126,7 @@ __global__ void permute_cols_kernel(
if (rest) { if (rest) {
if (threadIdx.x < rest) { if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x; auto cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k]; int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos]; out_half[cur_k] = a_row_half[src_pos];
...@@ -195,7 +195,6 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, ...@@ -195,7 +195,6 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2; return load_groups * tb_n * 2;
} else { } else {
int tb_scales = tb_groups * tb_n * 2; int tb_scales = tb_groups * tb_n * 2;
...@@ -203,22 +202,24 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, ...@@ -203,22 +202,24 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
} }
} }
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int prob_m, int prob_n, int prob_k, int num_bits, int thread_m_blocks, int prob_m, int prob_n,
int group_size, bool has_act_order, bool is_k_full, int prob_k, int num_bits, int group_size,
int has_zp, int is_zp_float) { bool has_act_order, bool is_k_full, int has_zp,
int is_zp_float) {
int pack_factor = 32 / num_bits; int pack_factor = 32 / num_bits;
// Get B size // Get B size
int tb_k = th_config.thread_k; int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n; int tb_n = th_config.thread_n;
int tb_m = thread_m_blocks * 16; int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16);
// shm size for block_sorted_ids/block_topk_weights // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 4 * 2; int sh_block_meta_size = tb_m * 4;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_s_size = int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full); group_size, has_act_order, is_k_full);
...@@ -233,16 +234,17 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, ...@@ -233,16 +234,17 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
sh_zp_size = sh_s_size / 2; sh_zp_size = sh_s_size / 2;
} }
int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size + int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
sh_g_idx_size + sh_block_meta_size; sh_zp_size + sh_g_idx_size + sh_block_meta_size;
return total_size; return total_size;
} }
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int prob_m, int prob_n, int prob_k, int num_bits, int thread_m_blocks, int prob_m, int prob_n, int prob_k,
int group_size, bool has_act_order, bool is_k_full, int num_bits, int group_size, bool has_act_order,
int has_zp, int is_zp_float, int max_shared_mem) { bool is_k_full, int has_zp, int is_zp_float,
int max_shared_mem) {
// Sanity // Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 || if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) { th_config.num_threads == -1) {
...@@ -266,143 +268,129 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, ...@@ -266,143 +268,129 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
// Check that pipeline fits into cache // Check that pipeline fits into cache
int cache_size = get_kernel_cache_size( int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
has_act_order, is_k_full, has_zp, is_zp_float); num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size <= max_shared_mem; return cache_size <= max_shared_mem;
} }
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
NUM_THREADS, IS_ZP_FLOAT) \ else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \ m_block_size_8 == M_BLOCK_SIZE_8 && \
m_block_size_8 == M_BLOCK_SIZE_8 && \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ is_zp_float == IS_ZP_FLOAT) { \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
is_zp_float == IS_ZP_FLOAT) { \ THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \ pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
IS_ZP_FLOAT>; \
} }
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \ // this is the most common cases
false) \ // BIGGROUP: cases for big group size (group_blocks in [-1, 8])
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ // FZP: cases for float-zero-point (is_zp_float = true)
NUM_THREADS, false) \ // ACT: cases for act order case (group_blocks == 0)
\ // FP4: cases for nvfp4(e2m1) (group_blocks == 1)
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \ #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ #define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
\ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \ #define COMMON_GET_IF(W_TYPE) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
NUM_THREADS, false) \ COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
NUM_THREADS, false) \ COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \ #define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
\ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ #define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
\ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
NUM_THREADS, false)
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
false) \ #define FP4_GET_IF(W_TYPE) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \ FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
false) \ FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \ FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
false) \ FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \ #define BIGGROUP_GET_IF(W_TYPE) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
NUM_THREADS, false) \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
NUM_THREADS, false) \ BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) // We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ #define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \ #define FZP_GET_IF(W_TYPE) \
\ FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
NUM_THREADS, false) \ FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ FZP_GET_IF_M234(W_TYPE, 8, 4, 128)
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
// We currently have 4-bit models only with group_blocks == 4 // We currently have 4-bit models only with group_blocks == 4
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
true) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \ #define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
NUM_THREADS, true) \ _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
NUM_THREADS, true) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ #define ACT_GET_IF(W_TYPE) \
NUM_THREADS, true) ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128)
template <typename scalar_t> template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
...@@ -415,23 +403,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, ...@@ -415,23 +403,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
auto kernel = MarlinDefault; auto kernel = MarlinDefault;
if (false) { if (false) {
} }
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256)
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128)
GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256)
GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128)
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256) COMMON_GET_IF(vllm::kU4)
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128) COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256) BIGGROUP_GET_IF(vllm::kFE4M3fn)
GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128)
AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256) FP4_GET_IF(vllm::kFE2M1f)
AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128)
AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256) ACT_GET_IF(vllm::kU4B8)
AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128) ACT_GET_IF(vllm::kU8B128)
return kernel; return kernel;
} }
...@@ -457,19 +439,19 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, ...@@ -457,19 +439,19 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
for (int i = 0; i < thread_configs_size; i++) { for (int i = 0; i < thread_configs_size; i++) {
thread_config_t th_config = thread_configs[i]; thread_config_t th_config = thread_configs[i];
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
num_bits, group_size, has_act_order, is_k_full, has_zp, prob_n, prob_k, num_bits, group_size, has_act_order,
is_zp_float, max_shared_mem)) { is_k_full, has_zp, is_zp_float, max_shared_mem)) {
continue; continue;
} }
int cache_size = get_kernel_cache_size( int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
group_size, has_act_order, is_k_full, has_zp, is_zp_float); num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
int group_blocks = 0; int group_blocks = 0;
if (!has_act_order) { if (!has_act_order) {
group_blocks = group_size == -1 ? -1 : group_size / 16; group_blocks = group_size == -1 ? -1 : (group_size / 16);
} }
auto kernel = get_marlin_kernel<scalar_t>( auto kernel = get_marlin_kernel<scalar_t>(
...@@ -501,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, ...@@ -501,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
template <typename scalar_t> template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
void* zp, void* g_idx, void* perm, void* a_tmp, void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
void* sorted_token_ids, void* expert_ids, void* sorted_token_ids, void* expert_ids,
void* num_tokens_past_padded, void* topk_weights, void* num_tokens_past_padded, void* topk_weights,
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
...@@ -520,8 +502,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, ...@@ -520,8 +502,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else { } else {
TORCH_CHECK( TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128, q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = ",
q_type.str()); q_type.str());
} }
...@@ -555,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, ...@@ -555,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int4* C_ptr = (int4*)C; int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp; int4* C_tmp_ptr = (int4*)C_tmp;
const int4* s_ptr = (const int4*)s; const int4* s_ptr = (const int4*)s;
const uint16_t* s2_ptr = (const uint16_t*)s2;
const int4* zp_ptr = (const int4*)zp; const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx; const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm; const int* perm_ptr = (const int*)perm;
...@@ -631,18 +616,18 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, ...@@ -631,18 +616,18 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int thread_k_blocks = thread_k / 16; int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16; int thread_n_blocks = thread_n / 16;
TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n, TORCH_CHECK(
prob_k, num_bits, group_size, has_act_order, is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
is_k_full, has_zp, is_zp_float, max_shared_mem), prob_n, prob_k, num_bits, group_size, has_act_order,
"Invalid thread config: thread_m_blocks = ", thread_m_blocks, is_k_full, has_zp, is_zp_float, max_shared_mem),
", thread_k = ", thread_tfg.thread_k, "Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_n = ", thread_tfg.thread_n, ", thread_k = ", thread_tfg.thread_k,
", num_threads = ", thread_tfg.num_threads, " for MKN = [", ", thread_n = ", thread_tfg.thread_n,
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, ", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ",
", group_size = ", group_size, prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, ", group_size = ", group_size, ", has_act_order = ", has_act_order,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, ", is_k_full = ", is_k_full, ", has_zp = ", has_zp,
", max_shared_mem = ", max_shared_mem); ", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem);
auto kernel = get_marlin_kernel<scalar_t>( auto kernel = get_marlin_kernel<scalar_t>(
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
...@@ -663,10 +648,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, ...@@ -663,10 +648,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// avoid ">>>" being formatted to "> > >" // avoid ">>>" being formatted to "> > >"
// clang-format off // clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>( kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce); prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
// clang-format on // clang-format on
} }
...@@ -675,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, ...@@ -675,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
torch::Tensor moe_wna16_marlin_gemm( torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none, torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_q_weight, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none, std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none, std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace, std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
...@@ -826,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm( ...@@ -826,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm(
} }
} }
torch::Tensor global_scale;
if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
"global_scale can only be used for float4_e2m1f.");
} else {
global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
"the global_scale parameter must be passed for float4_e2m1f.");
}
torch::Tensor b_zeros; torch::Tensor b_zeros;
if (b_zeros_or_none.has_value()) { if (b_zeros_or_none.has_value()) {
b_zeros = b_zeros_or_none.value(); b_zeros = b_zeros_or_none.value();
...@@ -838,13 +835,15 @@ torch::Tensor moe_wna16_marlin_gemm( ...@@ -838,13 +835,15 @@ torch::Tensor moe_wna16_marlin_gemm(
if (has_zp) { if (has_zp) {
TORCH_CHECK( TORCH_CHECK(
b_q_type == vllm::kU4, b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
} else { } else {
TORCH_CHECK( TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
b_q_type.str()); "float4_e2m1f when "
"has_zp = False. Got = ",
b_q_type.str());
} }
if (has_zp && is_zp_float) { if (has_zp && is_zp_float) {
...@@ -889,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm( ...@@ -889,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm(
int dev = a.get_device(); int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) { if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<half>( MARLIN_NAMESPACE_NAME::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(), a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(), c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(), a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
...@@ -901,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm( ...@@ -901,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm(
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float); use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) { } else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>( MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(), global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
......
...@@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
} }
if (use_global_memory) { if (use_global_memory) {
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors // tensors
...@@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
cumsum_buffer.data_ptr<int32_t>()); cumsum_buffer.data_ptr<int32_t>());
}); });
} else if (use_i16) { } else if (use_i16) {
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// set dynamic shared mem // set dynamic shared mem
auto kernel = auto kernel =
...@@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids.numel()); topk_ids.numel());
}); });
} else { } else {
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto kernel = auto kernel =
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>; vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
...@@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
TORCH_CHECK(num_experts == 256, TORCH_CHECK(num_experts == 256,
"sgl_moe_align_block_size kernel only supports deepseek v3."); "sgl_moe_align_block_size kernel only supports deepseek v3.");
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `cumsum` tensors // calc needed amount of shared mem for `cumsum` tensors
auto options_int = auto options_int =
......
#include <c10/core/ScalarType.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "permute_unpermute_kernels/moe_permute_unpermute_kernel.h"
#include "permute_unpermute_kernels/dispatch.h"
#include "core/registration.h"
void moe_permute(
const torch::Tensor& input, // [n_token, hidden]
const torch::Tensor& topk_weights, //[n_token, topk]
torch::Tensor& topk_ids, // [n_token, topk]
const torch::Tensor& token_expert_indicies, // [n_token, topk]
const std::optional<torch::Tensor>& expert_map, // [n_expert]
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor&
permuted_input, // [topk * n_token/align_block_size_m, hidden]
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
torch::Tensor& m_indices) { // [align_expand_m]
TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float,
"topk_weights must be float32");
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
"expert_first_token_offset must be int64");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32");
TORCH_CHECK(token_expert_indicies.scalar_type() == at::ScalarType::Int,
"token_expert_indicies must be int32");
TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int,
"src_row_id2dst_row_id_map must be int32");
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
"expert_first_token_offset shape != n_local_expert+1")
TORCH_CHECK(
src_row_id2dst_row_id_map.sizes() == token_expert_indicies.sizes(),
"token_expert_indicies shape must be same as src_row_id2dst_row_id_map");
auto n_token = input.sizes()[0];
auto n_hidden = input.sizes()[1];
auto align_block_size_value =
align_block_size.has_value() ? align_block_size.value() : -1;
auto stream = at::cuda::getCurrentCUDAStream().stream();
const long sorter_size =
CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert);
auto sort_workspace = torch::empty(
{sorter_size},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
auto permuted_experts_id = torch::empty_like(topk_ids);
auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map);
auto align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset);
CubKeyValueSorter sorter{};
int64_t* valid_num_ptr = nullptr;
// pre-process kernel for expert-parallelism:
// no local expert id plus "n_expert" offset for priority to local expert
// map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1]
// For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id
// [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids
// and map global expert id [2, 3] to local_expert id [0, 1] and map global
// expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map
// operation is to make local expert high priority in following sort topk_ids
// and scan local expert_first_token_offset for each ep rank for next group
// gemm.
if (expert_map.has_value()) {
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
valid_num_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
preprocessTopkIdLauncher(get_ptr<int>(topk_ids), n_token * topk,
expert_map_ptr, n_expert, stream);
}
// expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indicies),
get_ptr<int>(permuted_experts_id),
get_ptr<int>(dst_row_id2src_row_id_map),
get_ptr<int64_t>(expert_first_token_offset), n_token,
n_expert, n_local_expert, topk, sorter,
get_ptr<int>(sort_workspace), stream);
// dispatch expandInputRowsKernelLauncher
MOE_DISPATCH(input.scalar_type(), [&] {
expandInputRowsKernelLauncher<scalar_t>(
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
get_ptr<float>(topk_weights), get_ptr<int>(permuted_experts_id),
get_ptr<int>(dst_row_id2src_row_id_map),
get_ptr<int>(src_row_id2dst_row_id_map),
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
n_hidden, topk, n_local_expert, align_block_size_value, stream);
});
// get m_indices and update expert_first_token_offset with align block
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
get_ptr<int64_t>(align_expert_first_token_offset),
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
stream);
if (align_block_size.has_value()) {
// update align_expert_first_token_offset
expert_first_token_offset.copy_(align_expert_first_token_offset);
}
}
void moe_unpermute(
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
const torch::Tensor& topk_weights, //[n_token, topk]
const torch::Tensor& topk_ids, // [n_token, topk]
const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
const torch::Tensor& expert_first_token_offset, // [n_local_expert+1]
int64_t n_expert, int64_t n_local_expert, int64_t topk,
torch::Tensor& hidden_states // [n_token, hidden]
) {
TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(),
"topk_ids shape must be same as src_row_id2dst_row_id_map");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32");
TORCH_CHECK(
permuted_hidden_states.scalar_type() == hidden_states.scalar_type(),
"topk_ids dtype must be same as src_row_id2dst_row_id_map");
auto n_token = hidden_states.size(0);
auto n_hidden = hidden_states.size(1);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int64_t* valid_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
MOE_DISPATCH(hidden_states.scalar_type(), [&] {
finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>(
get_ptr<scalar_t>(permuted_hidden_states),
get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights),
get_ptr<int>(src_row_id2dst_row_id_map), get_ptr<int>(topk_ids),
n_token, n_hidden, topk, valid_ptr, stream);
});
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_permute", &moe_permute);
m.impl("moe_unpermute", &moe_unpermute);
}
\ No newline at end of file
...@@ -108,11 +108,11 @@ __device__ inline void dequant<half2, 4>(int q, half2* res) { ...@@ -108,11 +108,11 @@ __device__ inline void dequant<half2, 4>(int q, half2* res) {
const int MUL = 0x2c002c00; const int MUL = 0x2c002c00;
const int ADD = 0xd400d400; const int ADD = 0xd400d400;
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
q >>= 8; q >>= 8;
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo0), res[0] = __hsub2(*reinterpret_cast<half2*>(&lo0),
*reinterpret_cast<const half2*>(&SUB)); *reinterpret_cast<const half2*>(&SUB));
...@@ -149,13 +149,13 @@ __device__ inline void dequant<nv_bfloat162, 4>(int q, nv_bfloat162* res) { ...@@ -149,13 +149,13 @@ __device__ inline void dequant<nv_bfloat162, 4>(int q, nv_bfloat162* res) {
static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300; static constexpr uint32_t EX = 0x43004300;
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4; q >>= 4;
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4; q >>= 4;
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4; q >>= 4;
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
static constexpr uint32_t MUL = 0x3F803F80; static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300; static constexpr uint32_t ADD = 0xC300C300;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment