Commit 006693ed authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.11.2' into v0.11.2-ori

parents 4b51e6f1 275de341
......@@ -3,14 +3,58 @@
// need to be unsigned long long
#include <iostream>
#include "cumem_allocator_compat.h"
#ifndef USE_ROCM
static const char* PYARGS_PARSE = "KKKK";
#else
#include <cstdlib>
#include <cerrno>
#include <climits>
// Default chunk size 256MB for ROCm. Can be overridden at runtime by the
// environment variable VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE, specified in megabytes
// (MB). The env value is parsed with strtoull as an integer number of MB
// (decimal or 0x hex). The parsed MB value is converted to bytes. If
// parsing fails, the value is 0, or the multiplication would overflow,
// the default (256MB) is used.
static const unsigned long long DEFAULT_MEMCREATE_CHUNK_SIZE =
(256ULL * 1024ULL * 1024ULL);
static unsigned long long get_memcreate_chunk_size() {
const char* env = getenv("VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE");
if (!env) return DEFAULT_MEMCREATE_CHUNK_SIZE;
char* endptr = nullptr;
errno = 0;
unsigned long long val_mb = strtoull(env, &endptr, 0);
if (endptr == env || errno != 0) {
// parsing failed, fallback to default
return DEFAULT_MEMCREATE_CHUNK_SIZE;
}
if (val_mb == 0) return DEFAULT_MEMCREATE_CHUNK_SIZE;
const unsigned long long MB = 1024ULL * 1024ULL;
// guard against overflow when converting MB -> bytes
if (val_mb > (ULLONG_MAX / MB)) {
return DEFAULT_MEMCREATE_CHUNK_SIZE;
}
return val_mb * MB;
}
static inline unsigned long long my_min(unsigned long long a,
unsigned long long b) {
return a < b ? a : b;
}
static const char* PYARGS_PARSE = "KKKO";
#endif
extern "C" {
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <cuda.h>
char error_msg[10240]; // 10KB buffer to store error messages
CUresult no_error = CUresult(0);
......@@ -49,7 +93,12 @@ void ensure_context(unsigned long long device) {
}
void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
#ifndef USE_ROCM
CUmemGenericAllocationHandle* p_memHandle) {
#else
CUmemGenericAllocationHandle** p_memHandle,
unsigned long long* chunk_sizes, size_t num_chunks) {
#endif
ensure_context(device);
// Define memory allocation properties
CUmemAllocationProp prop = {};
......@@ -58,6 +107,7 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
prop.location.id = device;
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
#ifndef USE_ROCM
// Allocate memory using cuMemCreate
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
if (error_code != 0) {
......@@ -67,6 +117,39 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
if (error_code != 0) {
return;
}
#else
for (auto i = 0; i < num_chunks; ++i) {
CUDA_CHECK(cuMemCreate(p_memHandle[i], chunk_sizes[i], &prop, 0));
if (error_code != 0) {
// Clean up previously created handles
for (auto j = 0; j < i; ++j) {
cuMemRelease(*(p_memHandle[j]));
}
return;
}
}
unsigned long long allocated_size = 0;
for (auto i = 0; i < num_chunks; ++i) {
void* map_addr = (void*)((uintptr_t)d_mem + allocated_size);
CUDA_CHECK(cuMemMap(map_addr, chunk_sizes[i], 0, *(p_memHandle[i]), 0));
if (error_code != 0) {
// unmap previously mapped chunks
unsigned long long unmapped_size = 0;
for (auto j = 0; j < i; ++j) {
void* unmap_addr = (void*)((uintptr_t)d_mem + unmapped_size);
cuMemUnmap(unmap_addr, chunk_sizes[j]);
unmapped_size += chunk_sizes[j];
}
// release all created handles
for (auto j = 0; j < num_chunks; ++j) {
cuMemRelease(*(p_memHandle[j]));
}
return;
}
allocated_size += chunk_sizes[i];
}
#endif
CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = device;
......@@ -82,10 +165,16 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
void unmap_and_release(unsigned long long device, ssize_t size,
CUdeviceptr d_mem,
#ifndef USE_ROCM
CUmemGenericAllocationHandle* p_memHandle) {
#else
CUmemGenericAllocationHandle** p_memHandle,
unsigned long long* chunk_sizes, size_t num_chunks) {
#endif
// std::cout << "unmap_and_release: device=" << device << ", size=" << size <<
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
ensure_context(device);
#ifndef USE_ROCM
CUDA_CHECK(cuMemUnmap(d_mem, size));
if (error_code != 0) {
return;
......@@ -94,6 +183,30 @@ void unmap_and_release(unsigned long long device, ssize_t size,
if (error_code != 0) {
return;
}
#else
unsigned long long allocated_size = 0;
CUresult first_error = no_error;
for (auto i = 0; i < num_chunks; ++i) {
void* map_addr = (void*)((uintptr_t)d_mem + allocated_size);
CUresult status = cuMemUnmap(map_addr, chunk_sizes[i]);
if (status != no_error && first_error == no_error) {
first_error = status;
}
allocated_size += chunk_sizes[i];
}
for (auto i = 0; i < num_chunks; ++i) {
CUresult status = cuMemRelease(*(p_memHandle[i]));
if (status != no_error && first_error == no_error) {
first_error = status;
}
}
if (first_error != no_error) {
CUDA_CHECK(first_error);
}
#endif
}
PyObject* create_tuple_from_c_integers(unsigned long long a,
......@@ -120,6 +233,36 @@ PyObject* create_tuple_from_c_integers(unsigned long long a,
return tuple; // Return the created tuple
}
PyObject* create_tuple_from_c_mixed(unsigned long long a, unsigned long long b,
unsigned long long c,
CUmemGenericAllocationHandle** vec,
unsigned long long* chunk_sizes,
size_t num_chunks) {
PyObject* tuple = PyTuple_New(4);
if (!tuple) {
return NULL;
}
// PyObject* list = PyList_New(vec.size());
PyObject* list = PyList_New(num_chunks);
for (auto i = 0; i < num_chunks; ++i) {
PyObject* addr_size_pair = PyTuple_New(2);
PyObject* addr = PyLong_FromUnsignedLongLong((unsigned long long)(vec[i]));
PyObject* size =
PyLong_FromUnsignedLongLong((unsigned long long)(chunk_sizes[i]));
PyTuple_SetItem(addr_size_pair, 0, addr);
PyTuple_SetItem(addr_size_pair, 1, size);
PyList_SetItem(list, i, addr_size_pair);
}
PyTuple_SetItem(tuple, 0, PyLong_FromUnsignedLongLong(a));
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b));
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
PyTuple_SetItem(tuple, 3, list);
return tuple;
}
// ---------------------------------------------------------------------------
// Our exported C functions that call Python:
......@@ -147,14 +290,55 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
size_t alignedSize = ((size + granularity - 1) / granularity) * granularity;
CUdeviceptr d_mem;
#ifndef USE_ROCM
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0));
if (error_code != 0) {
return nullptr;
}
#else
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, granularity, 0, 0));
if (error_code != 0) {
return nullptr;
}
#endif
#ifndef USE_ROCM
// allocate the CUmemGenericAllocationHandle
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)malloc(
sizeof(CUmemGenericAllocationHandle));
#else
// Make sure chunk size is aligned with hardware granularity. The base
// chunk size can be configured via environment variable
// ``VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE``; otherwise
// DEFAULT_MEMCREATE_CHUNK_SIZE is used.
size_t base_chunk = (size_t)get_memcreate_chunk_size();
size_t aligned_chunk_size =
((base_chunk + granularity - 1) / granularity) * granularity;
size_t num_chunks =
(alignedSize + aligned_chunk_size - 1) / aligned_chunk_size;
CUmemGenericAllocationHandle** p_memHandle =
(CUmemGenericAllocationHandle**)malloc(
num_chunks * sizeof(CUmemGenericAllocationHandle*));
unsigned long long* chunk_sizes =
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
for (auto i = 0; i < num_chunks; ++i) {
p_memHandle[i] = (CUmemGenericAllocationHandle*)malloc(
sizeof(CUmemGenericAllocationHandle));
if (p_memHandle[i] == nullptr) {
std::cerr << "ERROR: malloc failed for p_memHandle[" << i << "].\n";
for (auto j = 0; j < i; ++j) {
free(p_memHandle[j]);
}
free(p_memHandle);
free(chunk_sizes);
return nullptr;
}
chunk_sizes[i] = (unsigned long long)my_min(
(unsigned long long)(alignedSize - i * aligned_chunk_size),
(unsigned long long)aligned_chunk_size);
}
#endif
if (!g_python_malloc_callback) {
std::cerr << "ERROR: g_python_malloc_callback not set.\n";
......@@ -164,9 +348,15 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();
#ifndef USE_ROCM
PyObject* arg_tuple = create_tuple_from_c_integers(
(unsigned long long)device, (unsigned long long)alignedSize,
(unsigned long long)d_mem, (unsigned long long)p_memHandle);
#else
PyObject* arg_tuple = create_tuple_from_c_mixed(
(unsigned long long)device, (unsigned long long)alignedSize,
(unsigned long long)d_mem, p_memHandle, chunk_sizes, num_chunks);
#endif
// Call g_python_malloc_callback
PyObject* py_result =
......@@ -182,7 +372,27 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
PyGILState_Release(gstate);
// do the final mapping
#ifndef USE_ROCM
create_and_map(device, alignedSize, d_mem, p_memHandle);
#else
create_and_map(device, alignedSize, d_mem, p_memHandle, chunk_sizes,
num_chunks);
free(chunk_sizes);
#endif
if (error_code != 0) {
// free address and the handle
CUDA_CHECK(cuMemAddressFree(d_mem, alignedSize));
#ifndef USE_ROCM
free(p_memHandle);
#else
for (size_t i = 0; i < num_chunks; ++i) {
free(p_memHandle[i]);
}
free(p_memHandle);
#endif
return nullptr;
}
return (void*)d_mem;
}
......@@ -206,36 +416,96 @@ void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
Py_XDECREF(py_result);
Py_XDECREF(py_ptr);
return;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
unsigned long long recv_d_mem;
#ifndef USE_ROCM
unsigned long long recv_p_memHandle;
#else
PyObject* recv_p_memHandle;
#endif
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size,
if (!PyArg_ParseTuple(py_result, PYARGS_PARSE, &recv_device, &recv_size,
&recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
Py_XDECREF(py_result);
Py_XDECREF(py_ptr);
return;
}
PyGILState_Release(gstate);
// For ROCm, copy the Python list of (addr,size) pairs into C arrays while
// holding the GIL. Then release the GIL and call the unmap/release helper
// using the copied arrays. This avoids calling PyList_* APIs without the
// GIL (which is undefined behavior and can crash when called from other
// threads).
CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem;
#ifdef USE_ROCM
Py_ssize_t num_chunks = PyList_Size(recv_p_memHandle);
CUmemGenericAllocationHandle** p_memHandle =
(CUmemGenericAllocationHandle**)malloc(
num_chunks * sizeof(CUmemGenericAllocationHandle*));
if (p_memHandle == nullptr) {
Py_DECREF(py_ptr);
Py_DECREF(py_result);
PyGILState_Release(gstate);
std::cerr << "ERROR: malloc failed for p_memHandle in my_free."
<< std::endl;
return;
}
unsigned long long* chunk_sizes =
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
if (chunk_sizes == nullptr) {
free(p_memHandle);
Py_DECREF(py_ptr);
Py_DECREF(py_result);
PyGILState_Release(gstate);
std::cerr << "ERROR: malloc failed for chunk_sizes in my_free."
<< std::endl;
return;
}
for (Py_ssize_t i = 0; i < num_chunks; ++i) {
PyObject* item = PyList_GetItem(recv_p_memHandle, i);
PyObject* addr_py = PyTuple_GetItem(item, 0);
PyObject* size_py = PyTuple_GetItem(item, 1);
p_memHandle[i] =
(CUmemGenericAllocationHandle*)PyLong_AsUnsignedLongLong(addr_py);
chunk_sizes[i] = (unsigned long long)PyLong_AsUnsignedLongLong(size_py);
}
// recv_size == size
// recv_device == device
// Drop temporary Python refs, then release the GIL before calling into
// non-Python APIs.
Py_DECREF(py_ptr);
Py_DECREF(py_result);
PyGILState_Release(gstate);
// Free memory
unmap_and_release(device, size, d_mem, p_memHandle, chunk_sizes, num_chunks);
#else
// Non-ROCm path: simple integer handle already extracted; drop temporary
// Python refs while still holding the GIL, then release it.
Py_DECREF(py_ptr);
Py_DECREF(py_result);
PyGILState_Release(gstate);
CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem;
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)recv_p_memHandle;
unmap_and_release(device, size, d_mem, p_memHandle);
#endif
// free address and the handle
CUDA_CHECK(cuMemAddressFree(d_mem, size));
if (error_code != 0) {
return;
#ifndef USE_ROCM
free(p_memHandle);
#else
for (auto i = 0; i < num_chunks; ++i) {
free(p_memHandle[i]);
}
free(p_memHandle);
free(chunk_sizes);
#endif
}
// ---------------------------------------------------------------------------
......@@ -271,19 +541,87 @@ static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
unsigned long long recv_d_mem;
#ifndef USE_ROCM
unsigned long long recv_p_memHandle;
#else
PyObject* recv_p_memHandle;
#endif
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
if (!PyArg_ParseTuple(args, PYARGS_PARSE, &recv_device, &recv_size,
&recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
#ifndef USE_ROCM
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)recv_p_memHandle;
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
#else
if (!PyList_Check(recv_p_memHandle)) {
PyErr_SetString(PyExc_TypeError,
"Expected a list for the 4th argument on ROCm");
return nullptr;
}
Py_ssize_t num_chunks = PyList_Size(recv_p_memHandle);
if (num_chunks < 0) {
return nullptr; // PyList_Size sets an exception on error.
}
CUmemGenericAllocationHandle** p_memHandle =
(CUmemGenericAllocationHandle**)malloc(
num_chunks * sizeof(CUmemGenericAllocationHandle*));
if (p_memHandle == nullptr) {
PyErr_SetString(PyExc_MemoryError, "malloc failed for p_memHandle");
return nullptr;
}
unsigned long long* chunk_sizes =
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
if (chunk_sizes == nullptr) {
free(p_memHandle);
PyErr_SetString(PyExc_MemoryError, "malloc failed for chunk_sizes");
return nullptr;
}
for (Py_ssize_t i = 0; i < num_chunks; ++i) {
PyObject* item = PyList_GetItem(recv_p_memHandle, i);
if (item == nullptr || !PyTuple_Check(item) || PyTuple_Size(item) != 2) {
free(p_memHandle);
free(chunk_sizes);
PyErr_SetString(
PyExc_TypeError,
"List items must be tuples of size 2 (handle_addr, size)");
return nullptr;
}
PyObject* addr_py = PyTuple_GetItem(item, 0);
PyObject* size_py = PyTuple_GetItem(item, 1);
if (addr_py == nullptr || size_py == nullptr) {
free(p_memHandle);
free(chunk_sizes);
return nullptr; // PyTuple_GetItem sets an exception
}
p_memHandle[i] =
(CUmemGenericAllocationHandle*)PyLong_AsUnsignedLongLong(addr_py);
if (PyErr_Occurred()) {
free(p_memHandle);
free(chunk_sizes);
return nullptr;
}
chunk_sizes[i] = (unsigned long long)PyLong_AsUnsignedLongLong(size_py);
if (PyErr_Occurred()) {
free(p_memHandle);
free(chunk_sizes);
return nullptr;
}
}
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle, chunk_sizes,
num_chunks);
free(p_memHandle);
free(chunk_sizes);
#endif
if (error_code != 0) {
error_code = no_error;
......@@ -301,19 +639,56 @@ static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
unsigned long long recv_d_mem;
#ifndef USE_ROCM
unsigned long long recv_p_memHandle;
#else
PyObject* recv_p_memHandle;
#endif
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
if (!PyArg_ParseTuple(args, PYARGS_PARSE, &recv_device, &recv_size,
&recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
#ifndef USE_ROCM
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)recv_p_memHandle;
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
#else
Py_ssize_t num_chunks = PyList_Size(recv_p_memHandle);
CUmemGenericAllocationHandle** p_memHandle =
(CUmemGenericAllocationHandle**)malloc(
num_chunks * sizeof(CUmemGenericAllocationHandle*));
if (p_memHandle == nullptr) {
PyErr_SetString(PyExc_MemoryError, "malloc failed for p_memHandle");
return nullptr;
}
unsigned long long* chunk_sizes =
(unsigned long long*)malloc(num_chunks * sizeof(unsigned long long));
if (chunk_sizes == nullptr) {
free(p_memHandle);
PyErr_SetString(PyExc_MemoryError, "malloc failed for chunk_sizes");
return nullptr;
}
for (auto i = 0; i < num_chunks; ++i) {
PyObject* item = PyList_GetItem(recv_p_memHandle, i);
PyObject* addr_py = PyTuple_GetItem(item, 0);
PyObject* size_py = PyTuple_GetItem(item, 1);
p_memHandle[i] =
(CUmemGenericAllocationHandle*)PyLong_AsUnsignedLongLong(addr_py);
chunk_sizes[i] = PyLong_AsUnsignedLongLong(size_py);
}
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle, chunk_sizes,
num_chunks);
free(p_memHandle);
free(chunk_sizes);
#endif
if (error_code != 0) {
error_code = no_error;
......
#pragma once
#ifdef USE_ROCM
////////////////////////////////////////
// For compatibility with CUDA and ROCm
////////////////////////////////////////
#include <hip/hip_runtime_api.h>
extern "C" {
#ifndef CUDA_SUCCESS
#define CUDA_SUCCESS hipSuccess
#endif // CUDA_SUCCESS
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html
typedef unsigned long long CUdevice;
typedef hipDeviceptr_t CUdeviceptr;
typedef hipError_t CUresult;
typedef hipCtx_t CUcontext;
typedef hipStream_t CUstream;
typedef hipMemGenericAllocationHandle_t CUmemGenericAllocationHandle;
typedef hipMemAllocationGranularity_flags CUmemAllocationGranularity_flags;
typedef hipMemAllocationProp CUmemAllocationProp;
typedef hipMemAccessDesc CUmemAccessDesc;
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_MEM_ALLOC_GRANULARITY_MINIMUM hipMemAllocationGranularityMinimum
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
#define CU_MEM_ALLOCATION_COMP_NONE 0x0
// Error Handling
// https://docs.nvidia.com/cuda/archive/11.4.4/cuda-driver-api/group__CUDA__ERROR.html
CUresult cuGetErrorString(CUresult hipError, const char** pStr) {
*pStr = hipGetErrorString(hipError);
return CUDA_SUCCESS;
}
// Context Management
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html
CUresult cuCtxGetCurrent(CUcontext* ctx) {
// This API is deprecated on the AMD platform, only for equivalent cuCtx
// driver API on the NVIDIA platform.
return hipCtxGetCurrent(ctx);
}
CUresult cuCtxSetCurrent(CUcontext ctx) {
// This API is deprecated on the AMD platform, only for equivalent cuCtx
// driver API on the NVIDIA platform.
return hipCtxSetCurrent(ctx);
}
// Primary Context Management
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PRIMARY__CTX.html
CUresult cuDevicePrimaryCtxRetain(CUcontext* ctx, CUdevice dev) {
return hipDevicePrimaryCtxRetain(ctx, dev);
}
// Virtual Memory Management
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html
CUresult cuMemAddressFree(CUdeviceptr ptr, size_t size) {
return hipMemAddressFree(ptr, size);
}
CUresult cuMemAddressReserve(CUdeviceptr* ptr, size_t size, size_t alignment,
CUdeviceptr addr, unsigned long long flags) {
return hipMemAddressReserve(ptr, size, alignment, addr, flags);
}
CUresult cuMemCreate(CUmemGenericAllocationHandle* handle, size_t size,
const CUmemAllocationProp* prop,
unsigned long long flags) {
return hipMemCreate(handle, size, prop, flags);
}
CUresult cuMemGetAllocationGranularity(
size_t* granularity, const CUmemAllocationProp* prop,
CUmemAllocationGranularity_flags option) {
return hipMemGetAllocationGranularity(granularity, prop, option);
}
CUresult cuMemMap(CUdeviceptr dptr, size_t size, size_t offset,
CUmemGenericAllocationHandle handle,
unsigned long long flags) {
return hipMemMap(dptr, size, offset, handle, flags);
}
CUresult cuMemRelease(CUmemGenericAllocationHandle handle) {
return hipMemRelease(handle);
}
CUresult cuMemSetAccess(CUdeviceptr ptr, size_t size,
const CUmemAccessDesc* desc, size_t count) {
return hipMemSetAccess(ptr, size, desc, count);
}
CUresult cuMemUnmap(CUdeviceptr ptr, size_t size) {
return hipMemUnmap(ptr, size);
}
} // extern "C"
#else
////////////////////////////////////////
// Import CUDA headers for NVIDIA GPUs
////////////////////////////////////////
#include <cuda_runtime_api.h>
#include <cuda.h>
#endif
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from typing import Union
from cutlass_library import *
......@@ -22,31 +21,31 @@ class MixedInputKernelScheduleType(enum.Enum):
TmaWarpSpecializedCooperative = enum_auto()
VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = {
VLLMDataTypeNames: dict[VLLMDataType | DataType, str] = {
**DataTypeNames, # type: ignore
**{
VLLMDataType.u4b8: "u4b8",
VLLMDataType.u8b128: "u8b128",
}
},
}
VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
VLLMDataTypeTag: dict[VLLMDataType | DataType, str] = {
**DataTypeTag, # type: ignore
**{
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
}
},
}
VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
VLLMDataTypeSize: dict[VLLMDataType | DataType, int] = {
**DataTypeSize, # type: ignore
**{
VLLMDataType.u4b8: 4,
VLLMDataType.u8b128: 8,
}
},
}
VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
VLLMDataTypeVLLMScalarTypeTag: dict[VLLMDataType | DataType, str] = {
VLLMDataType.u4b8: "vllm::kU4B8",
VLLMDataType.u8b128: "vllm::kU8B128",
DataType.u4: "vllm::kU4",
......@@ -57,7 +56,7 @@ VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
DataType.bf16: "vllm::kBfloat16",
}
VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
VLLMDataTypeTorchDataTypeTag: dict[VLLMDataType | DataType, str] = {
DataType.u8: "at::ScalarType::Byte",
DataType.s8: "at::ScalarType::Char",
DataType.e4m3: "at::ScalarType::Float8_e4m3fn",
......@@ -67,15 +66,11 @@ VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
DataType.f32: "at::ScalarType::Float",
}
VLLMKernelScheduleTag: dict[Union[
MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore
**{
MixedInputKernelScheduleType.TmaWarpSpecialized:
"cutlass::gemm::KernelTmaWarpSpecialized",
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong:
"cutlass::gemm::KernelTmaWarpSpecializedPingpong",
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative:
"cutlass::gemm::KernelTmaWarpSpecializedCooperative",
}
}
VLLMKernelScheduleTag: dict[MixedInputKernelScheduleType | KernelScheduleType, str] = {
**KernelScheduleTag, # type: ignore
**{
MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized", # noqa: E501
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong", # noqa: E501
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative", # noqa: E501
},
}
......@@ -88,3 +88,32 @@
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_VEC_SIZE(VEC_SIZE, ...) \
switch (VEC_SIZE) { \
case 16: { \
constexpr int vec_size = 16; \
__VA_ARGS__(); \
break; \
} \
case 8: { \
constexpr int vec_size = 8; \
__VA_ARGS__(); \
break; \
} \
case 4: { \
constexpr int vec_size = 4; \
__VA_ARGS__(); \
break; \
} \
case 2: { \
constexpr int vec_size = 2; \
__VA_ARGS__(); \
break; \
} \
default: { \
constexpr int vec_size = 1; \
__VA_ARGS__(); \
break; \
} \
}
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include <cuda_runtime.h>
#include <type_traits>
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "type_convert.cuh"
#define CHECK_TYPE(x, st) \
TORCH_CHECK(x.scalar_type() == st, #x " dtype is ", x.scalar_type(), \
", while ", st, " is expected")
#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_TH_CUDA(x); \
CHECK_CONTIGUOUS(x)
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
#if defined(HIP_VERSION) && HIP_VERSION < 70000000
// On ROCm versions before 7.0, __syncwarp isn't defined. The below
// implementation is copy/pasted from the implementation in ROCm 7.0
__device__ inline void __syncwarp() {
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront");
__builtin_amdgcn_wave_barrier();
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront");
}
#endif
#else
#define FINAL_MASK 0xffffffff
#endif
namespace tensorrt_llm::common {
template <typename T, int num>
struct packed_as;
// Specialization for packed_as used in this kernel.
template <>
struct packed_as<uint, 1> {
using type = uint;
};
template <>
struct packed_as<uint, 2> {
using type = uint2;
};
template <>
struct packed_as<uint, 4> {
using type = uint4;
};
template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}
template <typename T>
inline __device__ __host__ T divUp(T m, T n) {
return (m + n - 1) / n;
}
} // namespace tensorrt_llm::common
namespace tensorrt_llm::kernels {
// NOTE(zhuhaoran): This kernel is adapted from TensorRT-LLM implementation,
// with added support for passing the cos_sin_cache as an input.
// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
// Perform per-head QK Norm and RoPE in a single kernel.
// scalar_t_in: data type of QKV and RMSNorm weights
// scalar_t_cache: data type of cos/sin cache
// head_dim: the dimension of each head
// interleave: interleave=!is_neox.
template <typename scalar_t_in, typename scalar_t_cache, int head_dim,
bool interleave>
__global__ void fusedQKNormRopeKernel(
void* qkv_void, // Combined QKV tensor
int const num_heads_q, // Number of query heads
int const num_heads_k, // Number of key heads
int const num_heads_v, // Number of value heads
float const eps, // Epsilon for RMS normalization
void const* q_weight_void, // RMSNorm weights for query
void const* k_weight_void, // RMSNorm weights for key
void const* cos_sin_cache_void, // Pre-computed cos/sin cache
int64_t const* position_ids, // Position IDs for RoPE
int const num_tokens // Number of tokens
) {
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
if constexpr ((std::is_same_v<scalar_t_in, c10::BFloat16>) ||
std::is_same_v<scalar_t_cache, c10::BFloat16>) {
return;
} else {
#endif
using Converter = vllm::_typeConvert<scalar_t_in>;
static_assert(Converter::exists,
"Input QKV data type is not supported for this CUDA "
"architecture or toolkit version.");
using T_in = typename Converter::hip_type;
using T2_in = typename Converter::packed_hip_type;
using CacheConverter = vllm::_typeConvert<scalar_t_cache>;
static_assert(CacheConverter::exists,
"Cache data type is not supported for this CUDA architecture "
"or toolkit version.");
using T_cache = typename CacheConverter::hip_type;
T_in* qkv = reinterpret_cast<T_in*>(qkv_void);
T_in const* q_weight = reinterpret_cast<T_in const*>(q_weight_void);
T_in const* k_weight = reinterpret_cast<T_in const*>(k_weight_void);
T_cache const* cos_sin_cache =
reinterpret_cast<T_cache const*>(cos_sin_cache_void);
int const warpsPerBlock = blockDim.x / 32;
int const warpId = threadIdx.x / 32;
int const laneId = threadIdx.x % 32;
// Calculate global warp index to determine which head/token this warp
// processes
int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId;
// Total number of attention heads (Q and K)
int const total_qk_heads = num_heads_q + num_heads_k;
// Determine which token and head type (Q or K) this warp processes
int const tokenIdx = globalWarpIdx / total_qk_heads;
int const localHeadIdx = globalWarpIdx % total_qk_heads;
// Skip if this warp is assigned beyond the number of tokens
if (tokenIdx >= num_tokens) return;
bool const isQ = localHeadIdx < num_heads_q;
int const headIdx = isQ ? localHeadIdx : localHeadIdx - num_heads_q;
int const num_heads = num_heads_q + num_heads_k + num_heads_v;
static_assert(head_dim % (32 * 2) == 0,
"head_dim must be divisible by 64 (each warp processes one "
"head, and each thread gets even number of "
"elements)");
constexpr int numElemsPerThread = head_dim / 32;
float elements[numElemsPerThread];
constexpr int elemSizeBytes = numElemsPerThread * sizeof(__nv_bfloat16);
static_assert(elemSizeBytes % 4 == 0,
"numSizeBytes must be a multiple of 4");
constexpr int vecSize =
elemSizeBytes /
4; // Use packed_as<uint, vecSize> to perform loading/saving.
using vec_T = typename tensorrt_llm::common::packed_as<uint, vecSize>::type;
int offsetWarp; // Offset for the warp
if (isQ) {
// Q segment: token offset + head offset within Q segment
offsetWarp = tokenIdx * num_heads * head_dim + headIdx * head_dim;
} else {
// K segment: token offset + entire Q segment + head offset within K
// segment
offsetWarp = tokenIdx * num_heads * head_dim + num_heads_q * head_dim +
headIdx * head_dim;
}
int offsetThread = offsetWarp + laneId * numElemsPerThread;
// Sum of squares for RMSNorm
float sumOfSquares = 0.0f;
// Load.
{
vec_T vec = *reinterpret_cast<vec_T const*>(&qkv[offsetThread]);
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
#pragma unroll
for (int i = 0; i < num_packed_elems; i++) {
// Interpret the generic vector chunk as the specific packed type
T2_in packed_val = *(reinterpret_cast<T2_in*>(&vec) + i);
// Convert to float2 for computation
float2 vals = Converter::convert(packed_val);
sumOfSquares += vals.x * vals.x;
sumOfSquares += vals.y * vals.y;
elements[2 * i] = vals.x;
elements[2 * i + 1] = vals.y;
}
}
// Reduce sum across warp using the utility function
sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares);
// Compute RMS normalization factor
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
// Normalize elements
#pragma unroll
for (int i = 0; i < numElemsPerThread; i++) {
int dim = laneId * numElemsPerThread + i;
float weight = isQ ? Converter::convert(q_weight[dim])
: Converter::convert(k_weight[dim]);
elements[i] *= rms_rcp * weight;
}
// Apply RoPE to normalized elements
float elements2[numElemsPerThread]; // Additional buffer required for RoPE.
int64_t pos_id = position_ids[tokenIdx];
// Calculate cache pointer for this position - similar to
// pos_encoding_kernels.cu
T_cache const* cache_ptr = cos_sin_cache + pos_id * head_dim;
int const embed_dim = head_dim / 2;
T_cache const* cos_ptr = cache_ptr;
T_cache const* sin_ptr = cache_ptr + embed_dim;
if constexpr (interleave) {
// Perform interleaving. Use pre-computed cos/sin values.
#pragma unroll
for (int i = 0; i < numElemsPerThread / 2; ++i) {
int const idx0 = 2 * i;
int const idx1 = 2 * i + 1;
float const val0 = elements[idx0];
float const val1 = elements[idx1];
int const dim_idx = laneId * numElemsPerThread + idx0;
int const half_dim = dim_idx / 2;
float const cos_val =
CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
float const sin_val =
CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
elements[idx0] = val0 * cos_val - val1 * sin_val;
elements[idx1] = val0 * sin_val + val1 * cos_val;
}
} else {
// Before data exchange with in warp, we need to sync.
__syncwarp();
// Get the data from the other half of the warp. Use pre-computed cos/sin
// values.
#pragma unroll
for (int i = 0; i < numElemsPerThread; i++) {
elements2[i] = __shfl_xor_sync(FINAL_MASK, elements[i], 16);
if (laneId < 16) {
elements2[i] = -elements2[i];
}
int dim_idx = laneId * numElemsPerThread + i;
dim_idx = (dim_idx * 2) % head_dim;
int half_dim = dim_idx / 2;
// Use pre-computed cos/sin from cache
float cos_val = CacheConverter::convert(VLLM_LDG(cos_ptr + half_dim));
float sin_val = CacheConverter::convert(VLLM_LDG(sin_ptr + half_dim));
elements[i] = elements[i] * cos_val + elements2[i] * sin_val;
}
// __shfl_xor_sync does not provide memfence. Need to sync again.
__syncwarp();
}
// Store.
{
vec_T vec;
constexpr int num_packed_elems = elemSizeBytes / sizeof(T2_in);
#pragma unroll
for (int i = 0; i < num_packed_elems; i++) {
// Convert from float2 back to the specific packed type
T2_in packed_val = Converter::convert(
make_float2(elements[2 * i], elements[2 * i + 1]));
// Place it into the generic vector
*(reinterpret_cast<T2_in*>(&vec) + i) = packed_val;
}
*reinterpret_cast<vec_T*>(&qkv[offsetThread]) = vec;
}
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
}
#endif
}
// Borrowed from
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
const bool INTERLEAVE = true; \
__VA_ARGS__ \
} else { \
const bool INTERLEAVE = false; \
__VA_ARGS__ \
}
template <typename scalar_t_in, typename scalar_t_cache>
void launchFusedQKNormRope(void* qkv, int const num_tokens,
int const num_heads_q, int const num_heads_k,
int const num_heads_v, int const head_dim,
float const eps, void const* q_weight,
void const* k_weight, void const* cos_sin_cache,
bool const interleave, int64_t const* position_ids,
cudaStream_t stream) {
constexpr int blockSize = 256;
int const warpsPerBlock = blockSize / 32;
int const totalQKHeads = num_heads_q + num_heads_k;
int const totalWarps = num_tokens * totalQKHeads;
int const gridSize = common::divUp(totalWarps, warpsPerBlock);
dim3 gridDim(gridSize);
dim3 blockDim(blockSize);
switch (head_dim) {
case 64:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 64, INTERLEAVE>
<<<gridDim, blockDim, 0, stream>>>(
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
k_weight, cos_sin_cache, position_ids, num_tokens);
});
break;
case 128:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 128, INTERLEAVE>
<<<gridDim, blockDim, 0, stream>>>(
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
k_weight, cos_sin_cache, position_ids, num_tokens);
});
break;
case 256:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
fusedQKNormRopeKernel<scalar_t_in, scalar_t_cache, 256, INTERLEAVE>
<<<gridDim, blockDim, 0, stream>>>(
qkv, num_heads_q, num_heads_k, num_heads_v, eps, q_weight,
k_weight, cos_sin_cache, position_ids, num_tokens);
});
break;
default:
TORCH_CHECK(false,
"Unsupported head dimension for fusedQKNormRope: ", head_dim);
}
}
} // namespace tensorrt_llm::kernels
void fused_qk_norm_rope(
torch::Tensor& qkv, // Combined QKV tensor [num_tokens,
// (num_heads_q+num_heads_k+num_heads_v)*head_dim]
int64_t num_heads_q, // Number of query heads
int64_t num_heads_k, // Number of key heads
int64_t num_heads_v, // Number of value heads
int64_t head_dim, // Dimension per head
double eps, // Epsilon for RMS normalization
torch::Tensor& q_weight, // RMSNorm weights for query [head_dim]
torch::Tensor& k_weight, // RMSNorm weights for key [head_dim]
torch::Tensor& cos_sin_cache, // Cos/sin cache [max_position, head_dim]
bool is_neox, // Whether RoPE is applied in Neox style
torch::Tensor& position_ids // Position IDs for RoPE [num_tokens]
) {
// Input validation
CHECK_INPUT(qkv);
CHECK_INPUT(position_ids);
CHECK_INPUT(q_weight);
CHECK_INPUT(k_weight);
CHECK_INPUT(cos_sin_cache);
CHECK_TYPE(position_ids, torch::kInt64);
TORCH_CHECK(qkv.dim() == 2,
"QKV tensor must be 2D: [num_tokens, "
"(num_heads_q+num_heads_k+num_heads_v)*head_dim]");
TORCH_CHECK(position_ids.dim() == 1, "Position IDs must be 1D: [num_tokens]");
TORCH_CHECK(q_weight.dim() == 1, "Query weights must be 1D: [head_dim]");
TORCH_CHECK(k_weight.dim() == 1, "Key weights must be 1D: [head_dim]");
TORCH_CHECK(cos_sin_cache.dim() == 2,
"Cos/sin cache must be 2D: [max_position, head_dim]");
TORCH_CHECK(q_weight.size(0) == head_dim,
"Query weights size must match head dimension");
TORCH_CHECK(k_weight.size(0) == head_dim,
"Key weights size must match head dimension");
TORCH_CHECK(cos_sin_cache.size(1) == head_dim,
"Cos/sin cache dimension must match head_dim");
TORCH_CHECK(qkv.scalar_type() == q_weight.scalar_type() &&
qkv.scalar_type() == k_weight.scalar_type(),
"qkv, q_weight and k_weight must have the same dtype");
int64_t num_tokens = qkv.size(0);
TORCH_CHECK(position_ids.size(0) == num_tokens,
"Number of tokens in position_ids must match QKV");
int64_t total_heads = num_heads_q + num_heads_k + num_heads_v;
TORCH_CHECK(
qkv.size(1) == total_heads * head_dim,
"QKV tensor size must match total number of heads and head dimension");
auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
VLLM_DISPATCH_HALF_TYPES(qkv.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
using qkv_scalar_t = scalar_t;
VLLM_DISPATCH_FLOATING_TYPES(
cos_sin_cache.scalar_type(), "fused_qk_norm_rope_kernel", [&] {
using cache_scalar_t = scalar_t;
tensorrt_llm::kernels::launchFusedQKNormRope<qkv_scalar_t,
cache_scalar_t>(
qkv.data_ptr(), static_cast<int>(num_tokens),
static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
static_cast<int>(num_heads_v), static_cast<int>(head_dim),
static_cast<float>(eps), q_weight.data_ptr(), k_weight.data_ptr(),
cos_sin_cache.data_ptr(), !is_neox,
reinterpret_cast<int64_t const*>(position_ids.data_ptr()),
stream);
});
});
}
\ No newline at end of file
......@@ -8,11 +8,37 @@
#define VLLM_LAUNCH_BLOCKS_CAP 4
#endif
// compile-time estimate of max threads per SM for launch bounds.
// Compile-time estimate of max threads per SM for launch bounds.
// Families: 1024, 1536, 2048 threads/SM.
#ifndef VLLM_MAX_THREADS_PER_SM
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 300
#define VLLM_MAX_THREADS_PER_SM 1536
#ifdef __CUDA_ARCH__
/* 1024 thr/SM: Turing (sm_75) */
#if (__CUDA_ARCH__ == 750)
#define VLLM_MAX_THREADS_PER_SM 1024
/* 1536 thr/SM: Ampere GA10x (sm_86/87), Ada (sm_89),
GB20x consumer (sm_120/121), Thor (sm_101 or sm_110) */
#elif (__CUDA_ARCH__ == 860) || (__CUDA_ARCH__ == 870) || \
(__CUDA_ARCH__ == 890) || (__CUDA_ARCH__ == 1010) || \
(__CUDA_ARCH__ == 1100) || (__CUDA_ARCH__ == 1200) || \
(__CUDA_ARCH__ == 1210)
#define VLLM_MAX_THREADS_PER_SM 1536
/* 2048 thr/SM: Volta (sm_70/72), Ampere GA100 (sm_80),
Hopper (sm_90), Blackwell (sm_100/103) */
#elif (__CUDA_ARCH__ == 700) || (__CUDA_ARCH__ == 720) || \
(__CUDA_ARCH__ == 800) || (__CUDA_ARCH__ == 900) || \
(__CUDA_ARCH__ == 1000) || (__CUDA_ARCH__ == 1030)
#define VLLM_MAX_THREADS_PER_SM 2048
/* Fallback: use 2048 for unknown future CCs */
#else
#define VLLM_MAX_THREADS_PER_SM 2048
#endif
#else
/* Host pass (no __CUDA_ARCH__): neutral default */
#define VLLM_MAX_THREADS_PER_SM 2048
#endif
#endif
......
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
......@@ -8,7 +10,7 @@
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
template <typename scalar_t, int VEC_SIZE>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
......@@ -17,11 +19,21 @@ __global__ void rms_norm_kernel(
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
const scalar_t* input_row = input + blockIdx.x * input_stride;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * input_stride + idx];
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
float x = static_cast<float>(vec.val[i]);
variance += x * x;
}
};
auto scalar_op = [&variance](const scalar_t& val) {
float x = static_cast<float>(val);
variance += x * x;
}
};
vllm::vectorize_read_with_alignment<VEC_SIZE>(
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
......@@ -32,10 +44,20 @@ __global__ void rms_norm_kernel(
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * input_stride + idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
scalar_t* out_row = out + blockIdx.x * hidden_size;
auto* v_in = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(input_row);
auto* v_w = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(weight);
auto* v_out = reinterpret_cast<vec_n_t<scalar_t, VEC_SIZE>*>(out_row);
for (int i = threadIdx.x; i < hidden_size / VEC_SIZE; i += blockDim.x) {
vec_n_t<scalar_t, VEC_SIZE> dst;
vec_n_t<scalar_t, VEC_SIZE> src1 = v_in[i];
vec_n_t<scalar_t, VEC_SIZE> src2 = v_w[i];
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
float x = static_cast<float>(src1.val[j]);
dst.val[j] = ((scalar_t)(x * s_variance)) * src2.val[j];
}
v_out[i] = dst;
}
}
......@@ -135,211 +157,6 @@ fused_add_rms_norm_kernel(
}
}
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck.
_f16VecPN struct extends _f16Vec to add operations specifically required for
polynomial normalization (poly norm).
The original _f16Vec does not include the sum-of-powers computation or
in-place polynomial normalization logic. */
template <typename scalar_t, int width>
struct alignas(16) _f16VecPN : _f16Vec<scalar_t, width> {
using Base = _f16Vec<scalar_t, width>;
using Converter = typename Base::Converter;
using T1 = typename Base::T1;
using T2 = typename Base::T2;
using Base::data;
__device__ auto sum_pows() const {
float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f;
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
float x2 = z.x * z.x;
float x4 = x2 * x2;
float x6 = x4 * x2;
float y2 = z.y * z.y;
float y4 = y2 * y2;
float y6 = y4 * y2;
s2 += x2 + y2;
s4 += x4 + y4;
s6 += x6 + y6;
}
return std::make_tuple(s2, s4, s6);
}
__device__ void poly_norm_inplace(const float w2_inv_std,
const float w1_inv_std2,
const float w0_inv_std3, const float bias) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
float x2 = z.x * z.x;
float x3 = x2 * z.x;
z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias;
float y2 = z.y * z.y;
float y3 = y2 * z.y;
z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias;
auto out = Converter::convert(z);
data[i] = out.x;
data[i + 1] = out.y;
}
}
};
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [3]
const scalar_t* __restrict__ bias, // [1]
const float epsilon, const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16VecPN<scalar_t, width>>);
static_assert(sizeof(_f16VecPN<scalar_t, width>) == sizeof(scalar_t) * width);
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ input_v =
reinterpret_cast<const _f16VecPN<scalar_t, width>*>(input);
const int vec_hidden_size = hidden_size / width;
float variance = 0.0f;
float variance2 = 0.0f;
float variance3 = 0.0f;
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16VecPN<scalar_t, width> temp = input_v[id];
auto [x2, x4, x6] = temp.sum_pows();
variance += x2;
variance2 += x4;
variance3 += x6;
}
float3 thread_variances = make_float3(variance, variance2, variance3);
struct SumOp {
__device__ float3 operator()(const float3& a, const float3& b) const {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
};
using BlockReduce = cub::BlockReduce<float3, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float3 block_variances =
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
variance = block_variances.x;
variance2 = block_variances.y;
variance3 = block_variances.z;
__shared__ float s_w2_inv_std;
__shared__ float s_w1_inv_std2;
__shared__ float s_w0_inv_std3;
__shared__ float s_bias;
if (threadIdx.x == 0) {
float w0 = (float)weight[0];
float w1 = (float)weight[1];
float w2 = (float)weight[2];
s_bias = (float)bias[0];
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
}
__syncthreads();
auto* __restrict__ out_v = reinterpret_cast<_f16VecPN<scalar_t, width>*>(out);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16VecPN<scalar_t, width> temp = input_v[id];
temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias);
out_v[id] = temp;
}
}
/* Generic poly_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [3]
const scalar_t* __restrict__ bias, // [1]
const float epsilon, const int hidden_size) {
float variance = 0.0f;
float variance2 = 0.0f;
float variance3 = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x2 = x * x;
float x4 = x2 * x2;
float x6 = x4 * x2;
variance += x2;
variance2 += x4;
variance3 += x6;
}
float3 thread_variances = make_float3(variance, variance2, variance3);
struct SumOp {
__device__ float3 operator()(const float3& a, const float3& b) const {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
};
using BlockReduce = cub::BlockReduce<float3, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float3 block_variances =
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
variance = block_variances.x;
variance2 = block_variances.y;
variance3 = block_variances.z;
__shared__ float s_w2_inv_std;
__shared__ float s_w1_inv_std2;
__shared__ float s_w0_inv_std3;
__shared__ float s_bias;
if (threadIdx.x == 0) {
float w0 = (float)weight[0];
float w1 = (float)weight[1];
float w2 = (float)weight[2];
s_bias = (float)bias[0];
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x2 = x * x;
float x3 = x2 * x;
out[blockIdx.x * hidden_size + idx] =
(scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 +
s_bias);
}
}
} // namespace vllm
void rms_norm(torch::Tensor& out, // [..., hidden_size]
......@@ -351,18 +168,34 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
int64_t input_stride = input.stride(-2);
// We cannot just use `input.stride(-2)` if the tensor is not row-major.
// Instead, we use a 2d view to get the second-innermost stride.
// That way the dimensions (except the last one) can be arbitrarily permuted.
torch::Tensor input_view = input.view({-1, hidden_size});
int num_tokens = input_view.numel() / hidden_size;
int64_t input_stride = input_view.stride(-2);
// For large num_tokens, use smaller blocks to increase SM concurrency.
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
VLLM_DISPATCH_FLOATING_TYPES(
input_view.scalar_type(), "rms_norm_kernel", [&] {
const int calculated_vec_size =
std::gcd(16 / sizeof(scalar_t), hidden_size);
const int block_size =
std::min(hidden_size / calculated_vec_size, max_block_size);
dim3 block(block_size);
VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
vllm::rms_norm_kernel<scalar_t, vec_size><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input_view.data_ptr<scalar_t>(),
input_stride, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
hidden_size);
});
});
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
......@@ -379,6 +212,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
TORCH_CHECK(input.scalar_type() == residual.scalar_type());
TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);
......@@ -413,55 +248,11 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
wt_ptr % req_alignment_bytes == 0;
bool offsets_are_multiple_of_vector_width =
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
!batch_invariant_launch) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}
#define LAUNCH_FUSED_POLY_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon, \
hidden_size); \
});
void poly_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [3]
torch::Tensor& bias, // [1]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.data_ptr() != input.data_ptr());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_POLY_NORM(8);
} else {
LAUNCH_FUSED_POLY_NORM(0);
}
}
......@@ -6,9 +6,11 @@
*/
#include "type_convert.cuh"
#include "quantization/fp8/common.cuh"
#include "quantization/w8a8/fp8/common.cuh"
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
......@@ -16,7 +18,7 @@
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t, typename fp8_type>
template <typename scalar_t, typename fp8_type, int VEC_SIZE>
__global__ void rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
......@@ -27,10 +29,21 @@ __global__ void rms_norm_static_fp8_quant_kernel(
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * input_stride + idx];
const scalar_t* input_row = input + blockIdx.x * input_stride;
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
float x = static_cast<float>(vec.val[i]);
variance += x * x;
}
};
auto scalar_op = [&variance](const scalar_t& val) {
float x = static_cast<float>(val);
variance += x * x;
}
};
vllm::vectorize_read_with_alignment<VEC_SIZE>(
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
......@@ -44,11 +57,18 @@ __global__ void rms_norm_static_fp8_quant_kernel(
// invert scale to avoid division
float const scale_inv = 1.0f / *scale;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * input_stride + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
auto* v_in = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(input_row);
auto* v_w = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(weight);
for (int idx = threadIdx.x; idx < hidden_size / VEC_SIZE; idx += blockDim.x) {
vec_n_t<scalar_t, VEC_SIZE> src1 = v_in[idx];
vec_n_t<scalar_t, VEC_SIZE> src2 = v_w[idx];
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
float x = static_cast<float>(src1.val[j]);
float const out_norm = ((scalar_t)(x * s_variance)) * src2.val[j];
out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
}
}
}
......@@ -174,20 +194,29 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
int input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;
// For large num_tokens, use smaller blocks to increase SM concurrency.
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
input_stride, weight.data_ptr<scalar_t>(),
scale.data_ptr<float>(), epsilon, num_tokens,
hidden_size);
const int calculated_vec_size =
std::gcd(16 / sizeof(scalar_t), hidden_size);
const int block_size =
std::min(hidden_size / calculated_vec_size, max_block_size);
dim3 block(block_size);
VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t,
vec_size>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
input_stride, weight.data_ptr<scalar_t>(),
scale.data_ptr<float>(), epsilon, num_tokens,
hidden_size);
});
});
});
}
......@@ -215,6 +244,8 @@ void fused_add_rms_norm_static_fp8_quant(
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(residual.scalar_type() == input.scalar_type());
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
int hidden_size = input.size(-1);
int input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;
......@@ -240,7 +271,9 @@ void fused_add_rms_norm_static_fp8_quant(
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) {
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
!batch_invariant_launch) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
......
......@@ -24,6 +24,8 @@ struct SSMParamsBase {
int64_t pad_slot_id;
bool delta_softplus;
bool cache_enabled;
int block_size;
index_t A_d_stride;
index_t A_dstate_stride;
......@@ -46,8 +48,9 @@ struct SSMParamsBase {
index_t out_z_batch_stride;
index_t out_z_d_stride;
index_t ssm_states_batch_stride;
index_t ssm_states_dim_stride;
index_t ssm_states_dim_stride;
index_t ssm_states_dstate_stride;
index_t cache_indices_stride;
// Common data pointers.
void *__restrict__ A_ptr;
......@@ -66,6 +69,9 @@ struct SSMParamsBase {
void *__restrict__ cache_indices_ptr;
void *__restrict__ has_initial_state_ptr;
void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write
void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write
void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use
};
......
......@@ -119,7 +119,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
......@@ -133,9 +133,18 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
typename Ktraits::state_t *ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
cache_index * params.ssm_states_batch_stride +
dim_id * kNRows * params.ssm_states_dim_stride;
typename Ktraits::state_t *ssm_states;
if (params.cache_enabled) {
// APC mode: ssm_states points to the base, we'll use absolute cache slots later
ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
dim_id * kNRows * params.ssm_states_dim_stride;
} else {
// Non-APC mode: offset by cache_index as before
ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
cache_index * params.ssm_states_batch_stride +
dim_id * kNRows * params.ssm_states_dim_stride;
}
float D_val[kNRows] = {0};
if (params.D_ptr != nullptr) {
......@@ -159,7 +168,22 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// }
constexpr int kChunkSize = kNThreads * kNItems;
const int n_chunks = (seqlen + 2048 - 1) / 2048;
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048;
const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size;
const int* batch_cache_indices = cache_indices != nullptr ?
cache_indices + batch_id * params.cache_indices_stride : nullptr;
const int* block_idx_first_scheduled = params.block_idx_first_scheduled_token_ptr != nullptr ?
reinterpret_cast<const int*>(params.block_idx_first_scheduled_token_ptr) : nullptr;
const int* block_idx_last_scheduled = params.block_idx_last_scheduled_token_ptr != nullptr ?
reinterpret_cast<const int*>(params.block_idx_last_scheduled_token_ptr) : nullptr;
const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ?
reinterpret_cast<const int*>(params.initial_state_idx_ptr) : nullptr;
const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;
for (int chunk = 0; chunk < n_chunks; ++chunk) {
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
......@@ -219,7 +243,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (kIsVariableC) {
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1 ));
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1));
if constexpr (!kIsVariableB) {
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
......@@ -242,7 +266,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
for (int i = 0; i < kNItems; ++i) {
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct
if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) {
thread_data[i] = make_float2(1.f, 0.f);
......@@ -250,8 +273,24 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
// Initialize running total
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0);
scan_t running_prefix;
if (chunk > 0) {
running_prefix = smem_running_prefix[state_idx + r * MAX_DSTATE];
} else {
// Load initial state
if (params.cache_enabled && has_initial_state && batch_cache_indices != nullptr) {
size_t state_offset = load_cache_slot * params.ssm_states_batch_stride +
r * params.ssm_states_dim_stride +
state_idx * params.ssm_states_dstate_stride;
running_prefix = make_float2(1.0, float(ssm_states[state_offset]));
} else if (has_initial_state) {
// Non-APC mode: load from current batch position
running_prefix = make_float2(1.0, float(ssm_states[state_idx * params.ssm_states_dstate_stride]));
} else {
// No initial state
running_prefix = make_float2(1.0, 0.0);
}
}
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
......@@ -260,8 +299,25 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// There's a syncthreads in the scan op, so we don't need to sync here.
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
if (threadIdx.x == 0) {
smem_running_prefix[state_idx] = prefix_op.running_prefix;
if (chunk == n_chunks - 1) {
smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix;
// Store state at the end of each chunk when cache is enabled
if (params.cache_enabled && batch_cache_indices != nullptr) {
size_t cache_slot;
if (chunk == n_chunks - 1) {
cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
} else {
cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk];
}
size_t state_offset = cache_slot * params.ssm_states_batch_stride +
r * params.ssm_states_dim_stride +
state_idx * params.ssm_states_dstate_stride;
ssm_states[state_offset] = typename Ktraits::state_t(prefix_op.running_prefix.y);
} else if (!params.cache_enabled && chunk == n_chunks - 1) {
// Non-APC mode: store only final state at current batch position
ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y);
}
}
......@@ -274,7 +330,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
}
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
__syncthreads();
......@@ -346,7 +401,9 @@ template<typename input_t, typename weight_t, typename state_t>
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
#ifndef USE_ROCM
if (params.seqlen <= 128) {
if (params.cache_enabled && params.block_size == 1024) {
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 128) {
selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 256) {
selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream);
......@@ -358,7 +415,9 @@ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
}
#else
if (params.seqlen <= 256) {
if (params.cache_enabled && params.block_size == 1024) {
selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 256) {
selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream);
} else if (params.seqlen <= 512) {
selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream);
......@@ -437,13 +496,17 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const std::optional<at::Tensor>& D,
const std::optional<at::Tensor>& delta_bias,
const torch::Tensor ssm_states,
bool has_z,
bool has_z,
bool delta_softplus,
const std::optional<at::Tensor>& query_start_loc,
const std::optional<at::Tensor>& cache_indices,
const std::optional<at::Tensor>& has_initial_state,
bool varlen,
int64_t pad_slot_id) {
int64_t pad_slot_id,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::Tensor> &initial_state_idx) {
// Reset the parameters
memset(&params, 0, sizeof(params));
......@@ -477,6 +540,14 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
// Set cache parameters - cache is enabled if we have direct cache writing params
params.cache_enabled = block_idx_first_scheduled_token.has_value();
params.block_size = static_cast<int>(block_size);
// Set direct cache writing pointers
params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr;
params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr;
params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr;
// All stride are in elements, not bytes.
params.A_d_stride = A.stride(0);
......@@ -504,9 +575,11 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.out_d_stride = out.stride(0);
params.ssm_states_batch_stride = ssm_states.stride(0);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dstate_stride = ssm_states.stride(2);
params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0;
}
else{
if (!is_variable_B) {
......@@ -537,8 +610,10 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.out_d_stride = out.stride(1);
params.ssm_states_batch_stride = ssm_states.stride(0);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dim_stride = ssm_states.stride(1);
params.ssm_states_dstate_stride = ssm_states.stride(2);
params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0;
}
}
......@@ -554,7 +629,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const torch::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
int64_t pad_slot_id,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::Tensor> &initial_state_idx) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
......@@ -646,7 +725,16 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
auto cache_indices_ = cache_indices.value();
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(cache_indices_.is_cuda());
CHECK_SHAPE(cache_indices_, batch_size);
// cache_indices can be either 1D (batch_size,) for non-APC mode
// or 2D (batch_size, max_positions) for APC mode
const bool is_apc_mode = block_idx_first_scheduled_token.has_value();
if (is_apc_mode) {
TORCH_CHECK(cache_indices_.dim() == 2, "cache_indices must be 2D for APC mode");
TORCH_CHECK(cache_indices_.size(0) == batch_size, "cache_indices first dimension must match batch_size");
} else {
CHECK_SHAPE(cache_indices_, batch_size);
}
}
......@@ -686,7 +774,11 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
cache_indices,
has_initial_state,
varlen,
pad_slot_id
pad_slot_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx
);
......
......@@ -87,30 +87,23 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;
// Per-expert outputs filled in parallel
std::vector<torch::Tensor> y_list(E);
y_list.resize(E);
auto X_all = x_c.index_select(/*dim=*/0, expert_tokens);
if (apply_router_weight_on_input) {
X_all = X_all.mul(expert_gates.unsqueeze(1));
}
auto Y_all = at::empty({offsets[E], H}, x_c.options());
at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
c10::InferenceMode guard;
for (int64_t e = e_begin; e < e_end; ++e) {
const int64_t te = counts[e];
if (te == 0) {
y_list[e] = at::empty({0, H}, x_c.options());
continue;
}
const int64_t start = offsets[e];
auto sel_tokens =
expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto gates_e =
expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto x_e = x_c.index_select(/*dim=*/0, sel_tokens);
if (apply_router_weight_on_input) {
x_e = x_e.mul(gates_e.unsqueeze(1));
}
auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto w13_e = w13_packed.select(/*dim=*/0, e);
auto w2_e = w2_packed.select(/*dim=*/0, e);
......@@ -137,17 +130,15 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
// W2
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);
if (!apply_router_weight_on_input) {
y = y.mul(gates_e.unsqueeze(1));
}
// Store per-expert result
y_list[e] = y;
Y_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te).copy_(y);
}
});
// Concatenate all expert outputs to match expert_tokens order
auto Y_all = at::cat(y_list, /*dim=*/0);
if (!apply_router_weight_on_input) {
Y_all = Y_all.mul(expert_gates.unsqueeze(1));
}
auto out = at::zeros({T, H}, x.options());
out =
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);
......
......@@ -427,11 +427,29 @@ __device__ inline bool is_finite(const T val) {
#endif
}
// Scoring function enums
enum ScoringFunc {
SCORING_NONE = 0, // no activation function
SCORING_SIGMOID = 1 // apply sigmoid
};
// Efficient sigmoid approximation from TensorRT-LLM
__device__ inline float sigmoid_accurate(float x) {
return 0.5f * tanhf(0.5f * x) + 0.5f;
}
template <typename T>
__device__ void topk_with_k2(T* output, T const* input,
__device__ inline T apply_sigmoid(T val) {
float f = cuda_cast<float, T>(val);
return cuda_cast<T, float>(sigmoid_accurate(f));
}
template <typename T>
__device__ void topk_with_k2(T* output, T const* input, T const* bias,
cg::thread_block_tile<32> const& tile,
int32_t const lane_id,
int const num_experts_per_group) {
int const num_experts_per_group,
int const scoring_func) {
// Get the top2 per thread
T largest = neg_inf<T>();
T second_largest = neg_inf<T>();
......@@ -439,6 +457,12 @@ __device__ void topk_with_k2(T* output, T const* input,
if (num_experts_per_group > WARP_SIZE) {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
T value = input[i];
// Apply scoring function if needed
if (scoring_func == SCORING_SIGMOID) {
value = apply_sigmoid(value);
}
value = value + bias[i];
if (value > largest) {
second_largest = largest;
largest = value;
......@@ -448,7 +472,13 @@ __device__ void topk_with_k2(T* output, T const* input,
}
} else {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
largest = input[i];
T value = input[i];
// Apply scoring function if needed
if (scoring_func == SCORING_SIGMOID) {
value = apply_sigmoid(value);
}
value = value + bias[i];
largest = value;
}
}
......@@ -472,17 +502,21 @@ __device__ void topk_with_k2(T* output, T const* input,
}
template <typename T>
__global__ void topk_with_k2_kernel(T* output, T* input,
__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
int64_t const num_tokens,
int64_t const num_cases,
int64_t const n_group,
int64_t const num_experts_per_group) {
int64_t const num_experts_per_group,
int const scoring_func) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
if (case_id < num_cases) {
input += case_id * num_experts_per_group;
// bias is per expert group, offset to current group
int32_t group_id = case_id % n_group;
T const* group_bias = bias + group_id * num_experts_per_group;
output += case_id;
cg::thread_block block = cg::this_thread_block();
......@@ -491,7 +525,8 @@ __global__ void topk_with_k2_kernel(T* output, T* input,
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
topk_with_k2(output, input, group_bias, tile, lane_id,
num_experts_per_group, scoring_func);
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
......@@ -500,16 +535,15 @@ __global__ void topk_with_k2_kernel(T* output, T* input,
template <typename T, typename IdxT>
__global__ void group_idx_and_topk_idx_kernel(
T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices,
T* scores_with_bias, int64_t const num_tokens, int64_t const n_group,
T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
T const* bias, int64_t const num_tokens, int64_t const n_group,
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
int64_t const num_experts_per_group, bool renormalize,
double routed_scaling_factor) {
double routed_scaling_factor, int scoring_func) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id =
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
scores_with_bias += case_id * num_experts;
scores += case_id * num_experts;
group_scores += case_id * n_group;
topk_values += case_id * topk;
......@@ -577,10 +611,16 @@ __global__ void group_idx_and_topk_idx_kernel(
int32_t offset = i_group * num_experts_per_group;
for (int32_t i = lane_id; i < align_num_experts_per_group;
i += WARP_SIZE) {
T candidates = (i < num_experts_per_group) &&
is_finite(scores_with_bias[offset + i])
? scores_with_bias[offset + i]
: neg_inf<T>();
T candidates = neg_inf<T>();
if (i < num_experts_per_group) {
// Apply scoring function (if any) and add bias
T input = scores[offset + i];
if (is_finite(input)) {
T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input)
: input;
candidates = score + bias[offset + i];
}
}
queue.add(candidates, offset + i);
}
if (group_scores[i_group] == topk_group_value) {
......@@ -602,11 +642,12 @@ __global__ void group_idx_and_topk_idx_kernel(
for (int i = lane_id;
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
i += WARP_SIZE) {
T value =
i < topk
? scores[s_topk_idx[i]]
: cuda_cast<T, float>(0.0f); // Load the valid value of expert
T value = cuda_cast<T, float>(0.0f);
if (i < topk) {
// Load the score value (without bias) for normalization
T input = scores[s_topk_idx[i]];
value =
(scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) : input;
s_topk_value[i] = value;
}
topk_sum +=
......@@ -627,12 +668,12 @@ __global__ void group_idx_and_topk_idx_kernel(
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
}
topk_indices[i] = s_topk_idx[i];
topk_values[i] = cuda_cast<T, float>(value);
topk_values[i] = value;
}
} else {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
topk_indices[i] = i;
topk_values[i] = cuda_cast<T, float>(1.0f / topk);
topk_values[i] = 1.0f / topk;
}
}
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
......@@ -644,12 +685,12 @@ __global__ void group_idx_and_topk_idx_kernel(
}
template <typename T, typename IdxT>
void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
IdxT* topk_indices, T* scores_with_bias,
int64_t const num_tokens, int64_t const num_experts,
int64_t const n_group, int64_t const topk_group,
int64_t const topk, bool const renormalize,
double const routed_scaling_factor, bool enable_pdl = false,
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
IdxT* topk_indices, T const* bias, int64_t const num_tokens,
int64_t const num_experts, int64_t const n_group,
int64_t const topk_group, int64_t const topk,
bool const renormalize, double const routed_scaling_factor,
int const scoring_func, bool enable_pdl = false,
cudaStream_t const stream = 0) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
......@@ -664,8 +705,9 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
num_tokens, num_cases, n_group, num_experts / n_group);
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
num_tokens, num_cases, n_group, num_experts / n_group,
scoring_func);
int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
......@@ -682,19 +724,18 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
topk_values, topk_indices, scores_with_bias, num_tokens,
n_group, topk_group, topk, num_experts,
num_experts / n_group, renormalize, routed_scaling_factor);
topk_values, topk_indices, bias, num_tokens, n_group,
topk_group, topk, num_experts, num_experts / n_group,
renormalize, routed_scaling_factor, scoring_func);
}
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
template void invokeNoAuxTc<T, IdxT>( \
T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \
T * scores_with_bias, int64_t const num_tokens, \
int64_t const num_experts, int64_t const n_group, \
int64_t const topk_group, int64_t const topk, bool const renormalize, \
double const routed_scaling_factor, bool enable_pdl, \
cudaStream_t const stream);
T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \
T const* bias, int64_t const num_tokens, int64_t const num_experts, \
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
bool const renormalize, double const routed_scaling_factor, \
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
INSTANTIATE_NOAUX_TC(float, int32_t);
INSTANTIATE_NOAUX_TC(half, int32_t);
......@@ -703,28 +744,32 @@ INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
} // namespace vllm
std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
torch::Tensor const& scores, torch::Tensor const& scores_with_bias,
int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize,
double routed_scaling_factor) {
auto data_type = scores_with_bias.scalar_type();
auto input_size = scores_with_bias.sizes();
torch::Tensor const& scores, int64_t n_group, int64_t topk_group,
int64_t topk, bool renormalize, double routed_scaling_factor,
torch::Tensor const& bias, int64_t scoring_func = 0) {
auto data_type = scores.scalar_type();
auto input_size = scores.sizes();
int64_t num_tokens = input_size[0];
int64_t num_experts = input_size[1];
TORCH_CHECK(input_size.size() == 2, "scores_with_bias must be a 2D Tensor");
TORCH_CHECK(input_size.size() == 2, "scores must be a 2D Tensor");
TORCH_CHECK(num_experts % n_group == 0,
"num_experts should be divisible by n_group");
TORCH_CHECK(n_group <= 32,
"n_group should be smaller than or equal to 32 for now");
TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now");
TORCH_CHECK(scoring_func == vllm::moe::SCORING_NONE ||
scoring_func == vllm::moe::SCORING_SIGMOID,
"scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)");
torch::Tensor group_scores = torch::empty(
{num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA));
// Always output float32 for topk_values (eliminates Python-side conversion)
torch::Tensor topk_values = torch::empty(
{num_tokens, topk}, torch::dtype(data_type).device(torch::kCUDA));
{num_tokens, topk}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
torch::Tensor topk_indices = torch::empty(
{num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto stream = c10::cuda::getCurrentCUDAStream(scores_with_bias.get_device());
auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device());
switch (data_type) {
case torch::kFloat16:
......@@ -732,11 +777,11 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
vllm::moe::invokeNoAuxTc<half, int32_t>(
reinterpret_cast<half*>(scores.mutable_data_ptr()),
reinterpret_cast<half*>(group_scores.mutable_data_ptr()),
reinterpret_cast<half*>(topk_values.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<half*>(scores_with_bias.data_ptr()), num_tokens,
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, false, stream);
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
break;
case torch::kFloat32:
// Handle Float32
......@@ -745,20 +790,20 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
reinterpret_cast<float*>(group_scores.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<float*>(scores_with_bias.data_ptr()), num_tokens,
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, false, stream);
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
break;
case torch::kBFloat16:
// Handle BFloat16
vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>(
reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16*>(topk_values.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16*>(scores_with_bias.data_ptr()),
num_tokens, num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, false, stream);
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
break;
default:
// Handle other data types
......
......@@ -17,25 +17,30 @@ FILE_HEAD = """
namespace MARLIN_NAMESPACE_NAME {
""".strip()
TEMPLATE = ("template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );")
TEMPLATE = (
"template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );"
)
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = [
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
"vllm::kFE2M1f"
"vllm::kU4",
"vllm::kU4B8",
"vllm::kU8B128",
"vllm::kFE4M3fn",
"vllm::kFE2M1f",
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
......@@ -58,11 +63,12 @@ def generate_new_kernels():
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
):
# act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8", "vllm::kU8B128"
"vllm::kU4B8",
"vllm::kU8B128",
]:
continue
if thread_configs[2] == 256:
......
......@@ -8,12 +8,77 @@
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
namespace vllm {
namespace moe {
namespace batched_moe_align_block_size {
// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel.
static constexpr int32_t num_threads = 1024;
static constexpr int32_t num_blocks = 1;
__global__ void batched_moe_align_block_size_kernel(
int32_t const num_batches, int32_t const max_tokens_per_batch,
int32_t const block_size, int32_t const* __restrict__ batch_num_tokens,
int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids,
int32_t* __restrict__ num_tokens_post_pad) {
// TODO(varun): This is a naive implementation. Could be optimized.
size_t const batch_id = threadIdx.x;
size_t const stride = blockDim.x * gridDim.x;
int32_t const num_blocks_per_batch =
CEILDIV(max_tokens_per_batch, block_size);
int32_t const sorted_ids_size =
num_blocks_per_batch * num_batches * block_size;
int32_t const block_ids_size = sorted_ids_size / block_size;
int32_t const SENTINEL =
num_batches * max_tokens_per_batch; // To denote invalid entries.
// Intialize sorted_ids
for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) {
sorted_ids[i] = SENTINEL;
}
// Intialize expert_ids with -1
for (size_t i = threadIdx.x; i < block_ids_size; i += stride) {
block_ids[i] = -1;
}
int32_t b_num_tokens = 0;
if (batch_id < num_batches) {
b_num_tokens = batch_num_tokens[batch_id];
}
int32_t const ceil_b_num_tokens =
CEILDIV(b_num_tokens, block_size) * block_size;
// Compute prefix sum over token counts per expert
using BlockScan = cub::BlockScan<int32_t, 1024>;
__shared__ typename BlockScan::TempStorage temp_storage;
int cumsum_val;
BlockScan(temp_storage).ExclusiveSum(ceil_b_num_tokens, cumsum_val);
__syncthreads();
bool const is_last_batch = batch_id == (num_batches - 1);
if (is_last_batch) {
*num_tokens_post_pad = cumsum_val + ceil_b_num_tokens;
}
if (batch_id < num_batches) {
int32_t const batch_offset = batch_id * max_tokens_per_batch;
for (size_t i = 0; i < b_num_tokens; ++i) {
sorted_ids[cumsum_val + i] = batch_offset + i;
}
int32_t const block_start = cumsum_val / block_size;
int32_t const num_blocks = ceil_b_num_tokens / block_size;
for (size_t i = 0; i < num_blocks; ++i) {
block_ids[block_start + i] = batch_id;
}
}
}
} // namespace batched_moe_align_block_size
template <typename scalar_t>
__global__ void moe_align_block_size_kernel(
const scalar_t* __restrict__ topk_ids,
......@@ -280,6 +345,33 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
});
}
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
int64_t block_size,
torch::Tensor const& batch_num_tokens,
torch::Tensor sorted_ids,
torch::Tensor batch_ids,
torch::Tensor num_tokens_post_pad) {
namespace batched_kernel = vllm::moe::batched_moe_align_block_size;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t const B = batch_num_tokens.size(0);
int32_t const num_blocks_per_batch =
round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size;
int32_t const num_blocks = num_blocks_per_batch * B;
int64_t const sorted_ids_size = num_blocks * block_size;
TORCH_CHECK(sorted_ids.size(0) == sorted_ids_size);
TORCH_CHECK(batch_ids.size(0) == sorted_ids_size / block_size);
TORCH_CHECK(num_tokens_post_pad.size(0) == 1);
TORCH_CHECK(B <= batched_kernel::num_threads);
batched_kernel::batched_moe_align_block_size_kernel<<<
batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>(
B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr<int32_t>(),
sorted_ids.data_ptr<int32_t>(), batch_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>());
}
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
torch::Tensor& output) // [num_tokens, hidden_size]
{
......
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"
namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
int32_t col) {
return row * total_col + col;
}
} // namespace
// TODO: Refactor common parts with moe_align_sum_kernels
template <typename scalar_t, typename token_cnts_t>
__global__ void moe_lora_align_sum_kernel(
scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping,
int64_t block_size, int num_experts, int max_loras, size_t numel,
int max_num_tokens_padded, int max_num_m_blocks,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled,
int32_t* lora_ids) {
const size_t tokens_per_thread = div_ceil(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
int lora_idx = blockIdx.x;
int lora_id = lora_ids[lora_idx];
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
return;
}
extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem;
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
// Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
sorted_token_ids[lora_id * max_num_tokens_padded + it] = numel;
}
// Initialize expert_ids with -1
for (size_t it = threadIdx.x; it < max_num_m_blocks; it += blockDim.x) {
expert_ids[lora_id * max_num_m_blocks + it] = -1;
}
// Initialize total_tokens_post_pad with 0
if (threadIdx.x == 0) {
total_tokens_post_pad[lora_id] = 0;
}
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int mask = token_lora_mapping[i / topk_num] == lora_id;
int idx = index(num_experts, threadIdx.x + 1, topk_ids[i]);
tokens_cnts[idx] += mask;
}
__syncthreads();
// For each expert we accumulate the token counts from the different threads.
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}
__syncthreads();
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] +
div_ceil(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
}
total_tokens_post_pad[lora_id] = static_cast<int32_t>(cumsum[num_experts]);
}
__syncthreads();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[index(max_num_m_blocks, lora_id, i / block_size)] =
threadIdx.x;
}
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
int mask = (int)token_lora_mapping[i / topk_num] == lora_id;
atomicAdd(
&sorted_token_ids[index(max_num_tokens_padded, lora_id, rank_post_pad)],
(i - numel) * mask);
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] += mask;
}
}
void moe_lora_align_block_size(
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size, int64_t max_loras,
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
torch::Tensor lora_ids) {
const int topk_num = topk_ids.size(1);
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
int device_max_shared_mem;
auto dev = topk_ids.get_device();
cudaDeviceGetAttribute(&device_max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int32_t num_thread = max((int32_t)num_experts, 128); // WARP_SIZE,
TORCH_CHECK(num_thread <= 1024,
"num_thread must be less than 1024, "
"and fallback is not implemented yet.");
const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) +
(num_experts + 1) * sizeof(int32_t);
if (shared_mem > device_max_shared_mem) {
TORCH_CHECK(false,
"Shared memory usage exceeds device limit, and global memory "
"fallback is not implemented yet.");
}
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] {
dim3 blockDim(num_thread);
auto kernel = moe_lora_align_sum_kernel<scalar_t, int32_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<max_loras, blockDim, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(),
token_lora_mapping.data_ptr<int32_t>(), block_size, num_experts,
max_loras, topk_ids.numel(), max_num_tokens_padded,
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), topk_num,
num_tokens_post_pad.data_ptr<int32_t>(),
adapter_enabled.data_ptr<int32_t>(), lora_ids.data_ptr<int32_t>());
});
}
\ No newline at end of file
......@@ -4,7 +4,7 @@
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);
torch::Tensor& gating_output, bool renormalize);
void moe_sum(torch::Tensor& input, torch::Tensor& output);
......@@ -12,6 +12,21 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
int64_t block_size,
torch::Tensor const& expert_num_tokens,
torch::Tensor sorted_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);
void moe_lora_align_block_size(
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size, int64_t max_loras,
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
torch::Tensor lora_ids);
#ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales,
......@@ -24,9 +39,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
int64_t BLOCK_SIZE_K, int64_t bit);
std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
torch::Tensor const& scores, torch::Tensor const& scores_with_bias,
int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize,
double routed_scaling_factor);
torch::Tensor const& scores, int64_t n_group, int64_t topk_group,
int64_t topk, bool renormalize, double routed_scaling_factor,
torch::Tensor const& bias, int64_t scoring_func);
#endif
bool moe_permute_unpermute_supported();
......
......@@ -16,12 +16,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <type_traits>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"
#include "../cub_helpers.h"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
typedef __hip_bfloat16 __nv_bfloat16;
typedef __hip_bfloat162 __nv_bfloat162;
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
......@@ -36,16 +47,27 @@ template <
/// Alignment requirement in bytes
int Alignment = sizeof(T) * N
>
class alignas(Alignment) AlignedArray {
float data[N];
struct alignas(Alignment) AlignedArray {
T data[N];
};
template <typename T>
__device__ __forceinline__ float toFloat(T value) {
if constexpr (std::is_same_v<T, float>) {
return value;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
return __bfloat162float(value);
} else if constexpr (std::is_same_v<T, __half>) {
return __half2float(value);
}
}
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
template <int TPB>
template <int TPB, typename InputType>
__launch_bounds__(TPB) __global__
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
void moeSoftmax(const InputType* input, const bool* finished, float* output, const int num_cols)
{
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
......@@ -66,7 +88,8 @@ __launch_bounds__(TPB) __global__
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData);
const float val = toFloat(input[idx]);
threadData = max(val, threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp());
......@@ -81,7 +104,8 @@ __launch_bounds__(TPB) __global__
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
threadData += exp((static_cast<float>(input[idx]) - float_max));
const float val = toFloat(input[idx]);
threadData += expf(val - float_max);
}
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp());
......@@ -95,8 +119,9 @@ __launch_bounds__(TPB) __global__
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
output[idx] = val;
const float val = toFloat(input[idx]);
const float softmax_val = expf(val - float_max) * normalizing_factor;
output[idx] = softmax_val;
}
}
......@@ -110,7 +135,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(
const int num_experts,
const int k,
const int start_expert,
const int end_expert)
const int end_expert,
const bool renormalize)
{
using cub_kvp = cub::KeyValuePair<int, float>;
......@@ -125,6 +151,7 @@ __launch_bounds__(TPB) __global__ void moeTopK(
const bool row_is_active = finished ? !finished[block_row] : true;
const int thread_read_offset = blockIdx.x * num_experts;
float selected_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
thread_kvp.key = 0;
......@@ -163,9 +190,23 @@ __launch_bounds__(TPB) __global__ void moeTopK(
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
assert(indices[idx] >= 0);
source_rows[idx] = k_idx * num_rows + block_row;
if (renormalize) {
selected_sum += result_kvp.value;
}
}
__syncthreads();
}
// Renormalize the k weights for this row to sum to 1, if requested.
if (renormalize) {
if (threadIdx.x == 0) {
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int idx = k * block_row + k_idx;
output[idx] = output[idx] / denom;
}
}
}
}
// ====================== TopK softmax things ===============================
......@@ -184,21 +225,30 @@ __launch_bounds__(TPB) __global__ void moeTopK(
2) This implementation assumes k is small, but will work for any k.
*/
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType>
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType, typename InputType = float>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert)
void topkGatingSoftmax(const InputType* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert, const bool renormalize)
{
static_assert(std::is_same_v<InputType, float> || std::is_same_v<InputType, __nv_bfloat16> ||
std::is_same_v<InputType, __half>,
"InputType must be float, __nv_bfloat16, or __half");
// We begin by enforcing compile time assertions and setting up compile time constants.
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
// Number of bytes each thread pulls in per load
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
if constexpr (std::is_same_v<InputType, __nv_bfloat16> || std::is_same_v<InputType, __half>) {
static_assert(ELTS_PER_LDG == 1 || ELTS_PER_LDG % 2 == 0,
"ELTS_PER_LDG must be 1 or even for 16-bit conversion");
}
// Restrictions based on previous section.
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
......@@ -236,27 +286,71 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
// row it will read.
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
const InputType* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
// this can support all powers of 2 up to 16.
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
using AccessType = AlignedArray<float, ELTS_PER_LDG>;
const InputType* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Finally, we pull in the data from global mem
float row_chunk[VPT];
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
// NOTE(zhuhaoran): dispatch different input types loading, BF16/FP16 convert to float
if constexpr (std::is_same_v<InputType, float>) {
using VecType = AlignedArray<float, ELTS_PER_LDG>;
VecType* row_chunk_vec_ptr = reinterpret_cast<VecType*>(&row_chunk);
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
{
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
}
} else if constexpr (std::is_same_v<InputType, __nv_bfloat16>) {
if constexpr (ELTS_PER_LDG >= 2) {
using VecType = AlignedArray<__nv_bfloat16, ELTS_PER_LDG>;
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
#pragma unroll
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
row_chunk_f2[base_idx_f2 + jj] = __bfloat1622float2(
*reinterpret_cast<const __nv_bfloat162*>(vec.data + jj * 2)
);
}
}
} else { // ELTS_PER_LDG == 1
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
const __nv_bfloat16* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
row_chunk[ii] = __bfloat162float(*scalar_ptr);
}
}
} else if constexpr (std::is_same_v<InputType, __half>) {
if constexpr (ELTS_PER_LDG >= 2) {
using VecType = AlignedArray<__half, ELTS_PER_LDG>;
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
#pragma unroll
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
row_chunk_f2[base_idx_f2 + jj] = __half22float2(
*reinterpret_cast<const __half2*>(vec.data + jj * 2)
);
}
}
} else { // ELTS_PER_LDG == 1
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
const __half* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
row_chunk[ii] = __half2float(*scalar_ptr);
}
}
}
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
......@@ -310,6 +404,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
float selected_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
// First, each thread does the local argmax
......@@ -363,6 +458,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
output[idx] = max_val;
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
if (renormalize) {
selected_sum += max_val;
}
}
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
......@@ -380,15 +478,28 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
}
}
}
// Renormalize the k weights for this row to sum to 1, if requested.
if (renormalize) {
if (thread_group_idx == 0)
{
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
const int idx = k * thread_row + k_idx;
output[idx] = output[idx] / denom;
}
}
}
}
namespace detail
{
// Constructs some constants needed to partition the work across threads at compile time.
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM>
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename InputType>
struct TopkConstants
{
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, "");
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
......@@ -397,20 +508,21 @@ struct TopkConstants
};
} // namespace detail
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType, typename InputType>
void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, const bool renormalize,
cudaStream_t stream)
{
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM, InputType>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM, IndType, InputType><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, renormalize);
}
#ifndef USE_ROCM
......@@ -418,26 +530,26 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
static_assert(WARP_SIZE == 32, \
"Unsupported warp size. Only 32 is supported for CUDA"); \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, stream);
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
num_tokens, topk, 0, num_experts, renormalize, stream);
#else
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
if (WARP_SIZE == 64) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
num_tokens, topk, 0, num_experts, renormalize, stream); \
} else if (WARP_SIZE == 32) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
num_tokens, topk, 0, num_experts, renormalize, stream); \
} else { \
assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
}
#endif
template <typename IndType>
template <typename IndType, typename InputType>
void topkGatingSoftmaxKernelLauncher(
const float* gating_output,
const InputType* gating_output,
float* topk_weights,
IndType* topk_indices,
int* token_expert_indices,
......@@ -445,11 +557,15 @@ void topkGatingSoftmaxKernelLauncher(
const int num_tokens,
const int num_experts,
const int topk,
const bool renormalize,
cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
#ifndef USE_ROCM
static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8;
// for bfloat16 dtype, we need 4 bytes loading to make sure num_experts
// elements can be loaded by a warp
static constexpr int BYTES_PER_LDG_MULTIPLE_64 =
(std::is_same_v<InputType, __nv_bfloat16> || std::is_same_v<InputType, __half>) ? 4 : 8;
#endif
switch (num_experts) {
case 1:
......@@ -506,11 +622,11 @@ void topkGatingSoftmaxKernelLauncher(
TORCH_CHECK(softmax_workspace != nullptr,
"softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64.");
static constexpr int TPB = 256;
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
moeSoftmax<TPB, InputType><<<num_tokens, TPB, 0, stream>>>(
gating_output, nullptr, softmax_workspace, num_experts);
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices,
num_experts, topk, 0, num_experts);
num_experts, topk, 0, num_experts, renormalize);
}
}
}
......@@ -518,11 +634,50 @@ void topkGatingSoftmaxKernelLauncher(
} // namespace moe
} // namespace vllm
template<typename ComputeType>
void dispatch_topk_softmax_launch(
torch::Tensor& gating_output,
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& softmax_workspace,
int num_tokens, int num_experts, int topk, bool renormalize, cudaStream_t stream)
{
if (topk_indices.scalar_type() == at::ScalarType::Int) {
vllm::moe::topkGatingSoftmaxKernelLauncher<int, ComputeType>(
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens, num_experts, topk, renormalize, stream);
} else if (topk_indices.scalar_type() == at::ScalarType::UInt32) {
vllm::moe::topkGatingSoftmaxKernelLauncher<uint32_t, ComputeType>(
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens, num_experts, topk, renormalize, stream);
} else {
TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long);
vllm::moe::topkGatingSoftmaxKernelLauncher<int64_t, ComputeType>(
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int64_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens, num_experts, topk, renormalize, stream);
}
}
void topk_softmax(
torch::Tensor& topk_weights, // [num_tokens, topk]
torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& token_expert_indices, // [num_tokens, topk]
torch::Tensor& gating_output) // [num_tokens, num_experts]
torch::Tensor& gating_output, // [num_tokens, num_experts]
bool renormalize)
{
const int num_experts = gating_output.size(-1);
const auto num_tokens = gating_output.numel() / num_experts;
......@@ -534,45 +689,19 @@ void topk_softmax(
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
if(topk_indices.scalar_type() == at::ScalarType::Int)
{
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
else if (topk_indices.scalar_type() == at::ScalarType::UInt32)
{
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
else {
TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long);
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int64_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
const auto workspace_options = gating_output.options().dtype(at::ScalarType::Float);
torch::Tensor softmax_workspace = torch::empty({workspace_size}, workspace_options);
if (gating_output.scalar_type() == at::ScalarType::Float) {
dispatch_topk_softmax_launch<float>(gating_output, topk_weights, topk_indices,
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
} else if (gating_output.scalar_type() == at::ScalarType::Half) {
dispatch_topk_softmax_launch<__half>(gating_output, topk_weights, topk_indices,
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
} else if (gating_output.scalar_type() == at::ScalarType::BFloat16) {
dispatch_topk_softmax_launch<__nv_bfloat16>(gating_output, topk_weights, topk_indices,
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
} else {
TORCH_CHECK(false, "Unsupported gating_output data type: ", gating_output.scalar_type());
}
}
......@@ -5,7 +5,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
"token_expert_indices, Tensor gating_output, bool renormalize) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
// Calculate the result of moe by summing up the partial results
......@@ -22,6 +22,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size, but for the batched case.
m.def(
"batched_moe_align_block_size(int max_tokens_per_batch,"
" int block_size, Tensor expert_num_tokens,"
" Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()");
m.impl("batched_moe_align_block_size", torch::kCUDA,
&batched_moe_align_block_size);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
m.def(
"moe_lora_align_block_size(Tensor topk_ids,"
" Tensor token_lora_mapping,"
" int num_experts,"
" int block_size, int max_loras, "
" int max_num_tokens_padded, "
" int max_num_m_blocks, "
" Tensor !sorted_token_ids,"
" Tensor !experts_ids,"
" Tensor !num_tokens_post_pad,"
" Tensor !adapter_enabled,"
" Tensor !lora_ids) -> () ");
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
#ifndef USE_ROCM
m.def(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
......@@ -80,9 +107,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply grouped topk routing to select experts.
m.def(
"grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int "
"grouped_topk(Tensor scores, int n_group, int "
"topk_group, int topk, bool renormalize, float "
"routed_scaling_factor) -> (Tensor, Tensor)");
"routed_scaling_factor, Tensor bias, int scoring_func) -> (Tensor, "
"Tensor)");
m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
#endif
}
......
......@@ -92,14 +92,25 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor& bias, double epsilon);
void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q,
int64_t num_heads_k, int64_t num_heads_v,
int64_t head_dim, double eps, torch::Tensor& q_weight,
torch::Tensor& k_weight, torch::Tensor& cos_sin_cache,
bool is_neox, torch::Tensor& position_ids);
void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask,
const torch::Tensor& output_mask,
const torch::Tensor& repetition_penalties);
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1);
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
const torch::Tensor& seq_lens, torch::Tensor& indices,
int64_t numRows, int64_t stride0, int64_t stride1);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale,
// double epsilon);
......@@ -133,12 +144,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& input_global_scale);
#endif
void silu_mul_fp8_quant_deep_gemm_cuda(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E)
at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens);
// void persistent_masked_m_silu_mul_quant(
// const at::Tensor& input, // (E, T, 2*H)
// const at::Tensor& counts, // (E)
// at::Tensor& y_q, // (E, T, H) [OUT]
// at::Tensor& y_s, // (E, T, H//group_size) [OUT]
// bool use_ue8m0);
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
......@@ -304,7 +315,7 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit);
bool use_exllama, bool use_v2_format, int64_t bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
......@@ -318,17 +329,19 @@ void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
// std::optional<torch::Tensor> const& scale_ub);
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B,
const torch::Tensor& C,
const std::optional<torch::Tensor>& D_,
const std::optional<torch::Tensor>& z_,
const std::optional<torch::Tensor>& delta_bias_,
bool delta_softplus,
const std::optional<torch::Tensor>& query_start_loc,
const std::optional<torch::Tensor>& cache_indices,
const std::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states, int64_t pad_slot_id);
void selective_scan_fwd(
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
const torch::Tensor& B, const torch::Tensor& C,
const std::optional<torch::Tensor>& D_,
const std::optional<torch::Tensor>& z_,
const std::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
const std::optional<torch::Tensor>& query_start_loc,
const std::optional<torch::Tensor>& cache_indices,
const std::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
const std::optional<torch::Tensor>& initial_state_idx);
torch::Tensor dynamic_4bit_int_moe_cpu(
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
......
......@@ -7,7 +7,7 @@
#include "../cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
#include "quantization/w8a8/fp8/common.cuh"
#include <c10/util/Float8_e4m3fn.h>
......@@ -114,13 +114,22 @@ __global__ void act_and_mul_quant_kernel(
}
__device__ __forceinline__ float silu(float x) {
return (__fdividef(x, (1.f + expf(-x))));
return __fdividef(x, (1.f + expf(-x)));
}
__device__ __forceinline__ float2 silu2(float2 x) {
return make_float2(silu(x.x), silu(x.y));
}
__device__ __forceinline__ __nv_bfloat162 silu2_v2(float2 x) {
#ifndef USE_ROCM
return make_bfloat162(__float2bfloat16_rn(silu(x.x)),
__float2bfloat16_rn(silu(x.y)));
#else
return __float22bfloat162_rn(make_float2(silu(x.x), silu(x.y)));
#endif
}
#ifndef USE_ROCM
__device__ __forceinline__ float warp_max(float v) {
static constexpr unsigned FULL_MASK = 0xffffffffu;
......@@ -223,224 +232,337 @@ constexpr __nv_bfloat16 get_fp8_min() {
return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032});
}
}
#ifndef USE_ROCM
template <typename fp8_type, int32_t NUM_WARPS, typename Idx_t,
int NUM_PARALLEL_TOKENS, bool USE_UE8M0, int GROUP_SIZE = 128,
int NUM_STAGES = 3>
template <typename Idx_t>
__device__ __forceinline__ int warp_expert_search(
int idx, int n, const Idx_t* __restrict__ input, Idx_t val) {
const Idx_t* input_ptr = input + idx;
int base_offset = 0;
for (;;) {
bool move_on = (idx < n && *input_ptr <= val);
unsigned mask = __ballot_sync(0xffffffff, move_on);
if (mask != 0xffffffffu) {
int last_lane = 31 - __clz(mask);
return base_offset + last_lane;
}
input_ptr += 32;
base_offset += 32;
idx += 32;
}
}
template <int num_parallel_tokens>
__device__ __forceinline__ void token_bounds(int32_t n_tokens,
int32_t worker_id,
int32_t& n_tokens_lower,
int32_t& n_tokens_upper) {
if (n_tokens < num_parallel_tokens && worker_id < n_tokens) {
if (worker_id >= num_parallel_tokens) return;
n_tokens_lower = worker_id;
n_tokens_upper = worker_id + 1;
} else {
int32_t chunk_size = n_tokens / num_parallel_tokens;
int32_t residual = n_tokens - chunk_size * num_parallel_tokens;
auto calc_id = [&](int32_t id) {
if (id < residual)
return min(n_tokens, id * (chunk_size + 1));
else
return min(n_tokens, id * chunk_size + residual);
};
n_tokens_lower = calc_id(worker_id);
n_tokens_upper = calc_id(worker_id + 1);
}
}
template <int BLOCK_COUNT, int SMEM_SIZE_BYTES_Y, typename fp8_type,
typename scale_t, int THREADS, typename Idx_t, bool CEIL_UE8M0,
int GROUP_SIZE = 128, int NUM_STAGES = 3>
__global__ void silu_mul_fp8_quant_deep_gemm_kernel(
const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q,
float* __restrict__ _y_s, const int32_t* __restrict__ counts,
scale_t* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert,
// sizes
int H, int G,
Idx_t E, Idx_t T, Idx_t H,
// strides (in elements)
Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e,
Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t,
Idx_t stride_ys_g, Idx_t stride_counts_e) {
static constexpr __nv_bfloat16 fp8_min = get_fp8_min<fp8_type>();
static constexpr __nv_bfloat16 fp8_max = get_fp8_max<fp8_type>();
// We assign EPS with its 16-bit unsigned counterpart to allow constexpr.
static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996});
Idx_t stride_ys_g, Idx_t stride_ys_p, Idx_t stride_counts_e) {
#ifndef USE_ROCM
static constexpr int NUM_WARPS = THREADS / WARP_SIZE;
// We pack 8 16-bit bfloat16 values into a 128-bit __int128_t.
static constexpr int32_t BFLOAT16_PER_GROUP = 8;
static constexpr int LOAD_STAGE_SIZE = 2 * GROUP_SIZE / 8;
static constexpr int LOAD_STAGE_MOD = NUM_STAGES * LOAD_STAGE_SIZE;
// We split the shared memory in half, corresponding to gate and up matrices:
// [...gate_i, ...up_i] where 0 <= i < stages.
static constexpr int32_t S_NUM_128 =
2u * (GROUP_SIZE / BFLOAT16_PER_GROUP) * NUM_WARPS * NUM_STAGES;
static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE;
static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2;
static constexpr int32_t S_NUM_64 = S_NUM_128 * 2;
__shared__ __int128_t __align__(16) s_buff_128[S_NUM_128];
static constexpr int COMPUTE_STAGE_SIZE = 2 * GROUP_SIZE / 4;
static constexpr int COMPUTE_STAGE_MOD = COMPUTE_STAGE_SIZE * NUM_STAGES;
const int32_t tid = threadIdx.x;
const int32_t warp_id = tid / WARP_SIZE;
const int32_t lane_id = tid % WARP_SIZE;
extern __shared__ __align__(16) __int128_t smem_128[];
auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128);
int* s_expert_offsets =
reinterpret_cast<int*>(smem_128 + (SMEM_SIZE_BYTES_Y / 16));
// block handles one (expert e, group g)
int32_t pid = blockIdx.x;
int32_t e = pid / G;
int32_t g = pid % G;
static constexpr __nv_bfloat16 fp8_min = get_fp8_min<fp8_type>();
static constexpr __nv_bfloat16 fp8_max = get_fp8_max<fp8_type>();
// We assign EPS with it's 16-bit unsigned counterpart to allow constexpr.
static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996});
int tid = threadIdx.x;
int warp_id = tid >> 5;
int lane_id = tid & 0x1f;
int running_sum{};
if (!warp_id) {
for (int i = 0; i < E; i += WARP_SIZE) {
bool valid = (i + threadIdx.x) < E;
int value =
(valid ? tokens_per_expert[i + threadIdx.x * stride_counts_e] : 0) +
(!lane_id ? running_sum : 0);
for (int offset = 1; offset < 32; offset *= 2) {
int n = __shfl_up_sync(0xFFFFFFFFu, value, offset);
if (lane_id >= offset) value += n;
}
const int32_t n_tokens = counts[e * stride_counts_e];
if (valid) {
s_expert_offsets[i + threadIdx.x + 1] = value;
}
if (!n_tokens) {
return; // Exit ASAP.
running_sum = __shfl_sync(0xFFFFFFFFu, value, WARP_SIZE - 1);
}
if (!lane_id) {
s_expert_offsets[0] = 0;
}
}
const Idx_t stride_i_t_128 = stride_i_t / 8u;
__syncthreads();
int32_t n_tokens_lower, n_tokens_upper;
int32_t total_tokens = s_expert_offsets[E];
const int warp_position_yq = warp_id * (H / NUM_WARPS);
const int warp_position_scales = warp_id * (H / (GROUP_SIZE * NUM_WARPS));
// A single block will handle tokens_per_block tokens.
// Each block i iterates over tokens of a slice of n_tokens =
// expert_counts[i], with the size of chunk being
// (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of
// updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling.
if (n_tokens < NUM_PARALLEL_TOKENS && blockIdx.y < n_tokens) {
// Specialize this, but can be likely fused.
if (blockIdx.y >= NUM_PARALLEL_TOKENS) {
return;
}
n_tokens_lower = blockIdx.y;
n_tokens_upper = blockIdx.y + 1;
} else {
auto chunk_size = n_tokens / NUM_PARALLEL_TOKENS;
auto residual = n_tokens - chunk_size * NUM_PARALLEL_TOKENS;
auto calc_id = [&](int32_t id) {
if (id < residual) {
return min(n_tokens, id * (chunk_size + 1));
} else {
return min(n_tokens, id * chunk_size + residual);
}
};
n_tokens_lower = calc_id(blockIdx.y);
n_tokens_upper = calc_id(blockIdx.y + 1);
}
if (n_tokens_lower >= n_tokens_upper) {
// Each warp will get space to store its hidden dim for gate and up.
__int128_t* s_hidden_load = smem_128 + warp_id * ((2 * 128 / 8) * NUM_STAGES);
__int128_t* smem_load_ptr = s_hidden_load + lane_id;
const __nv_bfloat16 fp8_inv = __hdiv(__float2bfloat16(1.f), fp8_max);
int32_t compute_pipeline_offset_64 = 0;
int32_t load_stage_offset{};
const __nv_bfloat16 one_bf16 = __float2bfloat16_rn(1.f);
__int64_t* smem_compute_ptr = reinterpret_cast<__int64_t*>(smem_128) +
warp_id * (2 * (GROUP_SIZE / 4) * NUM_STAGES) +
lane_id;
__int64_t* s_gate64_ptr = smem_compute_ptr;
__int64_t* s_up64_ptr = smem_compute_ptr + GROUP_SIZE / 4;
int tokens_lower, tokens_upper;
token_bounds<BLOCK_COUNT>(total_tokens, blockIdx.x, tokens_lower,
tokens_upper);
Idx_t expert_id{}, expert_offset{}, next_expert_offset{};
int token_id = tokens_lower;
int32_t t_load{};
if (token_id < tokens_upper) {
expert_id = warp_expert_search<int>(lane_id, E, s_expert_offsets, token_id);
expert_offset = s_expert_offsets[expert_id];
next_expert_offset = s_expert_offsets[expert_id + 1];
} else {
// This thread block has no work to do.
return;
}
// We do calculations here, using constexpr wherever possible.
const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h;
const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g;
const Idx_t base_yq =
e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h;
Idx_t gate_off_128 = (base_i / static_cast<Idx_t>(8u));
auto input_128_ptr = reinterpret_cast<const __int128_t*>(_input);
auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % HALF_THREAD_COUNT) +
stride_i_t_128 * n_tokens_lower;
auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u;
auto y_s_ptr =
_y_s + base_ys + warp_id * stride_ys_g + n_tokens_lower * stride_ys_t;
auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE +
stride_yq_t * n_tokens_lower + 4 * lane_id;
int32_t t_load = n_tokens_lower, load_stage_id = 0;
auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT);
auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u;
int32_t stage_offset{};
static constexpr int32_t LOAD_STAGE_SIZE = (NUM_WARPS * WARP_SIZE / 2);
static constexpr int32_t LOAD_STAGE_MOD =
NUM_STAGES * (NUM_WARPS * WARP_SIZE / 2);
// Two halves of all threads in a block conduct global loads for gate and up,
// repsectively.
int t_load_bound = H / (GROUP_SIZE * NUM_WARPS);
Idx_t base_i = ((expert_id * stride_i_e) / 8) +
(token_id - expert_offset) * stride_i_t / 8;
const Idx_t gate_warp_offset =
warp_id * ((stride_i_h * H) / (8 * NUM_WARPS)) + (lane_id & 0b1111);
const __int128_t* input_128_ptr =
reinterpret_cast<const __int128_t*>(_input) + gate_warp_offset +
((lane_id < 16) ? 0 : ((H * stride_i_h) / 8));
__int128_t* load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i);
auto token_offset = token_id - expert_offset;
auto load_and_advance_y_pred = [&] {
if (t_load < n_tokens_upper) {
auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset;
auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset;
if (t_load < t_load_bound) {
// Here we are simply continuing to load data
// from the current token.
auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset;
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
// unnecessary ALU ops.
stage_offset += LOAD_STAGE_SIZE;
stage_offset %= LOAD_STAGE_MOD;
load_stage_offset += LOAD_STAGE_SIZE;
load_stage_offset %= LOAD_STAGE_MOD;
if (tid < HALF_THREAD_COUNT) {
cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr);
gate_128_ptr += stride_i_t_128;
cp_async4(smem_load_ptr_staged, load_ptr);
load_ptr += GROUP_SIZE / 8;
++t_load;
} else if (token_id + 1 < tokens_upper) {
// We loaded everything from the current token, let's move on
// to the next one, and we checked that we have more tokens to load.
++token_id;
t_load = 0;
if (token_id >= next_expert_offset) {
// We need to find the next expert.
do {
// This is a loop because it's possible
// that some experts are assigned 0 tokens.
// NOTE: We are guaranteed that there's at least
// one more token left so we don't have to check for
// expert_id bounds.
++expert_id;
// This skips 1 memory read.
expert_offset = next_expert_offset;
next_expert_offset = s_expert_offsets[expert_id + 1];
} while (next_expert_offset == expert_offset);
base_i = expert_id * (stride_i_e / 8);
token_offset = 0;
load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i);
} else {
cp_async4(s_up_stage_128_staged_ptr, up_128_ptr);
up_128_ptr += stride_i_t_128;
// We remain within the same expert, so just
// move by H/4 __int128_t (2 * H/8).
base_i += stride_yq_t / 4;
token_offset++;
}
load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i);
auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset;
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
// unnecessary ALU ops.
load_stage_offset += LOAD_STAGE_SIZE;
load_stage_offset %= LOAD_STAGE_MOD;
cp_async4(smem_load_ptr_staged, load_ptr);
load_ptr += GROUP_SIZE / 8;
++t_load;
++load_stage_id;
}
// We fence even if there is nothing to load to simplify pipelining.
cp_async_fence();
};
// We need to warm-up the pipeline.
#pragma unroll
for (int i = 0; i < NUM_STAGES - 1; i++) {
load_and_advance_y_pred();
}
__int64_t* s_gate_ptr = reinterpret_cast<__int64_t*>(
s_buff_compute_32 + warp_id * (GROUP_SIZE / 2)) +
lane_id;
__int64_t* s_up_ptr = s_gate_ptr + S_NUM_64 / 2;
__nv_fp8x4_e4m3* y_q_base_ptr =
reinterpret_cast<__nv_fp8x4_e4m3*>(_y_q) + lane_id;
static constexpr int32_t STAGE_SIZE = (GROUP_SIZE * NUM_WARPS) / 4u;
static constexpr int32_t STAGE_MOD = STAGE_SIZE * NUM_STAGES;
Idx_t scale_group_offset = 0;
if constexpr (std::is_same<scale_t, uint8_t>::value) {
// packed int32_t format
int pack_id = warp_position_scales / 4;
int scale_in_pack = warp_position_scales % 4;
scale_group_offset = pack_id * stride_ys_p + scale_in_pack * stride_ys_g;
} else {
scale_group_offset = warp_position_scales * stride_ys_g;
}
int32_t compute_pipeline_offset_64 = 0;
scale_t* const y_scale_base_ptr = _y_s + scale_group_offset;
for (int32_t t = n_tokens_lower; t < n_tokens_upper; ++t) {
__nv_bfloat162 results_bf162[2];
for (auto j = tokens_lower; j < tokens_upper; j++) {
int current_group_id = warp_position_scales; // Running count of which
// group is being processed
const Idx_t base_ys = expert_id * stride_ys_e;
auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t;
__nv_fp8x4_e4m3* y_q_ptr =
y_q_base_ptr + (expert_id * stride_yq_e + token_offset * stride_yq_t +
warp_position_yq * stride_yq_h) /
4;
const int COMPUTE_LIMIT = H / (GROUP_SIZE * NUM_WARPS);
cp_async_wait<NUM_STAGES - 2>();
__syncthreads();
for (int i = 0; i < COMPUTE_LIMIT; i++) {
cp_async_wait<NUM_STAGES - 2>();
__syncthreads();
load_and_advance_y_pred();
// We double-buffer pipelined loads so that the next load will
// concurrently run with compute without overwrites.
load_and_advance_y_pred();
__int64_t* gate64_ptr = s_gate64_ptr + compute_pipeline_offset_64;
__int64_t* up64_ptr = s_up64_ptr + compute_pipeline_offset_64;
auto s_gate_compute_64 = s_gate_ptr + compute_pipeline_offset_64;
auto s_up_compute_64 = s_up_ptr + compute_pipeline_offset_64;
// COMPUTE_STAGE_SIZE/MOD must also be constexpr!
compute_pipeline_offset_64 += COMPUTE_STAGE_SIZE;
compute_pipeline_offset_64 %= COMPUTE_STAGE_MOD;
// STAGE_SIZE must also be constexpr!
compute_pipeline_offset_64 += STAGE_SIZE;
compute_pipeline_offset_64 %= STAGE_MOD;
__int64_t gate64 = *gate64_ptr;
__int64_t up64 = *up64_ptr;
// Each thread loads (gate/up) 2X 4X bfloat16 values into registers.
__int64_t gate64 = *s_gate_compute_64;
__nv_bfloat162* s_gate_compute_32 =
reinterpret_cast<__nv_bfloat162*>(&gate64);
__int64_t up64 = *s_up_compute_64;
__nv_bfloat162* s_up_compute_32 = reinterpret_cast<__nv_bfloat162*>(&up64);
// Compute
__nv_bfloat162 res[2];
__nv_bfloat162* s_up_comp = reinterpret_cast<__nv_bfloat162*>(&up64);
__nv_bfloat162* s_gate_comp = reinterpret_cast<__nv_bfloat162*>(&gate64);
#pragma unroll
for (int i = 0; i < 2; i++) {
// For silu, we make sure that div is emitted.
float2 gate = silu2(__bfloat1622float2(s_gate_compute_32[i]));
results_bf162[i] = __float22bfloat162_rn(gate);
}
#pragma unroll
for (int i = 0; i < 2; i++) {
results_bf162[i] = __hmul2(results_bf162[i], s_up_compute_32[i]);
}
for (int32_t k = 0; k < 2; ++k) {
__nv_bfloat162 gate = silu2_v2(__bfloat1622float2(s_gate_comp[k]));
res[k] = __hmul2(gate, s_up_comp[k]);
}
auto _y_max2 =
__hmax2(__habs2(results_bf162[0]), __habs2(results_bf162[1]));
auto _y_max2 = __hmax2(__habs2(res[0]), __habs2(res[1]));
__nv_bfloat16 y_max_bf16 = __hmax(EPS, __hmax(_y_max2.x, _y_max2.y));
_y_max2.x = __hmax(__hmax(_y_max2.x, _y_max2.y), EPS);
// An entire group is assigned to a single warp, so a simple warp reduce
// is used.
__nv_bfloat16 y_s = warp_max(y_max_bf16) / fp8_max;
__nv_bfloat16 y_s = __hmul(warp_max(_y_max2.x), fp8_inv);
if constexpr (USE_UE8M0) {
y_s = hexp2(hceil(hlog2(y_s)));
}
if constexpr (CEIL_UE8M0) {
y_s = hexp2(hceil(hlog2(y_s)));
}
auto inv_y = __float2bfloat16_rn(1.f) / y_s;
__nv_bfloat16 inv_y = __hdiv(one_bf16, y_s);
auto y_s2 = make_bfloat162(inv_y, inv_y);
auto y_s2 = make_bfloat162(inv_y, inv_y);
#pragma unroll
for (int32_t i = 0; i < 2; ++i) {
results_bf162[i] =
clip(__hmul2(results_bf162[i], y_s2), __bfloat162bfloat162(fp8_min),
__bfloat162bfloat162(fp8_max));
}
for (int32_t k = 0; k < 2; ++k) {
res[k] = clip(__hmul2(res[k], y_s2), __bfloat162bfloat162(fp8_min),
__bfloat162bfloat162(fp8_max));
}
*y_q_ptr = __nv_fp8x4_e4m3(res[0], res[1]);
y_q_ptr += WARP_SIZE * stride_yq_h;
if (!lane_id) {
// Store scales.
if constexpr (std::is_same<scale_t, uint8_t>::value) {
// Packed UE8MO format. Remove Mantissa.
*y_s_ptr = reinterpret_cast<int16_t&>(y_s) >> 7;
bool const jump_pack = (current_group_id + 1) % 4 == 0;
// Minus 3 because we need to get to the first group in the
// next pack.
y_s_ptr += jump_pack ? (stride_ys_p - 3) : stride_ys_g;
auto fp8x4 = __nv_fp8x4_e4m3(results_bf162[0], results_bf162[1]);
*reinterpret_cast<__nv_fp8x4_e4m3*>(y_q_ptr) = fp8x4;
y_q_ptr += stride_yq_t;
} else {
// float32 format
static_assert(std::is_same<scale_t, float>::value);
*y_s_ptr = y_s;
y_s_ptr += stride_ys_g;
}
if (lane_id == 0) {
*y_s_ptr = y_s;
y_s_ptr += stride_ys_t;
current_group_id += 1;
}
}
}
}
#endif
}
} // namespace vllm
......@@ -475,25 +597,26 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
void silu_mul_fp8_quant_deep_gemm_cuda(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E)
at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens) {
void persistent_masked_m_silu_mul_quant(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& tokens_per_expert, // (E)
at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
bool cast_scale_ue8m0) {
#ifndef USE_ROCM
// This kernel relies heavily on cp.async and fp8 support.
// This kernel currently only supports H % 128 == 0 and assumes a
// fixed GROUP_SIZE of 128.
static constexpr int GROUP_SIZE = 128;
TORCH_CHECK(input.dtype() == torch::kBFloat16);
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
y_q.dtype() == torch::kFloat8_e4m3fnuz);
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
TORCH_CHECK(input.size(-1) % 256 == 0);
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
// Check that num_parallel_tokens is of power of 2 and between 1 and 64.
TORCH_CHECK(1 <= num_parallel_tokens && num_parallel_tokens <= 64);
TORCH_CHECK(!(num_parallel_tokens & (num_parallel_tokens - 1)));
bool const is_packed_ue8m0 =
(y_s.dtype() == torch::kInt32 && cast_scale_ue8m0);
TORCH_CHECK(y_s.dtype() == torch::kFloat32 || is_packed_ue8m0);
using Idx_t = int64_t;
......@@ -506,85 +629,107 @@ void silu_mul_fp8_quant_deep_gemm_cuda(
Idx_t stride_yq_e = y_q.stride(0);
Idx_t stride_yq_t = y_q.stride(1);
Idx_t stride_yq_h = y_q.stride(2);
Idx_t stride_ys_e = y_s.stride(0);
Idx_t stride_ys_t = y_s.stride(1);
Idx_t stride_ys_g = y_s.stride(2);
Idx_t stride_counts_e = counts.stride(0);
Idx_t stride_counts_e = tokens_per_expert.stride(0);
static constexpr int GROUP_SIZE = 128;
int const NUM_GROUPS = H / GROUP_SIZE;
#define KERNEL_FN \
if (use_ue8m0) { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
NUM_PARALLEL_TOKENS, true> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
stride_counts_e); \
} else { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
NUM_PARALLEL_TOKENS, false> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
stride_counts_e); \
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#define KERNEL_CALL_H \
if (H % (4 * GROUP_SIZE) == 0) { \
static constexpr int NUM_WARPS = 4; \
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
KERNEL_FN \
} else { \
static constexpr int NUM_WARPS = 1; \
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
KERNEL_FN \
// TODO: Get this from cuda_arch ?
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
#define KERNEL(BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, STAGES) \
static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \
int sms = SILU_V2_BLOCK_COUNT; \
static constexpr int max_shared_mem_bytes = \
GROUP_SIZE * 2 * STAGES * NUM_WARPS * 2; \
dim3 grid(sms), block(THREAD_COUNT); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
VLLM_DISPATCH_FP8_TYPES( \
y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, scale_t, THREAD_COUNT, \
Idx_t, CEIL_UE8M0, GROUP_SIZE, STAGES> \
<<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), \
reinterpret_cast<scale_t*>(y_s.data_ptr()), \
reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \
T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \
stride_yq_t, stride_yq_h, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, stride_counts_e); \
});
#define LAUNCH_ON_H(scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
STRIDE_YS_P, CEIL_UE8M0) \
if (H >= 4096 && (NUM_GROUPS % 8) == 0) { \
/* 8 warp config */ \
static constexpr int NUM_STAGES = 4; \
static constexpr int THREAD_COUNT = 256; \
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, NUM_STAGES); \
} else { \
/* 1 warp config */ \
static constexpr int THREAD_COUNT = 32; \
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, 2); \
}
#define KERNEL_CALL_TOP_LEVEL \
if (num_parallel_tokens == 1) { \
static constexpr int NUM_PARALLEL_TOKENS = 1; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 2) { \
static constexpr int NUM_PARALLEL_TOKENS = 2; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 4) { \
static constexpr int NUM_PARALLEL_TOKENS = 4; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 8) { \
static constexpr int NUM_PARALLEL_TOKENS = 8; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 16) { \
static constexpr int NUM_PARALLEL_TOKENS = 16; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 32) { \
static constexpr int NUM_PARALLEL_TOKENS = 32; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 64) { \
static constexpr int NUM_PARALLEL_TOKENS = 64; \
KERNEL_CALL_H \
}
Idx_t stride_ys_e = y_s.stride(0);
Idx_t stride_ys_t = y_s.stride(1);
Idx_t stride_ys_g = y_s.stride(2);
Idx_t stride_ys_p = 0;
if (!cast_scale_ue8m0) {
TORCH_CHECK(!is_packed_ue8m0);
LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
false);
return;
}
Idx_t G;
dim3 block, grid;
auto populate_launch_params = [&](int num_warps, int _num_parallel_tokens) {
G = H / Idx_t(group_size * num_warps);
grid = dim3(E * G, _num_parallel_tokens);
block = dim3(num_warps * WARP_SIZE);
};
if (!is_packed_ue8m0) {
// UE8M0 but not packed
LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
true);
return;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
VLLM_DISPATCH_FP8_TYPES(y_q.scalar_type(),
"silu_mul_fp8_quant_deep_gemm_kernel",
[&] { KERNEL_CALL_TOP_LEVEL });
TORCH_CHECK(cast_scale_ue8m0 && is_packed_ue8m0);
TORCH_CHECK(y_s.dtype() == torch::kInt32);
// Int32 packed ue8m0 scales tensor.
// Let E, T, G be the number to experts, number of tokens and number of groups
// respectively. Let, E = 2, T = 4, G = 6, in this case the int32 scales
// tensor are of shape [1, 4, 2] and stride [8, 1, 4]. The scales are expected
// to be arranged as follows,
// [[T0G0-T0G1-T0G2-T0G3, T0G4-T0G5-X-X,],
// [T1G0-T1G1-T1G2-T1G3, T1G4-T1G5-X-X,]
// [T2G0-T2G1-T2G2-T2G3, T2G4-T2G5-X-X,]
// [T3G0-T3G1-T3G2-T3G3, T3G4-T3G5-X-X,]]
// where, TxGy is the scale ue8m0 scale value of Token x, Group y.
//
// In memory (in bytes) the scale values are arranged as,
// [T0G0, T0G1, T0G2, T0G3, T1G0, T1G2, T1G3, T1G4, T2G0, T2G1, T2G3, T2G4,
// T3G0, T3G1, T3G2, T3G3, T0G4, T0G5, X, X, T1G4, T1G5, X, X, T2G4, T2G5,
// X, X, T3G4, T3G5, X, X]
//
// An Int32 tensor of size [1, 4, 2] and stride [8, 1, 4] can be represented
// as an uint8 tensor of shape [1, 2, 4, 4] and stride [32, 16, 4, 1]. In
// english, ignoring the Experts dimension, the original int32 tensor is
// simply treated as two packed [4, 4] uint8 tensor (or two [4, 1] int32
// tensor). The following strides setting reflects this change. Caveat: This
// means that the G dimension is no longer contiguous. i.e. Note that to move
// from G3 to G4, we need to jump along the packing dimension. The kernel
// handles this case.
stride_ys_e *= sizeof(int32_t);
stride_ys_p = T * sizeof(int32_t); // Packing dimension
stride_ys_t = sizeof(int32_t);
stride_ys_g = 1;
LAUNCH_ON_H(uint8_t, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
true);
#endif
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment