Commit 7e63ef82 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0' into v0.14.0-dev

parents 8cbcac5d b17039bc
#include "cpu_types.hpp"
#include "scratchpad_manager.h"
#include "utils.hpp"
#include "cpu/cpu_types.hpp"
#include "cpu/utils.hpp"
#ifdef CPU_CAPABILITY_AMXBF16
#include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
......@@ -158,7 +157,7 @@ void cpu_gemm_wna16_impl(
// a simple schedule policy, just to hold more B tiles in L2 and make sure
// each thread has tasks
const int32_t n_partition_size = [&]() {
const int64_t cache_size = cpu_utils::get_l2_size();
const int64_t cache_size = cpu_utils::get_available_l2_size();
int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t));
int64_t ps_thread_limit = n_size / thread_num;
ps_cache_limit =
......@@ -179,7 +178,7 @@ void cpu_gemm_wna16_impl(
const int64_t b_buffer_offset = 0;
const int64_t c_buffer_offset = b_buffer_size;
const int64_t buffer_size = b_buffer_size + c_buffer_size;
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(buffer_size *
cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(buffer_size *
thread_num);
alignas(64) cpu_utils::Counter counter;
......@@ -190,7 +189,8 @@ void cpu_gemm_wna16_impl(
scalar_t* __restrict__ b_buffer = nullptr;
float* __restrict__ c_buffer = nullptr;
{
uint8_t* buffer_ptr = DNNLScratchPadManager::get_dnnl_scratchpad_manager()
uint8_t* buffer_ptr =
cpu_utils::ScratchPadManager::get_scratchpad_manager()
->get_data<uint8_t>() +
thread_id * buffer_size;
b_buffer = reinterpret_cast<scalar_t*>(buffer_ptr + b_buffer_offset);
......
......@@ -4,8 +4,8 @@
#include "common/memory_desc.hpp"
#include "common/memory.hpp"
#include "dnnl_helper.h"
#include "scratchpad_manager.h"
#include "cpu/utils.hpp"
#include "cpu/dnnl_helper.h"
static dnnl::engine& default_engine() {
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
......@@ -274,7 +274,7 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5);
scratchpad_storage->set_data_handle(
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
cpu_utils::ScratchPadManager::get_scratchpad_manager()->get_data<void>());
matmul.execute(default_stream(), memory_cache_);
default_stream().wait();
......@@ -294,7 +294,7 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
return m_size_cache_->get_or_create(key, [&]() {
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
auto manager = cpu_utils::ScratchPadManager::get_scratchpad_manager();
manager->realloc(desc.scratchpad_desc().get_size());
return dnnl::matmul(desc);
});
......@@ -470,7 +470,7 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
scratchpad_storage->set_data_handle(
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
cpu_utils::ScratchPadManager::get_scratchpad_manager()->get_data<void>());
matmul.execute(default_stream(), memory_cache_);
default_stream().wait();
......@@ -486,7 +486,7 @@ dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache(
}
return m_size_cache_->get_or_create(key, [&]() {
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
auto manager = cpu_utils::ScratchPadManager::get_scratchpad_manager();
manager->realloc(desc.scratchpad_desc().get_size());
return dnnl::matmul(desc);
});
......
......@@ -235,6 +235,39 @@ class MicroGemm<cpu_utils::ISA::AMX, scalar_t> {
}
}
static void pack_weight(const scalar_t* __restrict__ weight,
scalar_t* __restrict__ packed_weight,
const int32_t output_size, const int32_t input_size) {
constexpr int32_t elem_num_per_group = 4 / sizeof(scalar_t);
TORCH_CHECK_EQ(output_size % 16, 0);
TORCH_CHECK_EQ(input_size % (16 * elem_num_per_group), 0);
const int32_t output_group_num = output_size / 16;
const int32_t input_32b_num = input_size / elem_num_per_group;
for (int32_t output_group_idx = 0; output_group_idx < output_group_num;
++output_group_idx) {
const int32_t* __restrict__ weight_32b =
reinterpret_cast<const int32_t*>(weight);
int32_t* __restrict__ packed_weight_32b =
reinterpret_cast<int32_t*>(packed_weight);
for (int32_t output_idx = 0; output_idx < 16; ++output_idx) {
for (int32_t weight_offset = 0, packed_offset = 0;
weight_offset < input_32b_num;
++weight_offset, packed_offset += 16) {
packed_weight_32b[packed_offset] = weight_32b[weight_offset];
}
// update
weight_32b += input_32b_num;
packed_weight_32b += 1;
}
// update
weight += 16 * input_size;
packed_weight += 16 * input_size;
}
}
private:
alignas(64) __tilecfg amx_tile_config_;
int32_t curr_m_;
......
......@@ -13,6 +13,9 @@ namespace cpu_micro_gemm {
#define CPU_MICRO_GEMM_PARAMS \
a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c
// Note: weights for MicroGemm should be packed as (output_size / 16) contiguous
// blocks, means the logical shape of blocks is [16, input_size]. And the actual
// layout of blocks can be ISA-specific.
template <cpu_utils::ISA isa, typename scalar_t>
class MicroGemm {
public:
......@@ -86,6 +89,41 @@ FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr,
curr_d += ldd;
}
}
template <int32_t n_size, typename scalar_t>
FORCE_INLINE void add_bias_epilogue(float* c_ptr, float* d_ptr,
scalar_t* __restrict__ bias_ptr,
const int32_t m, const int64_t ldc,
const int64_t ldd) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
static_assert(n_size % 16 == 0);
constexpr int32_t n_group_num = n_size / 16;
static_assert(n_group_num <= 16);
vec_op::FP32Vec16 bias_vecs[n_group_num];
scalar_t* __restrict__ curr_bias = bias_ptr;
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t i) {
scalar_vec_t vec(curr_bias);
bias_vecs[i] = vec_op::FP32Vec16(vec);
curr_bias += 16;
});
float* curr_c = c_ptr;
float* curr_d = d_ptr;
for (int32_t i = 0; i < m; ++i) {
float* curr_c_iter = curr_c;
float* curr_d_iter = curr_d;
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t n_g_idx) {
vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
c_vec_fp32 = c_vec_fp32 + bias_vecs[n_g_idx];
c_vec_fp32.save(curr_d_iter);
curr_c_iter += 16;
curr_d_iter += 16;
});
curr_c += ldc;
curr_d += ldd;
}
}
} // namespace cpu_micro_gemm
#endif
......@@ -109,6 +109,25 @@ class MicroGemm<cpu_utils::ISA::VEC, scalar_t> {
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
TileGemm82<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
}
// Note: pack contiguous weight [output_size, input_size] as contiguous
// packed weight [output_size / 16, input_size, 16]
static void pack_weight(const scalar_t* __restrict__ weight,
scalar_t* __restrict__ packed_weight,
const int32_t output_size, const int32_t input_size) {
TORCH_CHECK_EQ(output_size % 16, 0);
for (int32_t o_idx = 0; o_idx < output_size; ++o_idx) {
const scalar_t* __restrict__ curr_weight = weight + o_idx * input_size;
scalar_t* __restrict__ curr_packed_weight =
packed_weight + (o_idx / 16) * (16 * input_size) + o_idx % 16;
for (int32_t i_idx = 0; i_idx < input_size; ++i_idx) {
*curr_packed_weight = *curr_weight;
curr_packed_weight += 16;
++curr_weight;
}
}
}
};
} // namespace cpu_micro_gemm
......
#include <cstdlib>
#include "scratchpad_manager.h"
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
this->realloc(allocation_unit * 128);
}
void DNNLScratchPadManager::realloc(size_t new_size) {
new_size = round(new_size);
if (new_size > size_) {
if (ptr_ != nullptr) {
std::free(ptr_);
}
ptr_ = std::aligned_alloc(64, new_size);
size_ = new_size;
}
}
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
static DNNLScratchPadManager manager;
return &manager;
}
#ifndef SCRATCHPAD_MANAGER_H
#define SCRATCHPAD_MANAGER_H
#include <cstddef>
#include <cstdio>
class DNNLScratchPadManager {
public:
static constexpr size_t allocation_unit = 4 * 1024; // 4KB
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
DNNLScratchPadManager();
template <typename T>
T* get_data() {
return reinterpret_cast<T*>(ptr_);
}
static size_t round(size_t size) {
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
}
void realloc(size_t new_size);
private:
size_t size_;
void* ptr_;
};
#endif
......@@ -110,6 +110,17 @@ void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
const std::optional<torch::Tensor>& bias,
const int64_t pack_factor, const std::string& isa_hint);
void prepack_moe_weight(const torch::Tensor& weight,
torch::Tensor& packed_weight, const std::string& isa);
void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
const torch::Tensor& w13, const torch::Tensor& w2,
const std::optional<torch::Tensor>& w13_bias,
const std::optional<torch::Tensor>& w2_bias,
const torch::Tensor& topk_weights,
const torch::Tensor& topk_id, const std::string& act,
const std::string& isa);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
......@@ -296,6 +307,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"pack_factor, str isa_hint) -> ()");
ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
#endif
// fused moe
#if defined(__AVX512F__)
ops.def(
"prepack_moe_weight(Tensor weight, Tensor(a1!) packed_weight, str isa) "
"-> ()");
ops.impl("prepack_moe_weight", torch::kCPU, &prepack_moe_weight);
ops.def(
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
"str act, str isa) -> ()");
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
#endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
......
......@@ -10,7 +10,7 @@
#define gettid() syscall(SYS_gettid)
#endif
#include "cpu_types.hpp"
#include "cpu/utils.hpp"
#ifdef VLLM_NUMA_DISABLED
std::string init_cpu_threads_env(const std::string& cpu_ids) {
......@@ -24,6 +24,8 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
#ifndef VLLM_NUMA_DISABLED
std::string init_cpu_threads_env(const std::string& cpu_ids) {
bitmask* omp_cpu_mask = numa_parse_cpustring_all(cpu_ids.c_str());
TORCH_CHECK(omp_cpu_mask != nullptr,
"Failed to parse CPU string: " + cpu_ids);
TORCH_CHECK(omp_cpu_mask->size > 0);
std::vector<int> omp_cpu_ids;
omp_cpu_ids.reserve(omp_cpu_mask->size);
......@@ -44,20 +46,12 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
// Memory node binding
if (numa_available() != -1) {
int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front());
std::set<int> node_ids;
for (const auto& cpu_id : omp_cpu_ids) {
int node_id = numa_node_of_cpu(cpu_id);
if (node_id != -1) {
node_ids.insert(node_id);
}
if (node_id != mem_node_id) {
TORCH_WARN("CPU ", cpu_id, " is on NUMA node ", node_id, ", but CPU ",
omp_cpu_ids.front(), " is on NUMA node ", mem_node_id,
". All CPUs should be on the same NUMA node for optimal "
"performance. Memory will be bound to NUMA node ",
mem_node_id, ".");
}
}
// Concatenate all node_ids into a single comma-separated string
if (!node_ids.empty()) {
......@@ -70,7 +64,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
}
bitmask* mask = numa_parse_nodestring(node_ids_str.c_str());
bitmask* src_mask = numa_get_membind();
bitmask* src_mask = numa_get_mems_allowed();
int pid = getpid();
......@@ -83,14 +77,45 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
std::to_string(errno));
}
// restrict memory allocation node.
// Restrict memory allocation to the selected NUMA node(s).
// Enhances memory locality for the threads bound to those NUMA CPUs.
if (node_ids.size() > 1) {
errno = 0;
numa_set_interleave_mask(mask);
if (errno != 0) {
TORCH_WARN("numa_set_interleave_mask failed. errno: " +
std::to_string(errno));
} else {
TORCH_WARN(
"NUMA binding: Using INTERLEAVE policy for memory "
"allocation across multiple NUMA nodes (nodes: " +
node_ids_str +
"). Memory allocations will be "
"interleaved across the specified NUMA nodes.");
}
} else {
errno = 0;
numa_set_membind(mask);
if (errno != 0) {
TORCH_WARN("numa_set_membind failed. errno: " +
std::to_string(errno));
} else {
TORCH_WARN(
"NUMA binding: Using MEMBIND policy for memory "
"allocation on the NUMA nodes (" +
node_ids_str +
"). Memory allocations will be "
"strictly bound to these NUMA nodes.");
}
}
numa_set_strict(1);
numa_free_nodemask(mask);
numa_free_nodemask(src_mask);
} else {
TORCH_WARN("numa_parse_nodestring or numa_get_membind failed. errno: " +
TORCH_WARN(
"numa_parse_nodestring or numa_get_run_node_mask failed. errno: " +
std::to_string(errno));
}
}
......@@ -138,4 +163,26 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
return ss.str();
}
#endif
#endif // VLLM_NUMA_DISABLED
namespace cpu_utils {
ScratchPadManager::ScratchPadManager() : size_(0), ptr_(nullptr) {
this->realloc(allocation_unit * 128);
}
void ScratchPadManager::realloc(size_t new_size) {
new_size = round(new_size);
if (new_size > size_) {
if (ptr_ != nullptr) {
std::free(ptr_);
}
ptr_ = std::aligned_alloc(64, new_size);
size_ = new_size;
}
}
ScratchPadManager* ScratchPadManager::get_scratchpad_manager() {
static ScratchPadManager manager;
return &manager;
}
} // namespace cpu_utils
......@@ -2,19 +2,24 @@
#define UTILS_HPP
#include <atomic>
#include <cassert>
#include <cstdint>
#include <unistd.h>
#include <ATen/cpu/Utils.h>
#if defined(__APPLE__)
#include <sys/sysctl.h>
#endif
#include "cpu_types.hpp"
#include "cpu/cpu_types.hpp"
namespace cpu_utils {
enum class ISA { AMX, VEC };
inline ISA get_isa(const std::string& isa) {
if (isa == "amx") {
return ISA::AMX;
} else if (isa == "vec") {
return ISA::VEC;
} else {
TORCH_CHECK(false, "Invalid isa type: " + isa);
}
}
template <typename T>
struct VecTypeTrait {
using vec_t = void;
......@@ -32,10 +37,12 @@ struct VecTypeTrait<c10::BFloat16> {
};
#endif
#if !defined(__powerpc__)
template <>
struct VecTypeTrait<c10::Half> {
using vec_t = vec_op::FP16Vec16;
};
#endif
struct Counter {
std::atomic<int64_t> counter;
......@@ -48,26 +55,66 @@ struct Counter {
int64_t acquire_counter() { return counter++; }
};
inline int64_t get_l2_size() {
inline int64_t get_available_l2_size() {
static int64_t size = []() {
#if defined(__APPLE__)
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
int64_t l2_cache_size = 0;
size_t len = sizeof(l2_cache_size);
if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
l2_cache_size > 0) {
return l2_cache_size >> 1; // use 50% of L2 cache
}
// Fallback if sysctlbyname fails
return 128LL * 1024 >> 1; // use 50% of 128KB
#else
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
assert(l2_cache_size != -1);
const uint32_t l2_cache_size = at::cpu::L2_cache_size();
return l2_cache_size >> 1; // use 50% of L2 cache
#endif
}();
return size;
}
template <int32_t alignment_v, typename T>
inline T round_up(T size) {
T alignment = alignment_v;
return (((size + alignment - 1) / alignment) * alignment);
}
template <int32_t alignment_v, typename T>
inline T round_down(T size) {
T alignment = alignment_v;
return (size / alignment) * alignment;
}
template <typename T>
inline void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
int32_t stride) {
std::stringstream ss;
ss << std::fixed << std::setprecision(5) << name << ": [\n";
auto* curr_logits_buffer = ptr;
for (int32_t m = 0; m < row; ++m) {
for (int32_t n = 0; n < col; ++n) {
ss << curr_logits_buffer[n] << ", ";
}
ss << "\n";
curr_logits_buffer += stride;
}
ss << "]\n";
std::printf("%s", ss.str().c_str());
}
class ScratchPadManager {
public:
static constexpr size_t allocation_unit = 4 * 1024; // 4KB
static ScratchPadManager* get_scratchpad_manager();
ScratchPadManager();
template <typename T>
T* get_data() {
return reinterpret_cast<T*>(ptr_);
}
static size_t round(size_t size) {
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
}
void realloc(size_t new_size);
private:
size_t size_;
void* ptr_;
};
} // namespace cpu_utils
#endif
......@@ -107,6 +107,16 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
prop.location.id = device;
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
#ifndef USE_ROCM
int flag = 0;
CUDA_CHECK(cuDeviceGetAttribute(
&flag, CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED,
device));
if (flag) { // support GPUDirect RDMA if possible
prop.allocFlags.gpuDirectRDMACapable = 1;
}
#endif
#ifndef USE_ROCM
// Allocate memory using cuMemCreate
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
......
......@@ -107,7 +107,8 @@ __global__ void fusedQKNormRopeKernel(
void const* k_weight_void, // RMSNorm weights for key
void const* cos_sin_cache_void, // Pre-computed cos/sin cache
int64_t const* position_ids, // Position IDs for RoPE
int const num_tokens // Number of tokens
int const num_tokens, // Number of tokens
int const rotary_dim // Dimension for RoPE
) {
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
......@@ -227,22 +228,24 @@ __global__ void fusedQKNormRopeKernel(
// Calculate cache pointer for this position - similar to
// pos_encoding_kernels.cu
T_cache const* cache_ptr = cos_sin_cache + pos_id * head_dim;
int const embed_dim = head_dim / 2;
T_cache const* cache_ptr = cos_sin_cache + pos_id * rotary_dim;
int const embed_dim = rotary_dim / 2;
T_cache const* cos_ptr = cache_ptr;
T_cache const* sin_ptr = cache_ptr + embed_dim;
int const rotary_lanes = rotary_dim / numElemsPerThread; // rotary range
if (laneId < rotary_lanes) {
if constexpr (interleave) {
// Perform interleaving. Use pre-computed cos/sin values.
#pragma unroll
for (int i = 0; i < numElemsPerThread / 2; ++i) {
int const idx0 = 2 * i;
int const idx1 = 2 * i + 1;
// Global dimension index in the head
int const dim_idx = laneId * numElemsPerThread + idx0;
float const val0 = elements[idx0];
float const val1 = elements[idx1];
int const dim_idx = laneId * numElemsPerThread + idx0;
int const half_dim = dim_idx / 2;
float const cos_val =
CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
......@@ -255,19 +258,20 @@ __global__ void fusedQKNormRopeKernel(
} else {
// Before data exchange with in warp, we need to sync.
__syncwarp();
// Get the data from the other half of the warp. Use pre-computed cos/sin
// values.
int pairOffset = (rotary_dim / 2) / numElemsPerThread;
// Get the data from the other half of the warp. Use pre-computed
// cos/sin values.
#pragma unroll
for (int i = 0; i < numElemsPerThread; i++) {
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16);
if (laneId < 16) {
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], pairOffset);
if (laneId < pairOffset) {
elements2[i] = -elements2[i];
}
int dim_idx = laneId * numElemsPerThread + i;
dim_idx = (dim_idx * 2) % head_dim;
dim_idx = (dim_idx * 2) % rotary_dim;
int half_dim = dim_idx / 2;
// Use pre-computed cos/sin from cache
float cos_val = CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
float sin_val = CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
......@@ -276,7 +280,7 @@ __global__ void fusedQKNormRopeKernel(
// __shfl_xor_sync does not provide memfence. Need to sync again.
__syncwarp();
}
}
// Store.
{
vec_T vec;
......@@ -312,10 +316,10 @@ template <typename scalar_t_in, typename scalar_t_cache>
void launchFusedQKNormRope(void* qkv, int const num_tokens,
int const num_heads_q, int const num_heads_k,
int const num_heads_v, int const head_dim,
float const eps, void const* q_weight,
void const* k_weight, void const* cos_sin_cache,
bool const interleave, int64_t const* position_ids,
cudaStream_t stream) {
int const rotary_dim, float const eps,
void const* q_weight, void const* k_weight,
void const* cos_sin_cache, bool const interleave,
int64_t const* position_ids, cudaStream_t stream) {
constexpr int blockSize = 256;
int const warpsPerBlock = blockSize / 32;
......@@ -332,7 +336,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 64, INTERLEAVE>
<<<gridDim, blockDim, 0, stream>>>(
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
k_weight, cos_sin_cache, position_ids, num_tokens);
k_weight, cos_sin_cache, position_ids, num_tokens, rotary_dim);
});
break;
case 128:
......@@ -340,7 +344,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 128, INTERLEAVE>
<<<gridDim, blockDim, 0, stream>>>(
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
k_weight, cos_sin_cache, position_ids, num_tokens);
k_weight, cos_sin_cache, position_ids, num_tokens, rotary_dim);
});
break;
case 256:
......@@ -348,7 +352,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 256, INTERLEAVE>
<<<gridDim, blockDim, 0, stream>>>(
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
k_weight, cos_sin_cache, position_ids, num_tokens);
k_weight, cos_sin_cache, position_ids, num_tokens, rotary_dim);
});
break;
default:
......@@ -392,8 +396,11 @@ void fused_qk_norm_rope(
"Query weights size must match head dimension");
TORCH_CHECK(k_weight.size(0) == head_dim,
"Key weights size must match head dimension");
TORCH_CHECK(cos_sin_cache.size(1) == head_dim,
"Cos/sin cache dimension must match head_dim");
TORCH_CHECK(cos_sin_cache.size(1) % 2 == 0, "rotary_dim must be even");
TORCH_CHECK(cos_sin_cache.size(1) <= head_dim,
"rotary_dim must be less than or equal to head_dim");
TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() &&
qkv.scalar_type() == k_weight.scalar_type(),
"qkv, q_weight and k_weight must have the same dtype");
......@@ -419,7 +426,8 @@ void fused_qk_norm_rope(
qkv.data_ptr(), static_cast<int>(num_tokens),
static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
static_cast<int>(num_heads_v), static_cast<int>(head_dim),
static_cast<float>(eps), q_weight.data_ptr(), k_weight.data_ptr(),
static_cast<int>(cos_sin_cache.size(1)), static_cast<float>(eps),
q_weight.data_ptr(), k_weight.data_ptr(),
cos_sin_cache.data_ptr(), !is_neox,
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
stream);
......
......@@ -446,15 +446,19 @@ __device__ inline T apply_sigmoid(T val) {
template <ScoringFunc SF, typename T>
__device__ inline T apply_scoring(T val) {
if constexpr (SF == SCORING_SIGMOID) {
if constexpr (SF == SCORING_NONE) {
return val;
} else if constexpr (SF == SCORING_SIGMOID) {
return apply_sigmoid(val);
} else {
static_assert(SF == SCORING_NONE || SF == SCORING_SIGMOID,
"Unsupported ScoringFunc in apply_scoring");
return val;
}
}
template <typename T, ScoringFunc SF>
__device__ void topk_with_k2(T* output, T const* input, T const* bias,
template <typename T, typename BiasT, ScoringFunc SF>
__device__ void topk_with_k2(T* output, T const* input, BiasT const* bias,
cg::thread_block_tile<32> const& tile,
int32_t const lane_id,
int const num_experts_per_group) {
......@@ -465,7 +469,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
if (num_experts_per_group > WARP_SIZE) {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
T value = apply_scoring<SF>(input[i]);
value = value + bias[i];
value = value + static_cast<T>(bias[i]);
if (value > largest) {
second_largest = largest;
......@@ -477,7 +481,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
} else {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
T value = apply_scoring<SF>(input[i]);
value = value + bias[i];
value = value + static_cast<T>(bias[i]);
largest = value;
}
}
......@@ -499,8 +503,8 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
}
}
template <typename T, ScoringFunc SF>
__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
template <typename T, typename BiasT, ScoringFunc SF>
__global__ void topk_with_k2_kernel(T* output, T* input, BiasT const* bias,
int64_t const num_tokens,
int64_t const num_cases,
int64_t const n_group,
......@@ -513,7 +517,7 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
input += case_id * num_experts_per_group;
// bias is per expert group, offset to current group
int32_t group_id = case_id % n_group;
T const* group_bias = bias + group_id * num_experts_per_group;
BiasT const* group_bias = bias + group_id * num_experts_per_group;
output += case_id;
cg::thread_block block = cg::this_thread_block();
......@@ -522,7 +526,7 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
topk_with_k2<T, SF>(output, input, group_bias, tile, lane_id,
topk_with_k2<T, BiasT, SF>(output, input, group_bias, tile, lane_id,
num_experts_per_group);
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
......@@ -530,10 +534,11 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
#endif
}
template <typename T, typename IdxT, ScoringFunc SF, int NGroup = -1>
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF,
int NGroup = -1>
__global__ void group_idx_and_topk_idx_kernel(
T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
T const* bias, int64_t const num_tokens, int64_t const n_group,
BiasT const* bias, int64_t const num_tokens, int64_t const n_group,
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
int64_t const num_experts_per_group, bool renormalize,
double routed_scaling_factor) {
......@@ -619,7 +624,7 @@ __global__ void group_idx_and_topk_idx_kernel(
T input = scores[offset + i];
if (is_finite(input)) {
T score = apply_scoring<SF>(input);
candidates = score + bias[offset + i];
candidates = score + static_cast<T>(bias[offset + i]);
}
}
queue.add(candidates, offset + i);
......@@ -670,10 +675,13 @@ __global__ void group_idx_and_topk_idx_kernel(
if (case_id < num_tokens) {
if (if_proceed_next_topk) {
float scale = routed_scaling_factor;
if (renormalize) {
scale /= topk_sum;
}
for (int i = lane_id; i < topk; i += WARP_SIZE) {
float base = cuda_cast<float, T>(s_topk_value[i]);
float value = renormalize ? (base / topk_sum * routed_scaling_factor)
: (base * routed_scaling_factor);
float value = base * scale;
topk_indices[i] = s_topk_idx[i];
topk_values[i] = value;
}
......@@ -691,10 +699,10 @@ __global__ void group_idx_and_topk_idx_kernel(
#endif
}
template <typename T, typename IdxT, ScoringFunc SF>
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF>
inline void launch_group_idx_and_topk_kernel(
cudaLaunchConfig_t const& config, T* scores, T* group_scores,
float* topk_values, IdxT* topk_indices, T const* bias,
float* topk_values, IdxT* topk_indices, BiasT const* bias,
int64_t const num_tokens, int64_t const n_group, int64_t const topk_group,
int64_t const topk, int64_t const num_experts,
int64_t const num_experts_per_group, bool const renormalize,
......@@ -708,36 +716,36 @@ inline void launch_group_idx_and_topk_kernel(
switch (n_group) {
case 4: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 4>);
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 4>);
break;
}
case 8: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 8>);
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 8>);
break;
}
case 16: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 16>);
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 16>);
break;
}
case 32: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 32>);
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 32>);
break;
}
default: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF>);
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF>);
break;
}
}
}
template <typename T, typename IdxT>
template <typename T, typename BiasT, typename IdxT>
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
IdxT* topk_indices, T const* bias, int64_t const num_tokens,
int64_t const num_experts, int64_t const n_group,
int64_t const topk_group, int64_t const topk,
bool const renormalize, double const routed_scaling_factor,
int const scoring_func, bool enable_pdl = false,
cudaStream_t const stream = 0) {
IdxT* topk_indices, BiasT const* bias,
int64_t const num_tokens, int64_t const num_experts,
int64_t const n_group, int64_t const topk_group,
int64_t const topk, bool const renormalize,
double const routed_scaling_factor, int const scoring_func,
bool enable_pdl = false, cudaStream_t const stream = 0) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
cudaLaunchConfig_t config;
......@@ -758,12 +766,12 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
};
switch (sf) {
case SCORING_NONE: {
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_NONE>;
auto* kernel_instance1 = &topk_with_k2_kernel<T, BiasT, SCORING_NONE>;
launch_topk_with_k2(kernel_instance1);
break;
}
case SCORING_SIGMOID: {
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_SIGMOID>;
auto* kernel_instance1 = &topk_with_k2_kernel<T, BiasT, SCORING_SIGMOID>;
launch_topk_with_k2(kernel_instance1);
break;
}
......@@ -787,14 +795,14 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
config.attrs = attrs;
switch (sf) {
case SCORING_NONE: {
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_NONE>(
launch_group_idx_and_topk_kernel<T, BiasT, IdxT, SCORING_NONE>(
config, scores, group_scores, topk_values, topk_indices, bias,
num_tokens, n_group, topk_group, topk, num_experts,
num_experts_per_group, renormalize, routed_scaling_factor);
break;
}
case SCORING_SIGMOID: {
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_SIGMOID>(
launch_group_idx_and_topk_kernel<T, BiasT, IdxT, SCORING_SIGMOID>(
config, scores, group_scores, topk_values, topk_indices, bias,
num_tokens, n_group, topk_group, topk, num_experts,
num_experts_per_group, renormalize, routed_scaling_factor);
......@@ -805,17 +813,23 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
}
}
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
template void invokeNoAuxTc<T, IdxT>( \
#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \
template void invokeNoAuxTc<T, BiasT, IdxT>( \
T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \
T const* bias, int64_t const num_tokens, int64_t const num_experts, \
BiasT const* bias, int64_t const num_tokens, int64_t const num_experts, \
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
bool const renormalize, double const routed_scaling_factor, \
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
INSTANTIATE_NOAUX_TC(float, int32_t);
INSTANTIATE_NOAUX_TC(half, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
INSTANTIATE_NOAUX_TC(float, float, int32_t);
INSTANTIATE_NOAUX_TC(float, half, int32_t);
INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t);
INSTANTIATE_NOAUX_TC(half, float, int32_t);
INSTANTIATE_NOAUX_TC(half, half, int32_t);
INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t);
} // end namespace moe
} // namespace vllm
......@@ -824,6 +838,7 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
int64_t topk, bool renormalize, double routed_scaling_factor,
torch::Tensor const& bias, int64_t scoring_func = 0) {
auto data_type = scores.scalar_type();
auto bias_type = bias.scalar_type();
auto input_size = scores.sizes();
int64_t num_tokens = input_size[0];
int64_t num_experts = input_size[1];
......@@ -847,39 +862,62 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device());
#define LAUNCH_KERNEL(T, IdxT) \
do { \
switch (bias_type) { \
case torch::kFloat16: \
vllm::moe::invokeNoAuxTc<T, half, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens, \
num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, static_cast<int>(scoring_func), false, \
stream); \
break; \
case torch::kFloat32: \
vllm::moe::invokeNoAuxTc<T, float, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens, \
num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, static_cast<int>(scoring_func), false, \
stream); \
break; \
case torch::kBFloat16: \
vllm::moe::invokeNoAuxTc<T, __nv_bfloat16, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), \
num_tokens, num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, static_cast<int>(scoring_func), false, \
stream); \
break; \
default: \
throw std::invalid_argument( \
"Invalid bias dtype, only supports float16, float32, and " \
"bfloat16"); \
break; \
} \
} while (0)
switch (data_type) {
case torch::kFloat16:
// Handle Float16
vllm::moe::invokeNoAuxTc<half, int32_t>(
reinterpret_cast<half*>(scores.mutable_data_ptr()),
reinterpret_cast<half*>(group_scores.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
LAUNCH_KERNEL(half, int32_t);
break;
case torch::kFloat32:
// Handle Float32
vllm::moe::invokeNoAuxTc<float, int32_t>(
reinterpret_cast<float*>(scores.mutable_data_ptr()),
reinterpret_cast<float*>(group_scores.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
LAUNCH_KERNEL(float, int32_t);
break;
case torch::kBFloat16:
// Handle BFloat16
vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>(
reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
LAUNCH_KERNEL(__nv_bfloat16, int32_t);
break;
default:
// Handle other data types
......@@ -887,5 +925,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
"Invalid dtype, only supports float16, float32, and bfloat16");
break;
}
#undef LAUNCH_KERNEL
return {topk_values, topk_indices};
}
sm*_kernel_*.cu
kernel_selector.h
kernel_*.cu
......@@ -10,6 +10,8 @@ import jinja2
ARCHS = []
SUPPORT_FP8 = False
SUPPORT_SM75 = False
SUPPORT_SM80 = False
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
......@@ -19,6 +21,10 @@ for arch in sys.argv[1].split(","):
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
if arch >= 80:
SUPPORT_SM80 = True
if arch == 75:
SUPPORT_SM75 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
......@@ -157,6 +163,7 @@ def remove_old_kernels():
def generate_new_kernels():
result_dict = {}
sm_75_result_dict = {}
for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
......@@ -174,6 +181,8 @@ def generate_new_kernels():
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
sm_75_result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs
......@@ -197,17 +206,25 @@ def generate_new_kernels():
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"stages": 4,
"group_blocks": group_blocks,
"is_zp_float": "false",
}
if SUPPORT_SM80:
result_dict[(a_type, b_type, c_type)].append(config)
if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
config_sm75 = config.copy()
config_sm75["stages"] = 2
sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items():
for result_dict_tmp in [result_dict, sm_75_result_dict]:
for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
all_template_str_list = []
if not config_list:
continue
for config in config_list:
s_type = config["s_type"]
template_str = jinja2.Template(TEMPLATE).render(
......@@ -229,6 +246,7 @@ def generate_new_kernels():
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"stages == {config['stages']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
......@@ -262,6 +280,8 @@ def generate_new_kernels():
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
elif result_dict_tmp is sm_75_result_dict:
filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
......
......@@ -19,8 +19,8 @@
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
bool mul_topk_weights, int num_groups, int prob_m, int prob_n, \
int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
bool use_fp32_reduce
namespace MARLIN_NAMESPACE_NAME {
......
......@@ -26,6 +26,7 @@
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h"
#include "quantization/gptq_marlin/marlin_mma.h"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
......@@ -35,7 +36,7 @@
namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
......@@ -70,7 +71,6 @@ __global__ void Marlin(
const float* __restrict__ topk_weights_ptr, // moe top weights
int top_k, // num of experts per token
bool mul_topk_weights, // mul topk weights or not
bool is_ep, // expert parallelism
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
......@@ -84,146 +84,6 @@ __global__ void Marlin(
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
"f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
"r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} else if (k_size == 32) {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma_trans(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
const typename MarlinScalarType<type_id>::FragB& frag_b2,
typename MarlinScalarType<type_id>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
"r"(c[3]));
}
} else {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1200
asm volatile(
"mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#else
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
#endif
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template <int count, vllm::ScalarTypeId type_id>
......@@ -412,7 +272,6 @@ __global__ void Marlin(
const float* __restrict__ topk_weights_ptr, // moe top weights
int top_k, // num of experts per token
bool mul_topk_weights, // mul topk weights or not
bool is_ep, // expert parallelism
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
......@@ -439,9 +298,20 @@ __global__ void Marlin(
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
// Turing TensorCore only supports fp16 and int8
if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id())
return;
#endif
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id();
#else
constexpr bool use_fp16_accum = false;
#endif
using Adtype = MarlinScalarType<a_type_id>;
using Cdtype = MarlinScalarType<c_type_id>;
......@@ -504,14 +374,6 @@ __global__ void Marlin(
// parallel: num valid moe blocks
int parallel = num_tokens_past_padded / moe_block_size;
int num_valid_blocks = parallel;
if (is_ep) {
for (int i = 0; i < parallel; i++) {
if (expert_ids_ptr[i] == -1) num_valid_blocks--;
}
}
int num_invalid_blocks = parallel - num_valid_blocks;
parallel = num_valid_blocks;
int k_tiles = prob_k / 16 / thread_k_blocks;
int n_tiles = prob_n / 16 / thread_n_blocks;
......@@ -618,7 +480,22 @@ __global__ void Marlin(
}
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
if constexpr (moe_block_size >= 16)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 16);
if constexpr (moe_block_size >= 8)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 8);
if constexpr (moe_block_size >= 4)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 4);
if constexpr (moe_block_size >= 2)
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 2);
local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 1);
block_num_valid_tokens = local_count;
#else
block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count);
#endif
if (lane_id == 0)
reinterpret_cast<int*>(sh_new)[0] = block_num_valid_tokens;
......@@ -651,22 +528,8 @@ __global__ void Marlin(
if (par_id >= parallel) return;
old_expert_id = expert_id;
if (num_invalid_blocks > 0) {
int skip_count = par_id;
for (int i = 0; i < num_tokens_past_padded / moe_block_size; i++) {
expert_id = expert_ids_ptr[i];
if (expert_id != -1) {
if (skip_count == 0) {
block_id = i;
break;
};
skip_count--;
};
}
} else {
block_id = par_id;
expert_id = expert_ids_ptr[block_id];
}
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
uint16_t val = global_scale_ptr[expert_id];
......@@ -1018,10 +881,6 @@ __global__ void Marlin(
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage);
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size;
// Register storage for double buffer of shared memory reads.
......@@ -1545,11 +1404,13 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
if constexpr (m_block_size_8) {
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
mma_trans<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0, frag_b1,
frag_c[i][j][0]);
} else {
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b0,
frag_c[i][j][0]);
mma<a_type_id, use_fp16_accum>(frag_a[k2][i], frag_b1,
frag_c[i][j][1]);
}
}
}
......@@ -1583,9 +1444,11 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
mma<a_type_id, false, 32>(
frag_a[k2][i], frag_b[0],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
mma<a_type_id, false, 32>(
frag_a[k2][i], frag_b[1],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
}
......@@ -2132,6 +1995,21 @@ __global__ void Marlin(
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) {
// convert fp16 accum to fp32 for reduction
if constexpr (use_fp16_accum) {
#pragma unroll
for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) {
float* frag_c_part_float = reinterpret_cast<float*>(frag_c) + i * 4;
scalar_t* frag_c_part_half =
reinterpret_cast<scalar_t*>(frag_c_part_float);
#pragma unroll
for (int i = 3; i >= 0; i--) {
frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]);
}
}
}
if constexpr (is_a_8bit) {
float frag_a_s[2 * thread_m_blocks];
......
......@@ -142,7 +142,7 @@ typedef struct {
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) {
bool has_act_order, bool is_k_full, int stages) {
bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n;
......@@ -160,13 +160,13 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
if (cache_scales_chunk) {
int load_groups =
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
} else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * pipe_stages;
return tb_scales * stages;
}
}
......@@ -174,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n,
int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full, int has_zp,
int is_zp_float, bool is_a_8bit) {
int is_zp_float, bool is_a_8bit, int stages) {
int pack_factor = 32 / num_bits;
// Get B size
......@@ -185,8 +185,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 16;
int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2;
int tmp_size =
......@@ -195,8 +195,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
group_size, has_act_order, is_k_full, stages);
int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0;
int sh_zp_size = 0;
if (has_zp) {
if (is_zp_float)
......@@ -217,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
int num_bits, int group_size, bool has_act_order,
bool is_k_full, int has_zp, int is_zp_float,
int max_shared_mem, bool is_a_8bit) {
bool is_a_8bit, int stages, int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
......@@ -243,7 +243,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int cache_size =
get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, is_a_8bit);
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
return cache_size <= max_shared_mem;
}
......@@ -252,7 +252,7 @@ MarlinFuncPtr get_marlin_kernel(
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
int threads, bool is_zp_float) {
int threads, bool is_zp_float, int stages) {
int num_bits = b_type.size_bits();
auto kernel = MarlinDefault;
......@@ -266,8 +266,8 @@ exec_config_t determine_exec_config(
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms,
bool is_a_8bit) {
bool is_k_full, bool has_zp, bool is_zp_float, bool is_a_8bit, int stages,
int max_shared_mem, int sms) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs
......@@ -284,15 +284,15 @@ exec_config_t determine_exec_config(
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem - 512,
is_a_8bit)) {
is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
max_shared_mem - 512)) {
continue;
}
int cache_size = get_kernel_cache_size(
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
is_a_8bit);
is_a_8bit, stages);
int group_blocks = 0;
if (!has_act_order) {
......@@ -303,7 +303,7 @@ exec_config_t determine_exec_config(
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_n / 16, th_config.thread_k / 16,
m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float);
th_config.num_threads, is_zp_float, stages);
if (kernel == MarlinDefault) continue;
......@@ -336,14 +336,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* perm, void* a_tmp, void* sorted_token_ids,
void* expert_ids, void* num_tokens_past_padded,
void* topk_weights, int moe_block_size, int num_experts,
int top_k, bool mul_topk_weights, bool is_ep, int prob_m,
int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& a_type, vllm::ScalarType const& b_type,
vllm::ScalarType const& c_type, vllm::ScalarType const& s_type,
bool has_bias, bool has_act_order, bool is_k_full, bool has_zp,
int num_groups, int group_size, int dev, cudaStream_t stream,
int thread_k, int thread_n, int sms, int blocks_per_sm,
bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
int top_k, bool mul_topk_weights, int prob_m, int prob_n,
int prob_k, void* workspace, vllm::ScalarType const& a_type,
vllm::ScalarType const& b_type, vllm::ScalarType const& c_type,
vllm::ScalarType const& s_type, bool has_bias,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k,
int thread_n, int sms, int blocks_per_sm, bool use_atomic_add,
bool use_fp32_reduce, bool is_zp_float) {
int thread_m_blocks = div_ceil(moe_block_size, 16);
bool m_block_size_8 = moe_block_size == 8;
bool is_a_8bit = a_type.size_bits() == 8;
......@@ -433,8 +433,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
"marlin kernel only support Ampere or newer GPUs.");
TORCH_CHECK(major_capability * 10 + minor_capability >= 75,
"marlin kernel only support Turing or newer GPUs.");
int stages = 4;
if (major_capability == 7 && minor_capability == 5) {
stages = 2;
TORCH_CHECK(a_type == vllm::kFloat16 || a_type == vllm::kS8,
"Turing only support FP16 or INT8 activation.");
}
if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
"FP8 only support Ada Lovelace or newer GPUs.");
......@@ -461,8 +467,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
exec_cfg = determine_exec_config(
a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms,
is_a_8bit);
has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages,
max_shared_mem, sms);
thread_tfg = exec_cfg.tb_cfg;
}
......@@ -479,7 +485,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem, is_a_8bit),
is_a_8bit, stages, max_shared_mem),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
......@@ -493,12 +499,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int sh_cache_size =
get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, is_a_8bit);
is_k_full, has_zp, is_zp_float, is_a_8bit, stages);
auto kernel = get_marlin_kernel(
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
num_threads, is_zp_float);
num_threads, is_zp_float, stages);
if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
......@@ -517,7 +523,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
topk_weights_ptr, top_k, mul_topk_weights, num_groups, prob_m,
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce);
// clang-format on
}
......@@ -535,7 +541,7 @@ torch::Tensor moe_wna16_marlin_gemm(
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights,
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float, int64_t thread_k, int64_t thread_n,
......@@ -849,9 +855,9 @@ torch::Tensor moe_wna16_marlin_gemm(
perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
topk_weights.data_ptr(), moe_block_size, num_experts, top_k,
mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(),
a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
mul_topk_weights, size_m, size_n, size_k, workspace.data_ptr(), a_type,
b_type, c_type, s_type, has_bias, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce,
is_zp_float);
......
......@@ -80,7 +80,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_type_id,"
"bool mul_topk_weights, int b_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float,"
......
......@@ -2,6 +2,7 @@
#include <optional>
#include <torch/library.h>
#include <tuple>
#include "core/scalar_type.hpp"
......@@ -280,6 +281,11 @@ void get_cutlass_moe_mm_problem_sizes(
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
std::optional<bool> force_swap_ab = std::nullopt);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::Tensor& expert_first_token_offset,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
const int64_t n, const int64_t k, const bool swap_ab);
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
......@@ -316,6 +322,12 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min,
......@@ -350,8 +362,9 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
// void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor const& scale);
// void static_scaled_fp8_quant(
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale,
// std::optional<std::tuple<int64_t, int64_t>> group_shape = std::nullopt);
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& scale);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment