Commit 41199996 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.12.0' into v0.12.0-dev

parents 31021d81 4fd9d6a8
......@@ -192,7 +192,7 @@ class SHMManager {
const int group_size)
: _rank(rank),
_group_size(group_size),
_thread_num(torch::get_num_threads()),
_thread_num(omp_get_max_threads()),
_shm_names({""}),
_shared_mem_ptrs({nullptr}),
_shm_ctx(nullptr) {
......
......@@ -27,6 +27,8 @@ int64_t create_onednn_mm_handler(const torch::Tensor& b,
void onednn_mm(torch::Tensor& c, const torch::Tensor& a,
const std::optional<torch::Tensor>& bias, int64_t handler);
bool is_onednn_acl_supported();
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& kv_cache, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens);
......@@ -72,25 +74,45 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, bool is_vnni);
torch::Tensor get_scheduler_metadata(
const int64_t num_req, const int64_t num_heads_q,
const int64_t num_heads_kv, const int64_t head_dim,
const torch::Tensor& seq_lens, at::ScalarType dtype,
const torch::Tensor& query_start_loc, const bool casual,
const int64_t window_size, const std::string& isa_hint,
const bool enable_kv_split);
void cpu_attn_reshape_and_cache(const torch::Tensor& key,
const torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
const torch::Tensor& slot_mapping,
const std::string& isa);
void cpu_attention_with_kv_cache(
const torch::Tensor& query, const torch::Tensor& key_cache,
const torch::Tensor& value_cache, torch::Tensor& output,
const torch::Tensor& query_start_loc, const torch::Tensor& seq_lens,
const double scale, const bool causal,
const std::optional<torch::Tensor>& alibi_slopes,
const int64_t sliding_window_left, const int64_t sliding_window_right,
const torch::Tensor& block_table, const double softcap,
const torch::Tensor& scheduler_metadata,
const std::optional<torch::Tensor>& s_aux);
// Note: just for avoiding importing errors
void placeholder_op() { TORCH_CHECK(false, "Unimplemented"); }
void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
torch::Tensor& output, const torch::Tensor& scales,
const std::optional<torch::Tensor>& zeros,
const std::optional<torch::Tensor>& g_idx,
const std::optional<torch::Tensor>& bias,
const int64_t pack_factor, const std::string& isa_hint);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
// Attention ops
// Compute the attention between an input query and the cached keys/values
// using PagedAttention.
ops.def(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
ops.def(
"dynamic_4bit_int_moe("
"Tensor x, Tensor topk_ids, Tensor topk_weights,"
......@@ -100,20 +122,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
// Activation ops
// Activation function used in SwiGLU.
......@@ -164,7 +172,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantization
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \
defined(__powerpc64__)
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
// Helper function to release oneDNN handlers
ops.def("release_dnnl_matmul_handler(int handler) -> ()",
&release_dnnl_matmul_handler);
......@@ -181,6 +188,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int handler) -> ()");
ops.impl("onednn_mm", torch::kCPU, &onednn_mm);
// Check if oneDNN was built with ACL backend
ops.def("is_onednn_acl_supported() -> bool", &is_onednn_acl_supported);
// Create oneDNN W8A8 handler
ops.def(
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
......@@ -197,15 +207,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
"Tensor? azp) -> ()",
{stride_tag});
"Tensor? azp) -> ()");
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()",
{stride_tag});
"Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
&dynamic_scaled_int8_quant);
#endif
......@@ -254,37 +262,40 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
&int8_scaled_mm_with_quant);
#endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Cache ops
// Swap in (out) the cache blocks from src to dst.
cache_ops.def(
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks);
// Copy the cache blocks from src to dst.
cache_ops.def(
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
"Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
cache_ops.def(
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
cache_ops.impl("concat_and_cache_mla", torch::kCPU, &concat_and_cache_mla);
// CPU attention kernels
ops.def(
"get_scheduler_metadata(int num_req, int num_heads_q, int num_heads_kv, "
"int head_dim, Tensor seq_lens, ScalarType dtype, Tensor "
"query_start_loc, bool casual, int window_size, str isa_hint, bool "
"enable_kv_split) -> Tensor",
&get_scheduler_metadata);
ops.def(
"cpu_attn_reshape_and_cache(Tensor key, Tensor value, Tensor(a2!) "
"key_cache, Tensor(a3!) value_cache, Tensor slot_mapping, str "
"isa) -> ()",
&cpu_attn_reshape_and_cache);
ops.def(
"cpu_attention_with_kv_cache(Tensor query, Tensor key_cache, Tensor "
"value_cache, Tensor(a3!) output, Tensor query_start_loc, Tensor "
"seq_lens, float scale, bool causal, Tensor? alibi_slopes, SymInt "
"sliding_window_left, SymInt sliding_window_right, Tensor block_table, "
"float softcap, Tensor sheduler_metadata, Tensor? s_aux) -> ()",
&cpu_attention_with_kv_cache);
// placeholders
ops.def("static_scaled_fp8_quant() -> ()", placeholder_op);
ops.def("dynamic_scaled_fp8_quant() -> ()", placeholder_op);
ops.def("dynamic_per_token_scaled_fp8_quant() -> ()", placeholder_op);
// WNA16
#if defined(__AVX512F__)
ops.def(
"cpu_gemm_wna16(Tensor input, Tensor q_weight, Tensor(a2!) output, "
"Tensor scales, Tensor? zeros, Tensor? g_idx, Tensor? bias, SymInt "
"pack_factor, str isa_hint) -> ()");
ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
#endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
......
......@@ -45,21 +45,55 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
// Memory node binding
if (numa_available() != -1) {
int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front());
bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str());
bitmask* src_mask = numa_get_membind();
int pid = getpid();
// move all existing pages to the specified numa node.
*(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
int page_num = numa_migrate_pages(pid, src_mask, mask);
if (page_num == -1) {
TORCH_WARN("numa_migrate_pages failed. errno: " + std::to_string(errno));
std::set<int> node_ids;
for (const auto& cpu_id : omp_cpu_ids) {
int node_id = numa_node_of_cpu(cpu_id);
if (node_id != -1) {
node_ids.insert(node_id);
}
if (node_id != mem_node_id) {
TORCH_WARN("CPU ", cpu_id, " is on NUMA node ", node_id, ", but CPU ",
omp_cpu_ids.front(), " is on NUMA node ", mem_node_id,
". All CPUs should be on the same NUMA node for optimal "
"performance. Memory will be bound to NUMA node ",
mem_node_id, ".");
}
}
// Concatenate all node_ids into a single comma-separated string
if (!node_ids.empty()) {
std::string node_ids_str;
for (const int node_id : node_ids) {
if (!node_ids_str.empty()) {
node_ids_str += ",";
}
node_ids_str += std::to_string(node_id);
}
// restrict memory allocation node.
numa_set_membind(mask);
numa_set_strict(1);
bitmask* mask = numa_parse_nodestring(node_ids_str.c_str());
bitmask* src_mask = numa_get_membind();
int pid = getpid();
if (mask && src_mask) {
// move all existing pages to the specified numa node.
*(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
int page_num = numa_migrate_pages(pid, src_mask, mask);
if (page_num == -1) {
TORCH_WARN("numa_migrate_pages failed. errno: " +
std::to_string(errno));
}
// restrict memory allocation node.
numa_set_membind(mask);
numa_set_strict(1);
numa_free_nodemask(mask);
numa_free_nodemask(src_mask);
} else {
TORCH_WARN("numa_parse_nodestring or numa_get_membind failed. errno: " +
std::to_string(errno));
}
}
}
// OMP threads binding
......
#ifndef UTILS_HPP
#define UTILS_HPP
#include <atomic>
#include <cassert>
#include <cstdint>
#include <unistd.h>
#if defined(__APPLE__)
#include <sys/sysctl.h>
#endif
#include "cpu_types.hpp"
namespace cpu_utils {
enum class ISA { AMX, VEC };
template <typename T>
struct VecTypeTrait {
using vec_t = void;
};
template <>
struct VecTypeTrait<float> {
using vec_t = vec_op::FP32Vec16;
};
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
template <>
struct VecTypeTrait<c10::BFloat16> {
using vec_t = vec_op::BF16Vec16;
};
#endif
template <>
struct VecTypeTrait<c10::Half> {
using vec_t = vec_op::FP16Vec16;
};
struct Counter {
std::atomic<int64_t> counter;
char _padding[56];
Counter() : counter(0) {}
void reset_counter() { counter.store(0); }
int64_t acquire_counter() { return counter++; }
};
inline int64_t get_l2_size() {
static int64_t size = []() {
#if defined(__APPLE__)
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
int64_t l2_cache_size = 0;
size_t len = sizeof(l2_cache_size);
if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
l2_cache_size > 0) {
return l2_cache_size >> 1; // use 50% of L2 cache
}
// Fallback if sysctlbyname fails
return 128LL * 1024 >> 1; // use 50% of 128KB
#else
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
assert(l2_cache_size != -1);
return l2_cache_size >> 1; // use 50% of L2 cache
#endif
}();
return size;
}
} // namespace cpu_utils
#endif
......@@ -12,6 +12,7 @@ using CubMaxOp = cub::Max;
#endif // CUB_VERSION
#else
#include <hipcub/hipcub.hpp>
using CubAddOp = cub::Sum;
using CubMaxOp = cub::Max;
namespace cub = hipcub;
using CubAddOp = hipcub::Sum;
using CubMaxOp = hipcub::Max;
#endif // USE_ROCM
......@@ -22,15 +22,10 @@ torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
auto strides = cpu_tensor.strides();
auto options = cpu_tensor.options().device(torch::kCUDA);
// from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter,
// const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the
// memory, so we don't free it here.
auto deleter = [](void*) {
// no-op, since the memory is owned by the original CPU tensor
};
// use default no-op deleter, since the memory is owned by the original CPU
// tensor
torch::Tensor cuda_tensor =
torch::from_blob(device_ptr, sizes, strides, deleter, options);
torch::from_blob(device_ptr, sizes, strides, options);
TORCH_CHECK(cuda_tensor.device().is_cuda(),
"Resulting tensor is not on CUDA device");
......
......@@ -3,14 +3,58 @@
// need to be unsigned long long
#include <iostream>
#include "cumem_allocator_compat.h"
#ifndef USE_ROCM
static const char* PYARGS_PARSE = "KKKK";
#else
#include <cstdlib>
#include <cerrno>
#include <climits>
// Default chunk size 256MB for ROCm. Can be overridden at runtime by the
// environment variable VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE, specified in megabytes
// (MB). The env value is parsed with strtoull as an integer number of MB
// (decimal or 0x hex). The parsed MB value is converted to bytes. If
// parsing fails, the value is 0, or the multiplication would overflow,
// the default (256MB) is used.
static const unsigned long long DEFAULT_MEMCREATE_CHUNK_SIZE =
(256ULL * 1024ULL * 1024ULL);
static unsigned long long get_memcreate_chunk_size() {
const char* env = getenv("VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE");
if (!env) return DEFAULT_MEMCREATE_CHUNK_SIZE;
char* endptr = nullptr;
errno = 0;
unsigned long long val_mb = strtoull(env, &endptr, 0);
if (endptr == env || errno != 0) {
// parsing failed, fallback to default
return DEFAULT_MEMCREATE_CHUNK_SIZE;
}
if (val_mb == 0) return DEFAULT_MEMCREATE_CHUNK_SIZE;
const unsigned long long MB = 1024ULL * 1024ULL;
// guard against overflow when converting MB -> bytes
if (val_mb > (ULLONG_MAX / MB)) {
return DEFAULT_MEMCREATE_CHUNK_SIZE;
}
return val_mb * MB;
}
static inline unsigned long long my_min(unsigned long long a,
unsigned long long b) {
return a < b ? a : b;
}
static const char* PYARGS_PARSE = "KKKO";
#endif
extern "C" {
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <cuda.h>
char error_msg[10240]; // 10KB buffer to store error messages
CUresult no_error = CUresult(0);
......@@ -49,7 +93,12 @@ void ensure_context(unsigned long long device) {
}
void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
#ifndef USE_ROCM
CUmemGenericAllocationHandle* p_memHandle) {
#else
CUmemGenericAllocationHandle** p_memHandle,
unsigned long long* chunk_sizes, size_t num_chunks) {
#endif
ensure_context(device);
// Define memory allocation properties
CUmemAllocationProp prop = {};
......@@ -58,6 +107,7 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
prop.location.id = device;
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
#ifndef USE_ROCM
// Allocate memory using cuMemCreate
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
if (error_code != 0) {
......@@ -67,6 +117,39 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
if (error_code != 0) {
return;
}
#else
for (auto i = 0; i < num_chunks; ++i) {
CUDA_CHECK(cuMemCreate(p_memHandle[i], chunk_sizes[i], &prop, 0));
if (error_code != 0) {
// Clean up previously created handles
for (auto j = 0; j < i; ++j) {
cuMemRelease(*(p_memHandle[j]));
}
return;
}
}
unsigned long long allocated_size = 0;
for (auto i = 0; i < num_chunks; ++i) {
void* map_addr = (void*)((uintptr_t)d_mem + allocated_size);
CUDA_CHECK(cuMemMap(map_addr, chunk_sizes[i], 0, *(p_memHandle[i]), 0));
if (error_code != 0) {
// unmap previously mapped chunks
unsigned long long unmapped_size = 0;
for (auto j = 0; j < i; ++j) {
void* unmap_addr = (void*)((uintptr_t)d_mem + unmapped_size);
cuMemUnmap(unmap_addr, chunk_sizes[j]);
unmapped_size += chunk_sizes[j];
}
// release all created handles
for (auto j = 0; j < num_chunks; ++j) {
cuMemRelease(*(p_memHandle[j]));
}
return;
}
allocated_size += chunk_sizes[i];
}
#endif
CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = device;
......@@ -82,10 +165,16 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
void unmap_and_release(unsigned long long device, ssize_t size,
CUdeviceptr d_mem,
#ifndef USE_ROCM
CUmemGenericAllocationHandle* p_memHandle) {
#else
CUmemGenericAllocationHandle** p_memHandle,
unsigned long long* chunk_sizes, size_t num_chunks) {
#endif
// std::cout << "unmap_and_release: device=" << device << ", size=" << size <<
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
ensure_context(device);
#ifndef USE_ROCM
CUDA_CHECK(cuMemUnmap(d_mem, size));
if (error_code != 0) {
return;
......@@ -94,6 +183,30 @@ void unmap_and_release(unsigned long long device, ssize_t size,
if (error_code != 0) {
return;
}
#else
unsigned long long allocated_size = 0;
CUresult first_error = no_error;
for (auto i = 0; i < num_chunks; ++i) {
void* map_addr = (void*)((uintptr_t)d_mem + allocated_size);
CUresult status = cuMemUnmap(map_addr, chunk_sizes[i]);
if (status != no_error && first_error == no_error) {
first_error = status;
}
allocated_size += chunk_sizes[i];
}
for (auto i = 0; i < num_chunks; ++i) {
CUresult status = cuMemRelease(*(p_memHandle[i]));
if (status != no_error && first_error == no_error) {
first_error = status;
}
}
if (first_error != no_error) {
CUDA_CHECK(first_error);
}
#endif
}
PyObject* create_tuple_from_c_integers(unsigned long long a,
......@@ -120,6 +233,36 @@ PyObject* create_tuple_from_c_integers(unsigned long long a,
return tuple; // Return the created tuple
}
PyObject* create_tuple_from_c_mixed(unsigned long long a, unsigned long long b,
unsigned long long c,
CUmemGenericAllocationHandle** vec,
unsigned long long* chunk_sizes,
size_t num_chunks) {
PyObject* tuple = PyTuple_New(4);
if (!tuple) {
return NULL;
}
// PyObject* list = PyList_New(vec.size());
PyObject* list = PyList_New(num_chunks);
for (auto i = 0; i < num_chunks; ++i) {
PyObject* addr_size_pair = PyTuple_New(2);
PyObject* addr = PyLong_FromUnsignedLongLong((unsigned long long)(vec[i]));
PyObject* size =
PyLong_FromUnsignedLongLong((unsigned long long)(chunk_sizes[i]));
PyTuple_SetItem(addr_size_pair, 0, addr);
PyTuple_SetItem(addr_size_pair, 1, size);
PyList_SetItem(list, i, addr_size_pair);
}
PyTuple_SetItem(tuple, 0, PyLong_FromUnsignedLongLong(a));
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b));
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
PyTuple_SetItem(tuple, 3, list);
return tuple;
}
// ---------------------------------------------------------------------------
// Our exported C functions that call Python:
......@@ -147,14 +290,55 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
size_t alignedSize = ((size + granularity - 1) / granularity) * granularity;
CUdeviceptr d_mem;
#ifndef USE_ROCM
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0));
if (error_code != 0) {
return nullptr;
}
#else
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, granularity, 0, 0));
if (error_code != 0) {
return nullptr;
}
#endif
#ifndef USE_ROCM
// allocate the CUmemGenericAllocationHandle
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)malloc(
sizeof(CUmemGenericAllocationHandle));
#else
// Make sure chunk size is aligned with hardware granularity. The base
// chunk size can be configured via environment variable
// ``VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE``; otherwise
// DEFAULT_MEMCREATE_CHUNK_SIZE is used.
size_t base_chunk = (size_t)get_memcreate_chunk_size();
size_t aligned_chunk_size =
((base_chunk + granularity - 1) / granularity) * granularity;
size_t num_chunks =
(alignedSize + aligned_chunk_size - 1) / aligned_chunk_size;
CUmemGenericAllocationHandle** p_memHandle =
(CUmemGenericAllocationHandle**)malloc(
num_chunks * sizeof(CUmemGenericAllocationHandle*));
unsigned long long* chunk_sizes =
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
for (auto i = 0; i < num_chunks; ++i) {
p_memHandle[i] = (CUmemGenericAllocationHandle*)malloc(
sizeof(CUmemGenericAllocationHandle));
if (p_memHandle[i] == nullptr) {
std::cerr << "ERROR: malloc failed for p_memHandle[" << i << "].\n";
for (auto j = 0; j < i; ++j) {
free(p_memHandle[j]);
}
free(p_memHandle);
free(chunk_sizes);
return nullptr;
}
chunk_sizes[i] = (unsigned long long)my_min(
(unsigned long long)(alignedSize - i * aligned_chunk_size),
(unsigned long long)aligned_chunk_size);
}
#endif
if (!g_python_malloc_callback) {
std::cerr << "ERROR: g_python_malloc_callback not set.\n";
......@@ -164,9 +348,15 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();
#ifndef USE_ROCM
PyObject* arg_tuple = create_tuple_from_c_integers(
(unsigned long long)device, (unsigned long long)alignedSize,
(unsigned long long)d_mem, (unsigned long long)p_memHandle);
#else
PyObject* arg_tuple = create_tuple_from_c_mixed(
(unsigned long long)device, (unsigned long long)alignedSize,
(unsigned long long)d_mem, p_memHandle, chunk_sizes, num_chunks);
#endif
// Call g_python_malloc_callback
PyObject* py_result =
......@@ -182,7 +372,27 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
PyGILState_Release(gstate);
// do the final mapping
#ifndef USE_ROCM
create_and_map(device, alignedSize, d_mem, p_memHandle);
#else
create_and_map(device, alignedSize, d_mem, p_memHandle, chunk_sizes,
num_chunks);
free(chunk_sizes);
#endif
if (error_code != 0) {
// free address and the handle
CUDA_CHECK(cuMemAddressFree(d_mem, alignedSize));
#ifndef USE_ROCM
free(p_memHandle);
#else
for (size_t i = 0; i < num_chunks; ++i) {
free(p_memHandle[i]);
}
free(p_memHandle);
#endif
return nullptr;
}
return (void*)d_mem;
}
......@@ -206,36 +416,96 @@ void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
Py_XDECREF(py_result);
Py_XDECREF(py_ptr);
return;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
unsigned long long recv_d_mem;
#ifndef USE_ROCM
unsigned long long recv_p_memHandle;
#else
PyObject* recv_p_memHandle;
#endif
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size,
if (!PyArg_ParseTuple(py_result, PYARGS_PARSE, &recv_device, &recv_size,
&recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
Py_XDECREF(py_result);
Py_XDECREF(py_ptr);
return;
}
PyGILState_Release(gstate);
// For ROCm, copy the Python list of (addr,size) pairs into C arrays while
// holding the GIL. Then release the GIL and call the unmap/release helper
// using the copied arrays. This avoids calling PyList_* APIs without the
// GIL (which is undefined behavior and can crash when called from other
// threads).
CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem;
#ifdef USE_ROCM
Py_ssize_t num_chunks = PyList_Size(recv_p_memHandle);
CUmemGenericAllocationHandle** p_memHandle =
(CUmemGenericAllocationHandle**)malloc(
num_chunks * sizeof(CUmemGenericAllocationHandle*));
if (p_memHandle == nullptr) {
Py_DECREF(py_ptr);
Py_DECREF(py_result);
PyGILState_Release(gstate);
std::cerr << "ERROR: malloc failed for p_memHandle in my_free."
<< std::endl;
return;
}
unsigned long long* chunk_sizes =
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
if (chunk_sizes == nullptr) {
free(p_memHandle);
Py_DECREF(py_ptr);
Py_DECREF(py_result);
PyGILState_Release(gstate);
std::cerr << "ERROR: malloc failed for chunk_sizes in my_free."
<< std::endl;
return;
}
for (Py_ssize_t i = 0; i < num_chunks; ++i) {
PyObject* item = PyList_GetItem(recv_p_memHandle, i);
PyObject* addr_py = PyTuple_GetItem(item, 0);
PyObject* size_py = PyTuple_GetItem(item, 1);
p_memHandle[i] =
(CUmemGenericAllocationHandle*)PyLong_AsUnsignedLongLong(addr_py);
chunk_sizes[i] = (unsigned long long)PyLong_AsUnsignedLongLong(size_py);
}
// recv_size == size
// recv_device == device
// Drop temporary Python refs, then release the GIL before calling into
// non-Python APIs.
Py_DECREF(py_ptr);
Py_DECREF(py_result);
PyGILState_Release(gstate);
// Free memory
unmap_and_release(device, size, d_mem, p_memHandle, chunk_sizes, num_chunks);
#else
// Non-ROCm path: simple integer handle already extracted; drop temporary
// Python refs while still holding the GIL, then release it.
Py_DECREF(py_ptr);
Py_DECREF(py_result);
PyGILState_Release(gstate);
CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem;
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)recv_p_memHandle;
unmap_and_release(device, size, d_mem, p_memHandle);
#endif
// free address and the handle
CUDA_CHECK(cuMemAddressFree(d_mem, size));
if (error_code != 0) {
return;
#ifndef USE_ROCM
free(p_memHandle);
#else
for (auto i = 0; i < num_chunks; ++i) {
free(p_memHandle[i]);
}
free(p_memHandle);
free(chunk_sizes);
#endif
}
// ---------------------------------------------------------------------------
......@@ -271,19 +541,87 @@ static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
unsigned long long recv_d_mem;
#ifndef USE_ROCM
unsigned long long recv_p_memHandle;
#else
PyObject* recv_p_memHandle;
#endif
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
if (!PyArg_ParseTuple(args, PYARGS_PARSE, &recv_device, &recv_size,
&recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
#ifndef USE_ROCM
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)recv_p_memHandle;
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
#else
if (!PyList_Check(recv_p_memHandle)) {
PyErr_SetString(PyExc_TypeError,
"Expected a list for the 4th argument on ROCm");
return nullptr;
}
Py_ssize_t num_chunks = PyList_Size(recv_p_memHandle);
if (num_chunks < 0) {
return nullptr; // PyList_Size sets an exception on error.
}
CUmemGenericAllocationHandle** p_memHandle =
(CUmemGenericAllocationHandle**)malloc(
num_chunks * sizeof(CUmemGenericAllocationHandle*));
if (p_memHandle == nullptr) {
PyErr_SetString(PyExc_MemoryError, "malloc failed for p_memHandle");
return nullptr;
}
unsigned long long* chunk_sizes =
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
if (chunk_sizes == nullptr) {
free(p_memHandle);
PyErr_SetString(PyExc_MemoryError, "malloc failed for chunk_sizes");
return nullptr;
}
for (Py_ssize_t i = 0; i < num_chunks; ++i) {
PyObject* item = PyList_GetItem(recv_p_memHandle, i);
if (item == nullptr || !PyTuple_Check(item) || PyTuple_Size(item) != 2) {
free(p_memHandle);
free(chunk_sizes);
PyErr_SetString(
PyExc_TypeError,
"List items must be tuples of size 2 (handle_addr, size)");
return nullptr;
}
PyObject* addr_py = PyTuple_GetItem(item, 0);
PyObject* size_py = PyTuple_GetItem(item, 1);
if (addr_py == nullptr || size_py == nullptr) {
free(p_memHandle);
free(chunk_sizes);
return nullptr; // PyTuple_GetItem sets an exception
}
p_memHandle[i] =
(CUmemGenericAllocationHandle*)PyLong_AsUnsignedLongLong(addr_py);
if (PyErr_Occurred()) {
free(p_memHandle);
free(chunk_sizes);
return nullptr;
}
chunk_sizes[i] = (unsigned long long)PyLong_AsUnsignedLongLong(size_py);
if (PyErr_Occurred()) {
free(p_memHandle);
free(chunk_sizes);
return nullptr;
}
}
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle, chunk_sizes,
num_chunks);
free(p_memHandle);
free(chunk_sizes);
#endif
if (error_code != 0) {
error_code = no_error;
......@@ -301,19 +639,56 @@ static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
unsigned long long recv_d_mem;
#ifndef USE_ROCM
unsigned long long recv_p_memHandle;
#else
PyObject* recv_p_memHandle;
#endif
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
if (!PyArg_ParseTuple(args, PYARGS_PARSE, &recv_device, &recv_size,
&recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
#ifndef USE_ROCM
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)recv_p_memHandle;
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
#else
Py_ssize_t num_chunks = PyList_Size(recv_p_memHandle);
CUmemGenericAllocationHandle** p_memHandle =
(CUmemGenericAllocationHandle**)malloc(
num_chunks * sizeof(CUmemGenericAllocationHandle*));
if (p_memHandle == nullptr) {
PyErr_SetString(PyExc_MemoryError, "malloc failed for p_memHandle");
return nullptr;
}
unsigned long long* chunk_sizes =
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
if (chunk_sizes == nullptr) {
free(p_memHandle);
PyErr_SetString(PyExc_MemoryError, "malloc failed for chunk_sizes");
return nullptr;
}
for (auto i = 0; i < num_chunks; ++i) {
PyObject* item = PyList_GetItem(recv_p_memHandle, i);
PyObject* addr_py = PyTuple_GetItem(item, 0);
PyObject* size_py = PyTuple_GetItem(item, 1);
p_memHandle[i] =
(CUmemGenericAllocationHandle*)PyLong_AsUnsignedLongLong(addr_py);
chunk_sizes[i] = PyLong_AsUnsignedLongLong(size_py);
}
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle, chunk_sizes,
num_chunks);
free(p_memHandle);
free(chunk_sizes);
#endif
if (error_code != 0) {
error_code = no_error;
......
#pragma once
#ifdef USE_ROCM
////////////////////////////////////////
// For compatibility with CUDA and ROCm
////////////////////////////////////////
#include <hip/hip_runtime_api.h>
extern "C" {
#ifndef CUDA_SUCCESS
#define CUDA_SUCCESS hipSuccess
#endif // CUDA_SUCCESS
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html
typedef unsigned long long CUdevice;
typedef hipDeviceptr_t CUdeviceptr;
typedef hipError_t CUresult;
typedef hipCtx_t CUcontext;
typedef hipStream_t CUstream;
typedef hipMemGenericAllocationHandle_t CUmemGenericAllocationHandle;
typedef hipMemAllocationGranularity_flags CUmemAllocationGranularity_flags;
typedef hipMemAllocationProp CUmemAllocationProp;
typedef hipMemAccessDesc CUmemAccessDesc;
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_MEM_ALLOC_GRANULARITY_MINIMUM hipMemAllocationGranularityMinimum
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
#define CU_MEM_ALLOCATION_COMP_NONE 0x0
// Error Handling
// https://docs.nvidia.com/cuda/archive/11.4.4/cuda-driver-api/group__CUDA__ERROR.html
CUresult cuGetErrorString(CUresult hipError, const char** pStr) {
*pStr = hipGetErrorString(hipError);
return CUDA_SUCCESS;
}
// Context Management
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html
CUresult cuCtxGetCurrent(CUcontext* ctx) {
// This API is deprecated on the AMD platform, only for equivalent cuCtx
// driver API on the NVIDIA platform.
return hipCtxGetCurrent(ctx);
}
CUresult cuCtxSetCurrent(CUcontext ctx) {
// This API is deprecated on the AMD platform, only for equivalent cuCtx
// driver API on the NVIDIA platform.
return hipCtxSetCurrent(ctx);
}
// Primary Context Management
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PRIMARY__CTX.html
CUresult cuDevicePrimaryCtxRetain(CUcontext* ctx, CUdevice dev) {
return hipDevicePrimaryCtxRetain(ctx, dev);
}
// Virtual Memory Management
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html
CUresult cuMemAddressFree(CUdeviceptr ptr, size_t size) {
return hipMemAddressFree(ptr, size);
}
CUresult cuMemAddressReserve(CUdeviceptr* ptr, size_t size, size_t alignment,
CUdeviceptr addr, unsigned long long flags) {
return hipMemAddressReserve(ptr, size, alignment, addr, flags);
}
CUresult cuMemCreate(CUmemGenericAllocationHandle* handle, size_t size,
const CUmemAllocationProp* prop,
unsigned long long flags) {
return hipMemCreate(handle, size, prop, flags);
}
CUresult cuMemGetAllocationGranularity(
size_t* granularity, const CUmemAllocationProp* prop,
CUmemAllocationGranularity_flags option) {
return hipMemGetAllocationGranularity(granularity, prop, option);
}
CUresult cuMemMap(CUdeviceptr dptr, size_t size, size_t offset,
CUmemGenericAllocationHandle handle,
unsigned long long flags) {
return hipMemMap(dptr, size, offset, handle, flags);
}
CUresult cuMemRelease(CUmemGenericAllocationHandle handle) {
return hipMemRelease(handle);
}
CUresult cuMemSetAccess(CUdeviceptr ptr, size_t size,
const CUmemAccessDesc* desc, size_t count) {
return hipMemSetAccess(ptr, size, desc, count);
}
CUresult cuMemUnmap(CUdeviceptr ptr, size_t size) {
return hipMemUnmap(ptr, size);
}
} // extern "C"
#else
////////////////////////////////////////
// Import CUDA headers for NVIDIA GPUs
////////////////////////////////////////
#include <cuda_runtime_api.h>
#include <cuda.h>
#endif
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from typing import Union
from cutlass_library import *
......@@ -22,31 +21,31 @@ class MixedInputKernelScheduleType(enum.Enum):
TmaWarpSpecializedCooperative = enum_auto()
VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = {
VLLMDataTypeNames: dict[VLLMDataType | DataType, str] = {
**DataTypeNames, # type: ignore
**{
VLLMDataType.u4b8: "u4b8",
VLLMDataType.u8b128: "u8b128",
}
},
}
VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
VLLMDataTypeTag: dict[VLLMDataType | DataType, str] = {
**DataTypeTag, # type: ignore
**{
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
}
},
}
VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
VLLMDataTypeSize: dict[VLLMDataType | DataType, int] = {
**DataTypeSize, # type: ignore
**{
VLLMDataType.u4b8: 4,
VLLMDataType.u8b128: 8,
}
},
}
VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
VLLMDataTypeVLLMScalarTypeTag: dict[VLLMDataType | DataType, str] = {
VLLMDataType.u4b8: "vllm::kU4B8",
VLLMDataType.u8b128: "vllm::kU8B128",
DataType.u4: "vllm::kU4",
......@@ -57,7 +56,7 @@ VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
DataType.bf16: "vllm::kBfloat16",
}
VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
VLLMDataTypeTorchDataTypeTag: dict[VLLMDataType | DataType, str] = {
DataType.u8: "at::ScalarType::Byte",
DataType.s8: "at::ScalarType::Char",
DataType.e4m3: "at::ScalarType::Float8_e4m3fn",
......@@ -67,15 +66,11 @@ VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
DataType.f32: "at::ScalarType::Float",
}
VLLMKernelScheduleTag: dict[Union[
MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore
**{
MixedInputKernelScheduleType.TmaWarpSpecialized:
"cutlass::gemm::KernelTmaWarpSpecialized",
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong:
"cutlass::gemm::KernelTmaWarpSpecializedPingpong",
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative:
"cutlass::gemm::KernelTmaWarpSpecializedCooperative",
}
}
VLLMKernelScheduleTag: dict[MixedInputKernelScheduleType | KernelScheduleType, str] = {
**KernelScheduleTag, # type: ignore
**{
MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized", # noqa: E501
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong", # noqa: E501
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative", # noqa: E501
},
}
......@@ -88,3 +88,53 @@
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_VEC_SIZE(VEC_SIZE, ...) \
switch (VEC_SIZE) { \
case 16: { \
constexpr int vec_size = 16; \
__VA_ARGS__(); \
break; \
} \
case 8: { \
constexpr int vec_size = 8; \
__VA_ARGS__(); \
break; \
} \
case 4: { \
constexpr int vec_size = 4; \
__VA_ARGS__(); \
break; \
} \
case 2: { \
constexpr int vec_size = 2; \
__VA_ARGS__(); \
break; \
} \
default: { \
constexpr int vec_size = 1; \
__VA_ARGS__(); \
break; \
} \
}
#define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \
switch (NUM_DIMS) { \
case 2: { \
constexpr int tensor_rank = 2; \
__VA_ARGS__(); \
break; \
} \
case 3: { \
constexpr int tensor_rank = 3; \
__VA_ARGS__(); \
break; \
} \
case 4: { \
constexpr int tensor_rank = 4; \
__VA_ARGS__(); \
break; \
} \
default: \
TORCH_CHECK(false, "Expects rank 2, 3 or 4 tensors but got ", NUM_DIMS); \
}
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include <cuda_runtime.h>
#include <type_traits>
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "type_convert.cuh"
#define CHECK_TYPE(x, st) \
TORCH_CHECK(x.scalar_type() == st, #x " dtype is ", x.scalar_type(), \
", while ", st, " is expected")
#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_TH_CUDA(x); \
CHECK_CONTIGUOUS(x)
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
#if defined(HIP_VERSION) && HIP_VERSION < 70000000
// On ROCm versions before 7.0, __syncwarp isn't defined. The below
// implementation is copy/pasted from the implementation in ROCm 7.0
__device__ inline void __syncwarp() {
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront");
__builtin_amdgcn_wave_barrier();
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront");
}
#endif
#else
#define FINAL_MASK 0xffffffff
#endif
namespace tensorrt_llm::common {
template <typename T, int num>
struct packed_as;
// Specialization for packed_as used in this kernel.
template <>
struct packed_as<uint, 1> {
using type = uint;
};
template <>
struct packed_as<uint, 2> {
using type = uint2;
};
template <>
struct packed_as<uint, 4> {
using type = uint4;
};
template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}
template <typename T>
inline __device__ __host__ T divUp(T m, T n) {
return (m + n - 1) / n;
}
} // namespace tensorrt_llm::common
namespace tensorrt_llm::kernels {
// NOTE(zhuhaoran): This kernel is adapted from TensorRT-LLM implementation,
// with added support for passing the cos_sin_cache as an input.
// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
// Perform per-head QK Norm and RoPE in a single kernel.
// scalar_t_in: data type of QKV and RMSNorm weights
// scalar_t_cache: data type of cos/sin cache
// head_dim: the dimension of each head
// interleave: interleave=!is_neox.
template <typename scalar_t_in, typename scalar_t_cache, int head_dim,
bool interleave>
__global__ void fusedQKNormRopeKernel(
void* qkv_void, // Combined QKV tensor
int const num_heads_q, // Number of query heads
int const num_heads_k, // Number of key heads
int const num_heads_v, // Number of value heads
float const eps, // Epsilon for RMS normalization
void const* q_weight_void, // RMSNorm weights for query
void const* k_weight_void, // RMSNorm weights for key
void const* cos_sin_cache_void, // Pre-computed cos/sin cache
int64_t const* position_ids, // Position IDs for RoPE
int const num_tokens // Number of tokens
) {
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
std::is_same_v<scalar_t_cache, c10::BFloat16>) {
return;
} else {
#endif
using Converter = vllm::_typeConvert<scalar_t_in>;
static_assert(Converter::exists,
"Input QKV data type is not supported for this CUDA "
"architecture or toolkit version.");
using T_in = typename Converter::hip_type;
using T2_in = typename Converter::packed_hip_type;
using CacheConverter = vllm::_typeConvert<scalar_t_cache>;
static_assert(CacheConverter::exists,
"Cache data type is not supported for this CUDA architecture "
"or toolkit version.");
using T_cache = typename CacheConverter::hip_type;
T_in* qkv = reinterpret_cast<T_in*>(qkv_void);
T_in const* q_weight = reinterpret_cast<T_in const*>(q_weight_void);
T_in const* k_weight = reinterpret_cast<T_in const*>(k_weight_void);
T_cache const* cos_sin_cache =
reinterpret_cast<T_cache const*>(cos_sin_cache_void);
int const warpsPerBlock = blockDim.x / 32;
int const warpId = threadIdx.x / 32;
int const laneId = threadIdx.x % 32;
// Calculate global warp index to determine which head/token this warp
// processes
int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId;
// Total number of attention heads (Q and K)
int const total_qk_heads = num_heads_q + num_heads_k;
// Determine which token and head type (Q or K) this warp processes
int const tokenIdx = globalWarpIdx / total_qk_heads;
int const localHeadIdx = globalWarpIdx % total_qk_heads;
// Skip if this warp is assigned beyond the number of tokens
if (tokenIdx >= num_tokens) return;
bool const isQ = localHeadIdx < num_heads_q;
int const headIdx = isQ ? localHeadIdx : localHeadIdx - num_heads_q;
int const num_heads = num_heads_q + num_heads_k + num_heads_v;
static_assert(head_dim % (32 * 2) == 0,
"head_dim must be divisible by 64 (each warp processes one "
"head, and each thread gets even number of "
"elements)");
constexpr int numElemsPerThread = head_dim / 32;
float elements[numElemsPerThread];
constexpr int elemSizeBytes = numElemsPerThread * sizeof(__nv_bfloat16);
static_assert(elemSizeBytes % 4 == 0,
"numSizeBytes must be a multiple of 4");
constexpr int vecSize =
elemSizeBytes /
4; // Use packed_as<uint, vecSize> to perform loading/saving.
using vec_T = typename tensorrt_llm::common::packed_as<uint, vecSize>::type;
int offsetWarp; // Offset for the warp
if (isQ) {
// Q segment: token offset + head offset within Q segment
offsetWarp = tokenIdx * num_heads * head_dim + headIdx * head_dim;
} else {
// K segment: token offset + entire Q segment + head offset within K
// segment
offsetWarp = tokenIdx * num_heads * head_dim + num_heads_q * head_dim +
headIdx * head_dim;
}
int offsetThread = offsetWarp + laneId * numElemsPerThread;
// Sum of squares for RMSNorm
float sumOfSquares = 0.0f;
// Load.
{
vec_T vec = *reinterpret_cast<vec_T const*>(&qkv[offsetThread]);
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
#pragma unroll
for (int i = 0; i < num_packed_elems; i++) {
// Interpret the generic vector chunk as the specific packed type
T2_in packed_val = *(reinterpret_cast<T2_in*>(&vec) + i);
// Convert to float2 for computation
float2 vals = Converter::convert(packed_val);
sumOfSquares += vals.x * vals.x;
sumOfSquares += vals.y * vals.y;
elements[2 * i] = vals.x;
elements[2 * i + 1] = vals.y;
}
}
// Reduce sum across warp using the utility function
sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares);
// Compute RMS normalization factor
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
// Normalize elements
#pragma unroll
for (int i = 0; i < numElemsPerThread; i++) {
int dim = laneId * numElemsPerThread + i;
float weight = isQ ? Converter::convert(q_weight[dim])
: Converter::convert(k_weight[dim]);
elements[i] *= rms_rcp * weight;
}
// Apply RoPE to normalized elements
float elements2[numElemsPerThread]; // Additional buffer required for RoPE.
int64_t pos_id = position_ids[tokenIdx];
// Calculate cache pointer for this position - similar to
// pos_encoding_kernels.cu
T_cache const* cache_ptr = cos_sin_cache + pos_id * head_dim;
int const embed_dim = head_dim / 2;
T_cache const* cos_ptr = cache_ptr;
T_cache const* sin_ptr = cache_ptr + embed_dim;
if constexpr (interleave) {
// Perform interleaving. Use pre-computed cos/sin values.
#pragma unroll
for (int i = 0; i < numElemsPerThread / 2; ++i) {
int const idx0 = 2 * i;
int const idx1 = 2 * i + 1;
float const val0 = elements[idx0];
float const val1 = elements[idx1];
int const dim_idx = laneId * numElemsPerThread + idx0;
int const half_dim = dim_idx / 2;
float const cos_val =
CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
float const sin_val =
CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
elements[idx0] = val0 * cos_val - val1 * sin_val;
elements[idx1] = val0 * sin_val + val1 * cos_val;
}
} else {
// Before data exchange with in warp, we need to sync.
__syncwarp();
// Get the data from the other half of the warp. Use pre-computed cos/sin
// values.
#pragma unroll
for (int i = 0; i < numElemsPerThread; i++) {
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16);
if (laneId < 16) {
elements2[i] = -elements2[i];
}
int dim_idx = laneId * numElemsPerThread + i;
dim_idx = (dim_idx * 2) % head_dim;
int half_dim = dim_idx / 2;
// Use pre-computed cos/sin from cache
float cos_val = CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
float sin_val = CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
elements[i] = elements[i] * cos_val + elements2[i] * sin_val;
}
// __shfl_xor_sync does not provide memfence. Need to sync again.
__syncwarp();
}
// Store.
{
vec_T vec;
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
#pragma unroll
for (int i = 0; i < num_packed_elems; i++) {
// Convert from float2 back to the specific packed type
T2_in packed_val = Converter::convert(
make_float2(elements[2 * i], elements[2 * i + 1]));
// Place it into the generic vector
*(reinterpret_cast<T2_in*>(&vec) + i) = packed_val;
}
*reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec;
}
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
}
#endif
}
// Borrowed from
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
const bool INTERLEAVE = true; \
__VA_ARGS__ \
} else { \
const bool INTERLEAVE = false; \
__VA_ARGS__ \
}
template <typename scalar_t_in, typename scalar_t_cache>
void launchFusedQKNormRope(void* qkv, int const num_tokens,
int const num_heads_q, int const num_heads_k,
int const num_heads_v, int const head_dim,
float const eps, void const* q_weight,
void const* k_weight, void const* cos_sin_cache,
bool const interleave, int64_t const* position_ids,
cudaStream_t stream) {
constexpr int blockSize = 256;
int const warpsPerBlock = blockSize / 32;
int const totalQKHeads = num_heads_q + num_heads_k;
int const totalWarps = num_tokens * totalQKHeads;
int const gridSize = common::divUp(totalWarps, warpsPerBlock);
dim3 gridDim(gridSize);
dim3 blockDim(blockSize);
switch (head_dim) {
case 64:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 64, INTERLEAVE>
<<<gridDim, blockDim, 0, stream>>>(
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
k_weight, cos_sin_cache, position_ids, num_tokens);
});
break;
case 128:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 128, INTERLEAVE>
<<<gridDim, blockDim, 0, stream>>>(
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
k_weight, cos_sin_cache, position_ids, num_tokens);
});
break;
case 256:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 256, INTERLEAVE>
<<<gridDim, blockDim, 0, stream>>>(
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
k_weight, cos_sin_cache, position_ids, num_tokens);
});
break;
default:
TORCH_CHECK(false,
"Unsupported head dimension for fusedQKNormRope: ", head_dim);
}
}
} // namespace tensorrt_llm::kernels
void fused_qk_norm_rope(
torch::Tensor& qkv, // Combined QKV tensor [num_tokens,
// (num_heads_q+num_heads_k+num_heads_v)*head_dim]
int64_t num_heads_q, // Number of query heads
int64_t num_heads_k, // Number of key heads
int64_t num_heads_v, // Number of value heads
int64_t head_dim, // Dimension per head
double eps, // Epsilon for RMS normalization
torch::Tensor& q_weight, // RMSNorm weights for query [head_dim]
torch::Tensor& k_weight, // RMSNorm weights for key [head_dim]
torch::Tensor& cos_sin_cache, // Cos/sin cache [max_position, head_dim]
bool is_neox, // Whether RoPE is applied in Neox style
torch::Tensor& position_ids // Position IDs for RoPE [num_tokens]
) {
// Input validation
CHECK_INPUT(qkv);
CHECK_INPUT(position_ids);
CHECK_INPUT(q_weight);
CHECK_INPUT(k_weight);
CHECK_INPUT(cos_sin_cache);
CHECK_TYPE(position_ids, torch::kInt64);
TORCH_CHECK(qkv.dim() == 2,
"QKV tensor must be 2D: [num_tokens, "
"(num_heads_q+num_heads_k+num_heads_v)*head_dim]");
TORCH_CHECK(position_ids.dim() == 1, "Position IDs must be 1D: [num_tokens]");
TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]");
TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]");
TORCH_CHECK(cos_sin_cache.dim() == 2,
"Cos/sin cache must be 2D: [max_position, head_dim]");
TORCH_CHECK(q_weight.size(0) == head_dim,
"Query weights size must match head dimension");
TORCH_CHECK(k_weight.size(0) == head_dim,
"Key weights size must match head dimension");
TORCH_CHECK(cos_sin_cache.size(1) == head_dim,
"Cos/sin cache dimension must match head_dim");
TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() &&
qkv.scalar_type() == k_weight.scalar_type(),
"qkv, q_weight and k_weight must have the same dtype");
int64_t num_tokens = qkv.size(0);
TORCH_CHECK(position_ids.size(0) == num_tokens,
"Number of tokens in position_ids must match QKV");
int64_t total_heads = num_heads_q + num_heads_k + num_heads_v;
TORCH_CHECK(
qkv.size(1) == total_heads * head_dim,
"QKV tensor size must match total number of heads and head dimension");
auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
VLLM_DISPATCH_HALF_TYPES(qkv.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
using qkv_scalar_t = scalar_t;
VLLM_DISPATCH_FLOATING_TYPES(
cos_sin_cache.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
using cache_scalar_t = scalar_t;
tensorrt_llm::kernels::launchFusedQKNormRope<qkv_scalar_t,
cache_scalar_t>(
qkv.data_ptr(), static_cast<int>(num_tokens),
static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
static_cast<int>(num_heads_v), static_cast<int>(head_dim),
static_cast<float>(eps), q_weight.data_ptr(), k_weight.data_ptr(),
cos_sin_cache.data_ptr(), !is_neox,
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
stream);
});
});
}
\ No newline at end of file
......@@ -8,11 +8,37 @@
#define VLLM_LAUNCH_BLOCKS_CAP 4
#endif
// compile-time estimate of max threads per SM for launch bounds.
// Compile-time estimate of max threads per SM for launch bounds.
// Families: 1024, 1536, 2048 threads/SM.
#ifndef VLLM_MAX_THREADS_PER_SM
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 300
#define VLLM_MAX_THREADS_PER_SM 1536
#ifdef __CUDA_ARCH__
/* 1024 thr/SM: Turing (sm_75) */
#if (__CUDA_ARCH__ == 750)
#define VLLM_MAX_THREADS_PER_SM 1024
/* 1536 thr/SM: Ampere GA10x (sm_86/87), Ada (sm_89),
GB20x consumer (sm_120/121), Thor (sm_101 or sm_110) */
#elif (__CUDA_ARCH__ == 860) || (__CUDA_ARCH__ == 870) || \
(__CUDA_ARCH__ == 890) || (__CUDA_ARCH__ == 1010) || \
(__CUDA_ARCH__ == 1100) || (__CUDA_ARCH__ == 1200) || \
(__CUDA_ARCH__ == 1210)
#define VLLM_MAX_THREADS_PER_SM 1536
/* 2048 thr/SM: Volta (sm_70/72), Ampere GA100 (sm_80),
Hopper (sm_90), Blackwell (sm_100/103) */
#elif (__CUDA_ARCH__ == 700) || (__CUDA_ARCH__ == 720) || \
(__CUDA_ARCH__ == 800) || (__CUDA_ARCH__ == 900) || \
(__CUDA_ARCH__ == 1000) || (__CUDA_ARCH__ == 1030)
#define VLLM_MAX_THREADS_PER_SM 2048
/* Fallback: use 2048 for unknown future CCs */
#else
#define VLLM_MAX_THREADS_PER_SM 2048
#endif
#else
/* Host pass (no __CUDA_ARCH__): neutral default */
#define VLLM_MAX_THREADS_PER_SM 2048
#endif
#endif
......
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
......@@ -8,20 +10,52 @@
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
template <typename scalar_t, int VEC_SIZE, int NUM_DIMS>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const int64_t input_stride,
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const int64_t input_stride_d2, // input.stride(-2)
const int64_t input_stride_d3, // input.stride(-3)
const int64_t input_stride_d4, // input.stride(-4)
const int64_t input_shape_d2, // input.size(-2)
const int64_t input_shape_d3, // input.size(-3)
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
const scalar_t* input_row;
if constexpr (NUM_DIMS == 2) {
// 2D for layernorm normal case [batch_size, hidden]
input_row = input + blockIdx.x * input_stride_d2;
} else if constexpr (NUM_DIMS == 3) {
// 3D for q/k norm [batch_size, num_heads, head_size]
int batch_idx = blockIdx.x / input_shape_d2;
int head_idx = blockIdx.x % input_shape_d2;
input_row =
input + batch_idx * input_stride_d3 + head_idx * input_stride_d2;
} else if constexpr (NUM_DIMS == 4) {
// 4D for transformers model_impl qk norm [batch, seq, head, head_dim]
int batch_idx = blockIdx.x / (input_shape_d3 * input_shape_d2);
int remaining = blockIdx.x % (input_shape_d3 * input_shape_d2);
int seq_idx = remaining / input_shape_d2;
int head_idx = remaining % input_shape_d2;
input_row = input + batch_idx * input_stride_d4 +
seq_idx * input_stride_d3 + head_idx * input_stride_d2;
}
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * input_stride + idx];
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
float x = static_cast<float>(vec.val[i]);
variance += x * x;
}
};
auto scalar_op = [&variance](const scalar_t& val) {
float x = static_cast<float>(val);
variance += x * x;
}
};
vllm::vectorize_read_with_alignment<VEC_SIZE>(
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
......@@ -32,10 +66,20 @@ __global__ void rms_norm_kernel(
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * input_stride + idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
scalar_t* out_row = out + blockIdx.x * hidden_size;
auto* v_in = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(input_row);
auto* v_w = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(weight);
auto* v_out = reinterpret_cast<vec_n_t<scalar_t, VEC_SIZE>*>(out_row);
for (int i = threadIdx.x; i < hidden_size / VEC_SIZE; i += blockDim.x) {
vec_n_t<scalar_t, VEC_SIZE> dst;
vec_n_t<scalar_t, VEC_SIZE> src1 = v_in[i];
vec_n_t<scalar_t, VEC_SIZE> src2 = v_w[i];
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
float x = static_cast<float>(src1.val[j]);
dst.val[j] = ((scalar_t)(x * s_variance)) * src2.val[j];
}
v_out[i] = dst;
}
}
......@@ -135,211 +179,6 @@ fused_add_rms_norm_kernel(
}
}
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck.
_f16VecPN struct extends _f16Vec to add operations specifically required for
polynomial normalization (poly norm).
The original _f16Vec does not include the sum-of-powers computation or
in-place polynomial normalization logic. */
template <typename scalar_t, int width>
struct alignas(16) _f16VecPN : _f16Vec<scalar_t, width> {
using Base = _f16Vec<scalar_t, width>;
using Converter = typename Base::Converter;
using T1 = typename Base::T1;
using T2 = typename Base::T2;
using Base::data;
__device__ auto sum_pows() const {
float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f;
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
float x2 = z.x * z.x;
float x4 = x2 * x2;
float x6 = x4 * x2;
float y2 = z.y * z.y;
float y4 = y2 * y2;
float y6 = y4 * y2;
s2 += x2 + y2;
s4 += x4 + y4;
s6 += x6 + y6;
}
return std::make_tuple(s2, s4, s6);
}
__device__ void poly_norm_inplace(const float w2_inv_std,
const float w1_inv_std2,
const float w0_inv_std3, const float bias) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
float x2 = z.x * z.x;
float x3 = x2 * z.x;
z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias;
float y2 = z.y * z.y;
float y3 = y2 * z.y;
z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias;
auto out = Converter::convert(z);
data[i] = out.x;
data[i + 1] = out.y;
}
}
};
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [3]
const scalar_t* __restrict__ bias, // [1]
const float epsilon, const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16VecPN<scalar_t, width>>);
static_assert(sizeof(_f16VecPN<scalar_t, width>) == sizeof(scalar_t) * width);
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ input_v =
reinterpret_cast<const _f16VecPN<scalar_t, width>*>(input);
const int vec_hidden_size = hidden_size / width;
float variance = 0.0f;
float variance2 = 0.0f;
float variance3 = 0.0f;
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16VecPN<scalar_t, width> temp = input_v[id];
auto [x2, x4, x6] = temp.sum_pows();
variance += x2;
variance2 += x4;
variance3 += x6;
}
float3 thread_variances = make_float3(variance, variance2, variance3);
struct SumOp {
__device__ float3 operator()(const float3& a, const float3& b) const {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
};
using BlockReduce = cub::BlockReduce<float3, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float3 block_variances =
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
variance = block_variances.x;
variance2 = block_variances.y;
variance3 = block_variances.z;
__shared__ float s_w2_inv_std;
__shared__ float s_w1_inv_std2;
__shared__ float s_w0_inv_std3;
__shared__ float s_bias;
if (threadIdx.x == 0) {
float w0 = (float)weight[0];
float w1 = (float)weight[1];
float w2 = (float)weight[2];
s_bias = (float)bias[0];
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
}
__syncthreads();
auto* __restrict__ out_v = reinterpret_cast<_f16VecPN<scalar_t, width>*>(out);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16VecPN<scalar_t, width> temp = input_v[id];
temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias);
out_v[id] = temp;
}
}
/* Generic poly_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [3]
const scalar_t* __restrict__ bias, // [1]
const float epsilon, const int hidden_size) {
float variance = 0.0f;
float variance2 = 0.0f;
float variance3 = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x2 = x * x;
float x4 = x2 * x2;
float x6 = x4 * x2;
variance += x2;
variance2 += x4;
variance3 += x6;
}
float3 thread_variances = make_float3(variance, variance2, variance3);
struct SumOp {
__device__ float3 operator()(const float3& a, const float3& b) const {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
};
using BlockReduce = cub::BlockReduce<float3, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float3 block_variances =
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
variance = block_variances.x;
variance2 = block_variances.y;
variance3 = block_variances.z;
__shared__ float s_w2_inv_std;
__shared__ float s_w1_inv_std2;
__shared__ float s_w0_inv_std3;
__shared__ float s_bias;
if (threadIdx.x == 0) {
float w0 = (float)weight[0];
float w1 = (float)weight[1];
float w2 = (float)weight[2];
s_bias = (float)bias[0];
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x2 = x * x;
float x3 = x2 * x;
out[blockIdx.x * hidden_size + idx] =
(scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 +
s_bias);
}
}
} // namespace vllm
void rms_norm(torch::Tensor& out, // [..., hidden_size]
......@@ -347,21 +186,43 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
if (input.stride(-1) != 1) {
input = input.contiguous();
}
TORCH_CHECK(input.stride(-1) == 1);
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
int64_t input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;
int num_dims = input.dim();
int64_t input_stride_d2 = input.stride(-2);
int64_t input_stride_d3 = (num_dims >= 3) ? input.stride(-3) : 0;
int64_t input_stride_d4 = (num_dims >= 4) ? input.stride(-4) : 0;
int64_t input_shape_d2 = (num_dims >= 3) ? input.size(-2) : 0;
int64_t input_shape_d3 = (num_dims >= 4) ? input.size(-3) : 0;
// For large num_tokens, use smaller blocks to increase SM concurrency.
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
VLLM_DISPATCH_RANK234(num_dims, [&] {
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
const int calculated_vec_size =
std::gcd(16 / sizeof(scalar_t), hidden_size);
const int block_size =
std::min(hidden_size / calculated_vec_size, max_block_size);
dim3 block(block_size);
VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
vllm::rms_norm_kernel<scalar_t, vec_size, tensor_rank>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
input_stride_d2, input_stride_d3, input_stride_d4,
input_shape_d2, input_shape_d3, weight.data_ptr<scalar_t>(),
epsilon, num_tokens, hidden_size);
});
});
});
}
......@@ -379,6 +240,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
TORCH_CHECK(input.scalar_type() == residual.scalar_type());
TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);
......@@ -413,55 +276,11 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
wt_ptr % req_alignment_bytes == 0;
bool offsets_are_multiple_of_vector_width =
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
!batch_invariant_launch) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}
#define LAUNCH_FUSED_POLY_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon, \
hidden_size); \
});
void poly_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [3]
torch::Tensor& bias, // [1]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.data_ptr() != input.data_ptr());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_POLY_NORM(8);
} else {
LAUNCH_FUSED_POLY_NORM(0);
}
}
......@@ -7,10 +7,12 @@
#include "type_convert.cuh"
#ifndef USE_ROCM
#include "quantization/fp8/common.cuh"
#include "quantization/w8a8/fp8/common.cuh"
#endif
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
......@@ -18,7 +20,7 @@
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t, typename fp8_type>
template <typename scalar_t, typename fp8_type, int VEC_SIZE>
__global__ void rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
......@@ -29,10 +31,21 @@ __global__ void rms_norm_static_fp8_quant_kernel(
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * input_stride + idx];
const scalar_t* input_row = input + blockIdx.x * input_stride;
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
float x = static_cast<float>(vec.val[i]);
variance += x * x;
}
};
auto scalar_op = [&variance](const scalar_t& val) {
float x = static_cast<float>(val);
variance += x * x;
}
};
vllm::vectorize_read_with_alignment<VEC_SIZE>(
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
......@@ -46,11 +59,18 @@ __global__ void rms_norm_static_fp8_quant_kernel(
// invert scale to avoid division
float const scale_inv = 1.0f / *scale;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * input_stride + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
auto* v_in = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(input_row);
auto* v_w = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(weight);
for (int idx = threadIdx.x; idx < hidden_size / VEC_SIZE; idx += blockDim.x) {
vec_n_t<scalar_t, VEC_SIZE> src1 = v_in[idx];
vec_n_t<scalar_t, VEC_SIZE> src2 = v_w[idx];
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
float x = static_cast<float>(src1.val[j]);
float const out_norm = ((scalar_t)(x * s_variance)) * src2.val[j];
out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
}
}
}
......@@ -176,20 +196,29 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
int input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;
// For large num_tokens, use smaller blocks to increase SM concurrency.
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
input_stride, weight.data_ptr<scalar_t>(),
scale.data_ptr<float>(), epsilon, num_tokens,
hidden_size);
const int calculated_vec_size =
std::gcd(16 / sizeof(scalar_t), hidden_size);
const int block_size =
std::min(hidden_size / calculated_vec_size, max_block_size);
dim3 block(block_size);
VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t,
vec_size>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
input_stride, weight.data_ptr<scalar_t>(),
scale.data_ptr<float>(), epsilon, num_tokens,
hidden_size);
});
});
});
}
......@@ -217,6 +246,8 @@ void fused_add_rms_norm_static_fp8_quant(
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(residual.scalar_type() == input.scalar_type());
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
int hidden_size = input.size(-1);
int input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;
......@@ -242,7 +273,9 @@ void fused_add_rms_norm_static_fp8_quant(
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) {
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
!batch_invariant_launch) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
......
......@@ -24,6 +24,8 @@ struct SSMParamsBase {
int64_t pad_slot_id;
bool delta_softplus;
bool cache_enabled;
int block_size;
index_t A_d_stride;
index_t A_dstate_stride;
......@@ -46,8 +48,9 @@ struct SSMParamsBase {
index_t out_z_batch_stride;
index_t out_z_d_stride;
index_t ssm_states_batch_stride;
index_t ssm_states_dim_stride;
index_t ssm_states_dim_stride;
index_t ssm_states_dstate_stride;
index_t cache_indices_stride;
// Common data pointers.
void *__restrict__ A_ptr;
......@@ -66,6 +69,9 @@ struct SSMParamsBase {
void *__restrict__ cache_indices_ptr;
void *__restrict__ has_initial_state_ptr;
void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write
void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write
void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use
};
......
......@@ -119,7 +119,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
......@@ -133,9 +133,18 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
typename Ktraits::state_t *ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
cache_index * params.ssm_states_batch_stride +
dim_id * kNRows * params.ssm_states_dim_stride;
typename Ktraits::state_t *ssm_states;
if (params.cache_enabled) {
// APC mode: ssm_states points to the base, we'll use absolute cache slots later
ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
dim_id * kNRows * params.ssm_states_dim_stride;
} else {
// Non-APC mode: offset by cache_index as before
ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
cache_index * params.ssm_states_batch_stride +
dim_id * kNRows * params.ssm_states_dim_stride;
}
float D_val[kNRows] = {0};
if (params.D_ptr != nullptr) {
......@@ -159,7 +168,22 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// }
constexpr int kChunkSize = kNThreads * kNItems;
const int n_chunks = (seqlen + 2048 - 1) / 2048;
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048;
const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size;
const int* batch_cache_indices = cache_indices != nullptr ?
cache_indices + batch_id * params.cache_indices_stride : nullptr;
const int* block_idx_first_scheduled = params.block_idx_first_scheduled_token_ptr != nullptr ?
reinterpret_cast<const int*>(params.block_idx_first_scheduled_token_ptr) : nullptr;
const int* block_idx_last_scheduled = params.block_idx_last_scheduled_token_ptr != nullptr ?
reinterpret_cast<const int*>(params.block_idx_last_scheduled_token_ptr) : nullptr;
const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ?
reinterpret_cast<const int*>(params.initial_state_idx_ptr) : nullptr;
const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;
for (int chunk = 0; chunk < n_chunks; ++chunk) {
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
......@@ -219,7 +243,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (kIsVariableC) {
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1 ));
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1));
if constexpr (!kIsVariableB) {
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
......@@ -242,7 +266,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
for (int i = 0; i < kNItems; ++i) {
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct
if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) {
thread_data[i] = make_float2(1.f, 0.f);
......@@ -250,8 +273,24 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
// Initialize running total
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0);
scan_t running_prefix;
if (chunk > 0) {
running_prefix = smem_running_prefix[state_idx + r * MAX_DSTATE];
} else {
// Load initial state
if (params.cache_enabled && has_initial_state && batch_cache_indices != nullptr) {
size_t state_offset = load_cache_slot * params.ssm_states_batch_stride +
r * params.ssm_states_dim_stride +
state_idx * params.ssm_states_dstate_stride;
running_prefix = make_float2(1.0, float(ssm_states[state_offset]));
} else if (has_initial_state) {
// Non-APC mode: load from current batch position
running_prefix = make_float2(1.0, float(ssm_states[state_idx * params.ssm_states_dstate_stride]));
} else {
// No initial state
running_prefix = make_float2(1.0, 0.0);
}
}
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
......@@ -260,8 +299,25 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// There's a syncthreads in the scan op, so we don't need to sync here.
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
if (threadIdx.x == 0) {
smem_running_prefix[state_idx] = prefix_op.running_prefix;
if (chunk == n_chunks - 1) {
smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix;
// Store state at the end of each chunk when cache is enabled
if (params.cache_enabled && batch_cache_indices != nullptr) {
size_t cache_slot;
if (chunk == n_chunks - 1) {
cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
} else {
cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk];
}
size_t state_offset = cache_slot * params.ssm_states_batch_stride +
r * params.ssm_states_dim_stride +
state_idx * params.ssm_states_dstate_stride;
ssm_states[state_offset] = typename Ktraits::state_t(prefix_op.running_prefix.y);
} else if (!params.cache_enabled && chunk == n_chunks - 1) {
// Non-APC mode: store only final state at current batch position
ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y);
}
}
......@@ -274,7 +330,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
}
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
__syncthreads();
......@@ -346,7 +401,9 @@ template<typename input_t, typename weight_t, typename state_t>
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
#ifndef USE_ROCM
if (params.seqlen <= 128) {
if (params.cache_enabled && params.block_size == 1024) {
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 128) {
selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 256) {
selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream);
......@@ -358,7 +415,9 @@ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
}
#else
if (params.seqlen <= 256) {
if (params.cache_enabled && params.block_size == 1024) {
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 256) {
selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 512) {
selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream);
......@@ -437,13 +496,17 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const std::optional<at::Tensor>& D,
const std::optional<at::Tensor>& delta_bias,
const torch::Tensor ssm_states,
bool has_z,
bool has_z,
bool delta_softplus,
const std::optional<at::Tensor>& query_start_loc,
const std::optional<at::Tensor>& cache_indices,
const std::optional<at::Tensor>& has_initial_state,
bool varlen,
int64_t pad_slot_id) {
int64_t pad_slot_id,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::Tensor> &initial_state_idx) {
// Reset the parameters
memset(&params, 0, sizeof(params));
......@@ -477,6 +540,14 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
// Set cache parameters - cache is enabled if we have direct cache writing params
params.cache_enabled = block_idx_first_scheduled_token.has_value();
params.block_size = static_cast<int>(block_size);
// Set direct cache writing pointers
params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr;
params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr;
params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr;
// All stride are in elements, not bytes.
params.A_d_stride = A.stride(0);
......@@ -504,9 +575,11 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.out_d_stride = out.stride(0);
params.ssm_states_batch_stride = ssm_states.stride(0);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dstate_stride = ssm_states.stride(2);
params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0;
}
else{
if (!is_variable_B) {
......@@ -537,8 +610,10 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.out_d_stride = out.stride(1);
params.ssm_states_batch_stride = ssm_states.stride(0);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dstate_stride = ssm_states.stride(2);
params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0;
}
}
......@@ -554,7 +629,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const torch::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
int64_t pad_slot_id,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::Tensor> &initial_state_idx) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
......@@ -646,7 +725,16 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
auto cache_indices_ = cache_indices.value();
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(cache_indices_.is_cuda());
CHECK_SHAPE(cache_indices_, batch_size);
// cache_indices can be either 1D (batch_size,) for non-APC mode
// or 2D (batch_size, max_positions) for APC mode
const bool is_apc_mode = block_idx_first_scheduled_token.has_value();
if (is_apc_mode) {
TORCH_CHECK(cache_indices_.dim() == 2, "cache_indices must be 2D for APC mode");
TORCH_CHECK(cache_indices_.size(0) == batch_size, "cache_indices first dimension must match batch_size");
} else {
CHECK_SHAPE(cache_indices_, batch_size);
}
}
......@@ -686,7 +774,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
cache_indices,
has_initial_state,
varlen,
pad_slot_id
pad_slot_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx
);
......
......@@ -87,30 +87,23 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;
// Per-expert outputs filled in parallel
std::vector<torch::Tensor> y_list(E);
y_list.resize(E);
at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
for (int64_t e = e_begin; e < e_end; ++e) {
const int64_t te = counts[e];
if (te == 0) {
y_list[e] = at::empty({0, H}, x_c.options());
auto X_all = x_c.index_select(/*dim=*/0, expert_tokens);
if (apply_router_weight_on_input) {
X_all = X_all.mul(expert_gates.unsqueeze(1));
}
auto Y_all = at::empty({offsets[E], H}, x_c.options());
at::parallel_for(0, offsets[E], 0, [&](int64_t idx_begin, int64_t idx_end) {
c10::InferenceMode guard;
for (int64_t e = 0; e < E; ++e) {
int64_t start = std::max(offsets[e], idx_begin);
int64_t end = std::min(offsets[e + 1], idx_end);
int64_t te = end - start;
if (te <= 0) {
continue;
}
const int64_t start = offsets[e];
auto sel_tokens =
expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto gates_e =
expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto x_e = x_c.index_select(/*dim=*/0, sel_tokens);
if (apply_router_weight_on_input) {
x_e = x_e.mul(gates_e.unsqueeze(1));
}
auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto w13_e = w13_packed.select(/*dim=*/0, e);
auto w2_e = w2_packed.select(/*dim=*/0, e);
......@@ -137,17 +130,15 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
// W2
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);
if (!apply_router_weight_on_input) {
y = y.mul(gates_e.unsqueeze(1));
}
// Store per-expert result
y_list[e] = y;
Y_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te).copy_(y);
}
});
// Concatenate all expert outputs to match expert_tokens order
auto Y_all = at::cat(y_list, /*dim=*/0);
if (!apply_router_weight_on_input) {
Y_all = Y_all.mul(expert_gates.unsqueeze(1));
}
auto out = at::zeros({T, H}, x.options());
out =
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);
......
......@@ -427,11 +427,29 @@ __device__ inline bool is_finite(const T val) {
#endif
}
// Scoring function enums
enum ScoringFunc {
SCORING_NONE = 0, // no activation function
SCORING_SIGMOID = 1 // apply sigmoid
};
// Efficient sigmoid approximation from TensorRT-LLM
__device__ inline float sigmoid_accurate(float x) {
return 0.5f * tanhf(0.5f * x) + 0.5f;
}
template <typename T>
__device__ void topk_with_k2(T* output, T const* input,
__device__ inline T apply_sigmoid(T val) {
float f = cuda_cast<float, T>(val);
return cuda_cast<T, float>(sigmoid_accurate(f));
}
template <typename T>
__device__ void topk_with_k2(T* output, T const* input, T const* bias,
cg::thread_block_tile<32> const& tile,
int32_t const lane_id,
int const num_experts_per_group) {
int const num_experts_per_group,
int const scoring_func) {
// Get the top2 per thread
T largest = neg_inf<T>();
T second_largest = neg_inf<T>();
......@@ -439,6 +457,12 @@ __device__ void topk_with_k2(T* output, T const* input,
if (num_experts_per_group > WARP_SIZE) {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
T value = input[i];
// Apply scoring function if needed
if (scoring_func == SCORING_SIGMOID) {
value = apply_sigmoid(value);
}
value = value + bias[i];
if (value > largest) {
second_largest = largest;
largest = value;
......@@ -448,7 +472,13 @@ __device__ void topk_with_k2(T* output, T const* input,
}
} else {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
largest = input[i];
T value = input[i];
// Apply scoring function if needed
if (scoring_func == SCORING_SIGMOID) {
value = apply_sigmoid(value);
}
value = value + bias[i];
largest = value;
}
}
......@@ -472,17 +502,21 @@ __device__ void topk_with_k2(T* output, T const* input,
}
template <typename T>
__global__ void topk_with_k2_kernel(T* output, T* input,
__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
int64_t const num_tokens,
int64_t const num_cases,
int64_t const n_group,
int64_t const num_experts_per_group) {
int64_t const num_experts_per_group,
int const scoring_func) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
if (case_id < num_cases) {
input += case_id * num_experts_per_group;
// bias is per expert group, offset to current group
int32_t group_id = case_id % n_group;
T const* group_bias = bias + group_id * num_experts_per_group;
output += case_id;
cg::thread_block block = cg::this_thread_block();
......@@ -491,7 +525,8 @@ __global__ void topk_with_k2_kernel(T* output, T* input,
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
topk_with_k2(output, input, group_bias, tile, lane_id,
num_experts_per_group, scoring_func);
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
......@@ -500,16 +535,15 @@ __global__ void topk_with_k2_kernel(T* output, T* input,
template <typename T, typename IdxT>
__global__ void group_idx_and_topk_idx_kernel(
T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices,
T* scores_with_bias, int64_t const num_tokens, int64_t const n_group,
T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
T const* bias, int64_t const num_tokens, int64_t const n_group,
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
int64_t const num_experts_per_group, bool renormalize,
double routed_scaling_factor) {
double routed_scaling_factor, int scoring_func) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id =
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
scores_with_bias += case_id * num_experts;
scores += case_id * num_experts;
group_scores += case_id * n_group;
topk_values += case_id * topk;
......@@ -577,10 +611,16 @@ __global__ void group_idx_and_topk_idx_kernel(
int32_t offset = i_group * num_experts_per_group;
for (int32_t i = lane_id; i < align_num_experts_per_group;
i += WARP_SIZE) {
T candidates = (i < num_experts_per_group) &&
is_finite(scores_with_bias[offset + i])
? scores_with_bias[offset + i]
: neg_inf<T>();
T candidates = neg_inf<T>();
if (i < num_experts_per_group) {
// Apply scoring function (if any) and add bias
T input = scores[offset + i];
if (is_finite(input)) {
T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input)
: input;
candidates = score + bias[offset + i];
}
}
queue.add(candidates, offset + i);
}
if (group_scores[i_group] == topk_group_value) {
......@@ -602,11 +642,12 @@ __global__ void group_idx_and_topk_idx_kernel(
for (int i = lane_id;
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
i += WARP_SIZE) {
T value =
i < topk
? scores[s_topk_idx[i]]
: cuda_cast<T, float>(0.0f); // Load the valid value of expert
T value = cuda_cast<T, float>(0.0f);
if (i < topk) {
// Load the score value (without bias) for normalization
T input = scores[s_topk_idx[i]];
value =
(scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) : input;
s_topk_value[i] = value;
}
topk_sum +=
......@@ -627,12 +668,12 @@ __global__ void group_idx_and_topk_idx_kernel(
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
}
topk_indices[i] = s_topk_idx[i];
topk_values[i] = cuda_cast<T, float>(value);
topk_values[i] = value;
}
} else {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
topk_indices[i] = i;
topk_values[i] = cuda_cast<T, float>(1.0f / topk);
topk_values[i] = 1.0f / topk;
}
}
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
......@@ -644,12 +685,12 @@ __global__ void group_idx_and_topk_idx_kernel(
}
template <typename T, typename IdxT>
void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
IdxT* topk_indices, T* scores_with_bias,
int64_t const num_tokens, int64_t const num_experts,
int64_t const n_group, int64_t const topk_group,
int64_t const topk, bool const renormalize,
double const routed_scaling_factor, bool enable_pdl = false,
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
IdxT* topk_indices, T const* bias, int64_t const num_tokens,
int64_t const num_experts, int64_t const n_group,
int64_t const topk_group, int64_t const topk,
bool const renormalize, double const routed_scaling_factor,
int const scoring_func, bool enable_pdl = false,
cudaStream_t const stream = 0) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
......@@ -664,8 +705,9 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
num_tokens, num_cases, n_group, num_experts / n_group);
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
num_tokens, num_cases, n_group, num_experts / n_group,
scoring_func);
int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
......@@ -682,19 +724,18 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
topk_values, topk_indices, scores_with_bias, num_tokens,
n_group, topk_group, topk, num_experts,
num_experts / n_group, renormalize, routed_scaling_factor);
topk_values, topk_indices, bias, num_tokens, n_group,
topk_group, topk, num_experts, num_experts / n_group,
renormalize, routed_scaling_factor, scoring_func);
}
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
template void invokeNoAuxTc<T, IdxT>( \
T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \
T * scores_with_bias, int64_t const num_tokens, \
int64_t const num_experts, int64_t const n_group, \
int64_t const topk_group, int64_t const topk, bool const renormalize, \
double const routed_scaling_factor, bool enable_pdl, \
cudaStream_t const stream);
T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \
T const* bias, int64_t const num_tokens, int64_t const num_experts, \
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
bool const renormalize, double const routed_scaling_factor, \
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
INSTANTIATE_NOAUX_TC(float, int32_t);
INSTANTIATE_NOAUX_TC(half, int32_t);
......@@ -703,28 +744,32 @@ INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
} // namespace vllm
std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
torch::Tensor const& scores, torch::Tensor const& scores_with_bias,
int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize,
double routed_scaling_factor) {
auto data_type = scores_with_bias.scalar_type();
auto input_size = scores_with_bias.sizes();
torch::Tensor const& scores, int64_t n_group, int64_t topk_group,
int64_t topk, bool renormalize, double routed_scaling_factor,
torch::Tensor const& bias, int64_t scoring_func = 0) {
auto data_type = scores.scalar_type();
auto input_size = scores.sizes();
int64_t num_tokens = input_size[0];
int64_t num_experts = input_size[1];
TORCH_CHECK(input_size.size() == 2, "scores_with_bias must be a 2D Tensor");
TORCH_CHECK(input_size.size() == 2, "scores must be a 2D Tensor");
TORCH_CHECK(num_experts % n_group == 0,
"num_experts should be divisible by n_group");
TORCH_CHECK(n_group <= 32,
"n_group should be smaller than or equal to 32 for now");
TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now");
TORCH_CHECK(scoring_func == vllm::moe::SCORING_NONE ||
scoring_func == vllm::moe::SCORING_SIGMOID,
"scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)");
torch::Tensor group_scores = torch::empty(
{num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA));
// Always output float32 for topk_values (eliminates Python-side conversion)
torch::Tensor topk_values = torch::empty(
{num_tokens, topk}, torch::dtype(data_type).device(torch::kCUDA));
{num_tokens, topk}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
torch::Tensor topk_indices = torch::empty(
{num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto stream = c10::cuda::getCurrentCUDAStream(scores_with_bias.get_device());
auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device());
switch (data_type) {
case torch::kFloat16:
......@@ -732,11 +777,11 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
vllm::moe::invokeNoAuxTc<half, int32_t>(
reinterpret_cast<half*>(scores.mutable_data_ptr()),
reinterpret_cast<half*>(group_scores.mutable_data_ptr()),
reinterpret_cast<half*>(topk_values.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<half*>(scores_with_bias.data_ptr()), num_tokens,
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, false, stream);
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
break;
case torch::kFloat32:
// Handle Float32
......@@ -745,20 +790,20 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
reinterpret_cast<float*>(group_scores.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<float*>(scores_with_bias.data_ptr()), num_tokens,
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, false, stream);
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
break;
case torch::kBFloat16:
// Handle BFloat16
vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>(
reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16*>(topk_values.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16*>(scores_with_bias.data_ptr()),
num_tokens, num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, false, stream);
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
break;
default:
// Handle other data types
......
kernel_*.cu
\ No newline at end of file
sm*_kernel_*.cu
kernel_selector.h
......@@ -4,128 +4,282 @@ import glob
import itertools
import os
import subprocess
import sys
import jinja2
FILE_HEAD = """
// auto generated by generate.py
ARCHS = []
SUPPORT_FP8 = False
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but it’s simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
// clang-format off
""".lstrip()
FILE_HEAD = (
FILE_HEAD_COMMENT
+ """
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
""".strip()
TEMPLATE = ("template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );")
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = [
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
"vllm::kFE2M1f"
]
"""
)
TEMPLATE = (
"template __global__ void Marlin<"
"{{a_type_id}}, "
"{{b_type_id}}, "
"{{c_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{m_block_size_8}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{is_zp_float}}>"
"( MARLIN_KERNEL_PARAMS );"
)
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]
QUANT_CONFIGS = [
# AWQ-INT4
{
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4
{
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# AWQ-INT8
{
"b_type": "kU8B128",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# FP8
{
"b_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 8],
},
# NVFP4
{
"b_type": "kFE2M1f",
"s_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [1],
},
# MXFP4
{
"a_type": ["kBFloat16"],
"b_type": "kFE2M1f",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kFE2M1f",
"c_type": ["kBFloat16"],
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [2],
},
]
def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
subprocess.call(["rm", "-f", filename])
filename = os.path.dirname(__file__) + "/kernel_selector.h"
subprocess.call(["rm", "-f", filename])
def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
all_template_str_list = []
result_dict = {}
for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
b_type = quant_config["b_type"]
all_group_blocks = quant_config["group_blocks"]
all_m_blocks = quant_config["thread_m_blocks"]
all_thread_configs = quant_config["thread_configs"]
# act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8", "vllm::kU8B128"
]:
continue
if thread_configs[2] == 256:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue
# we only support channelwise quantization and group_size == 128
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
for a_type, c_type in itertools.product(a_types, c_types):
if not SUPPORT_FP8 and a_type == "kFE4M3fn":
continue
# nvfp4 only supports group_size == 16
# mxfp4 only supports group_size == 32
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
if "16" in a_type and "16" in c_type and a_type != c_type:
continue
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs
):
thread_k, thread_n, threads = thread_configs
if threads == 256:
# for small batch (m_blocks == 1),
# we only need (128, 128, 256)
# for large batch (m_blocks > 1),
# we only need (64, 256, 256)
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
continue
if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]
config = {
"threads": threads,
"s_type": s_type,
"thread_m_blocks": max(m_blocks, 1),
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"group_blocks": group_blocks,
"is_zp_float": "false",
}
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
result_dict[(a_type, b_type, c_type)].append(config)
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
s_type = "vllm::kFE4M3fn"
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
s_type = "vllm::kFE8M0fnu"
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items():
all_template_str_list = []
for config in config_list:
s_type = config["s_type"]
template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
s_type_id=s_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
group_blocks=group_blocks,
is_zp_float=False,
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
all_template_str_list.append(template_str)
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += (
"else if (a_type == vllm::kFE4M3fn)\n"
" TORCH_CHECK(false, "
'"marlin kernel with fp8 activation is not built.");'
)
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
f.write(kernel_selector_str)
if __name__ == "__main__":
remove_old_kernels()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment