Commit b9e12416 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.3

parents e5d707db e9d3aa04
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void rms_norm_impl(scalar_t *__restrict__ out, void rms_norm_impl(scalar_t* __restrict__ out,
const scalar_t *__restrict__ input, const scalar_t* __restrict__ input,
const scalar_t *__restrict__ weight, const float epsilon, const scalar_t* __restrict__ weight, const float epsilon,
const int num_tokens, const int hidden_size) { const int num_tokens, const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
...@@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out, ...@@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out,
} }
template <typename scalar_t> template <typename scalar_t>
void fused_add_rms_norm_impl(scalar_t *__restrict__ input, void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
scalar_t *__restrict__ residual, scalar_t* __restrict__ residual,
const scalar_t *__restrict__ weight, const scalar_t* __restrict__ weight,
const float epsilon, const int num_tokens, const float epsilon, const int num_tokens,
const int hidden_size) { const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
...@@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input, ...@@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
} }
} }
} }
} // namespace } // namespace
void rms_norm(torch::Tensor &out, torch::Tensor &input, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor &weight, float epsilon) { float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
CPU_KERNEL_GUARD_IN(rms_norm_impl) CPU_KERNEL_GUARD_IN(rms_norm_impl)
rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
hidden_size); hidden_size);
CPU_KERNEL_GUARD_OUT(rms_norm_impl) CPU_KERNEL_GUARD_OUT(rms_norm_impl)
}); });
} }
void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor &weight, float epsilon) { torch::Tensor& weight, float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
......
...@@ -4,22 +4,21 @@ ...@@ -4,22 +4,21 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void rotary_embedding_impl( void rotary_embedding_impl(
const int64_t const int64_t* __restrict__ positions, // [batch_size, seq_len] or
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens] // [num_tokens]
scalar_t scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or /// head_size] or [num_tokens, num_heads,
/// [num_tokens, num_heads, head_size] /// head_size]
scalar_t scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or // head_size] or [num_tokens, num_kv_heads,
// [num_tokens, num_kv_heads, head_size] // head_size]
const scalar_t const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size, const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) { const int num_tokens) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
constexpr int ELEM_SIZE = sizeof(scalar_t);
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0);
...@@ -27,7 +26,7 @@ void rotary_embedding_impl( ...@@ -27,7 +26,7 @@ void rotary_embedding_impl(
#pragma omp parallel for #pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
for (int i = 0; i < num_heads; ++i) { for (int i = 0; i < num_heads; ++i) {
const int head_idx = i; const int head_idx = i;
...@@ -95,16 +94,16 @@ void rotary_embedding_impl( ...@@ -95,16 +94,16 @@ void rotary_embedding_impl(
template <typename scalar_t> template <typename scalar_t>
void rotary_embedding_gptj_impl( void rotary_embedding_gptj_impl(
const int64_t const int64_t* __restrict__ positions, // [batch_size, seq_len] or
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens] // [num_tokens]
scalar_t scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or /// head_size] or [num_tokens, num_heads,
/// [num_tokens, num_heads, head_size] /// head_size]
scalar_t scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or // head_size] or [num_tokens, num_kv_heads,
// [num_tokens, num_kv_heads, head_size] // head_size]
const scalar_t const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size, const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) { const int num_tokens) {
...@@ -114,13 +113,13 @@ void rotary_embedding_gptj_impl( ...@@ -114,13 +113,13 @@ void rotary_embedding_gptj_impl(
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_heads; ++i) { for (int i = 0; i < num_heads; ++i) {
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr; const scalar_t* cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i; const int head_idx = i;
const int64_t token_head = const int64_t token_head =
token_idx * query_stride + head_idx * head_size; token_idx * query_stride + head_idx * head_size;
scalar_t *head_query = token_head + query; scalar_t* head_query = token_head + query;
for (int j = 0; j < embed_dim; j += 1) { for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j; const int rot_offset = j;
const int x_index = 2 * rot_offset; const int x_index = 2 * rot_offset;
...@@ -142,12 +141,12 @@ void rotary_embedding_gptj_impl( ...@@ -142,12 +141,12 @@ void rotary_embedding_gptj_impl(
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_kv_heads; ++i) { for (int i = 0; i < num_kv_heads; ++i) {
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr; const scalar_t* cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i; const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
scalar_t *head_key = key + token_head; scalar_t* head_key = key + token_head;
for (int j = 0; j < embed_dim; j += 1) { for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j; const int rot_offset = j;
const int x_index = 2 * rot_offset; const int x_index = 2 * rot_offset;
...@@ -165,11 +164,11 @@ void rotary_embedding_gptj_impl( ...@@ -165,11 +164,11 @@ void rotary_embedding_gptj_impl(
} }
} }
} }
}; // namespace }; // namespace
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor &key, int head_size, torch::Tensor& key, int head_size,
torch::Tensor &cos_sin_cache, bool is_neox) { torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = query.numel() / query.size(-1); int num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;
......
...@@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops // Attention ops
ops.def( ops.def("paged_attention_v1", &paged_attention_v1,
"paged_attention_v1", "Compute the attention between an input query and the cached "
&paged_attention_v1, "keys/values using PagedAttention.");
"Compute the attention between an input query and the cached keys/values using PagedAttention."); ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops // Activation ops
ops.def( ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
"silu_and_mul", ops.def("gelu_and_mul", &gelu_and_mul,
&silu_and_mul, "Activation function used in GeGLU with `none` approximation.");
"Activation function used in SwiGLU."); ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
ops.def( "Activation function used in GeGLU with `tanh` approximation.");
"gelu_and_mul", ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
&gelu_and_mul, ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
"Activation function used in GeGLU with `none` approximation.");
ops.def(
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm // Layernorm
ops.def( ops.def("rms_norm", &rms_norm,
"rms_norm", "Apply Root Mean Square (RMS) Normalization to the input tensor.");
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def( ops.def("fused_add_rms_norm", &fused_add_rms_norm,
"fused_add_rms_norm", "In-place fused Add and RMS Normalization");
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding // Rotary embedding
ops.def( ops.def("rotary_embedding", &rotary_embedding,
"rotary_embedding", "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Cache ops // Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def( cache_ops.def("swap_blocks", &swap_blocks,
"swap_blocks", "Swap in (out) the cache blocks from src to dst");
&swap_blocks, cache_ops.def("copy_blocks", &copy_blocks,
"Swap in (out) the cache blocks from src to dst"); "Copy the cache blocks from src to dst");
cache_ops.def( cache_ops.def("reshape_and_cache", &reshape_and_cache,
"copy_blocks", "Reshape the key and value tensors and cache them");
&copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def(
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
} }
#pragma once #pragma once
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
#else #else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#endif #endif
...@@ -28,6 +29,13 @@ ...@@ -28,6 +29,13 @@
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
#endif #endif
#ifndef USE_ROCM
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
__shfl_down_sync(uint32_t(-1), var, lane_delta)
#else
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
#endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
...@@ -35,4 +43,3 @@ ...@@ -35,4 +43,3 @@
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif #endif
...@@ -2,9 +2,6 @@ ...@@ -2,9 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
int get_device_attribute( int get_device_attribute(int attribute, int device_id);
int attribute,
int device_id);
int get_max_shared_memory_per_block_device_attribute( int get_max_shared_memory_per_block_device_attribute(int device_id);
int device_id);
...@@ -2,34 +2,28 @@ ...@@ -2,34 +2,28 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#endif #endif
int get_device_attribute( int get_device_attribute(int attribute, int device_id) {
int attribute, int device, value;
int device_id) if (device_id < 0) {
{ cudaGetDevice(&device);
int device, value; } else {
if (device_id < 0) { device = device_id;
cudaGetDevice(&device); }
} cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
else { device);
device = device_id; return value;
}
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
return value;
} }
int get_max_shared_memory_per_block_device_attribute(int device_id) {
int get_max_shared_memory_per_block_device_attribute( int attribute;
int device_id) // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
{ // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
int attribute;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
#ifdef USE_ROCM #ifdef USE_ROCM
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
#else #else
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
#endif #endif
return get_device_attribute(attribute, device_id); return get_device_attribute(attribute, device_id);
} }
...@@ -7,11 +7,11 @@ ...@@ -7,11 +7,11 @@
// fake pointer type // fake pointer type
using fptr_t = uint64_t; using fptr_t = uint64_t;
static_assert(sizeof(void *) == sizeof(fptr_t)); static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int rank,
bool full_nvlink) { bool full_nvlink) {
int world_size = offsets.size(); int world_size = offsets.size();
if (world_size > 8) if (world_size > 8)
...@@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, ...@@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
} }
return (fptr_t) new vllm::CustomAllreduce( return (fptr_t) new vllm::CustomAllreduce(
reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(), reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
} }
...@@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, ...@@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
* 5. A[None].expand(2, -1, -1, -1): Not OK * 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK * 6. A[:, 1:, 1:]: Not OK
*/ */
bool _is_weak_contiguous(torch::Tensor &t) { bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() || return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size()); t.numel() * t.element_size());
} }
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
bool full_nvlink) { bool full_nvlink) {
auto inp_size = inp.numel() * inp.element_size(); auto inp_size = inp.numel() * inp.element_size();
// custom allreduce requires input byte size to be multiples of 16 // custom allreduce requires input byte size to be multiples of 16
...@@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, ...@@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
return false; return false;
} }
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
cudaStream_t stream) { cudaStream_t stream) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out)); TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) { switch (out.scalar_type()) {
case at::ScalarType::Float: { case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()), fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float *>(out.data_ptr()), reinterpret_cast<float*>(out.data_ptr()),
out.numel()); out.numel());
break; break;
} }
case at::ScalarType::Half: { case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()), fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half *>(out.data_ptr()), reinterpret_cast<half*>(out.data_ptr()), out.numel());
out.numel());
break; break;
} }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: { case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>( fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()), stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel()); reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
break; break;
} }
#endif #endif
...@@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, ...@@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
} }
} }
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream(); auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
...@@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { ...@@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
_all_reduce(_fa, inp, out, stream); _all_reduce(_fa, inp, out, stream);
} }
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer, void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor &out) { torch::Tensor& out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream(); auto stream = c10::cuda::getCurrentCUDAStream().stream();
...@@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer, ...@@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
} }
void dispose(fptr_t _fa) { void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
delete fa; delete fa;
} }
int meta_size() { return sizeof(vllm::Signal); } int meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor &t, void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets) { const std::vector<int64_t>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr()); fa->register_buffer(handles, offsets, t.data_ptr());
} }
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta( std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) { fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
return fa->get_graph_buffer_ipc_meta(); return fa->get_graph_buffer_ipc_meta();
} }
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles, void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>> &offsets) { const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets); fa->register_graph_buffers(handles, offsets);
} }
...@@ -31,9 +31,9 @@ struct Signal { ...@@ -31,9 +31,9 @@ struct Signal {
alignas(128) uint32_t end[kMaxBlocks][8]; alignas(128) uint32_t end[kMaxBlocks][8];
}; };
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
struct __align__(16) RankSignals { volatile Signal *signals[8]; }; struct __align__(16) RankSignals { volatile Signal* signals[8]; };
// like std::array, but aligned // like std::array, but aligned
template <typename T, int sz> template <typename T, int sz>
...@@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) { ...@@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) {
// scalar add functions // scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and // for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly // bfloat is disabled so we call the intrinsics directly
DINLINE half &assign_add(half &a, half b) { DINLINE half& assign_add(half& a, half b) {
a = __hadd(a, b); a = __hadd(a, b);
return a; return a;
} }
DINLINE float &assign_add(float &a, float b) { return a += b; } DINLINE float& assign_add(float& a, float b) { return a += b; }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
...@@ -80,14 +80,14 @@ template <> ...@@ -80,14 +80,14 @@ template <>
DINLINE nv_bfloat16 downcast_s(float val) { DINLINE nv_bfloat16 downcast_s(float val) {
return __float2bfloat16(val); return __float2bfloat16(val);
} }
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a = __hadd(a, b); a = __hadd(a, b);
return a; return a;
} }
#endif #endif
template <typename T, int N> template <typename T, int N>
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) { DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
#pragma unroll #pragma unroll
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
assign_add(a.data[i], b.data[i]); assign_add(a.data[i], b.data[i]);
...@@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) { ...@@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against // prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes. // other volatile writes.
template <int ngpus> template <int ngpus>
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) { int rank) {
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// reset flag for next time // reset flag for next time
...@@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, ...@@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write // Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// wait until we got true from all ranks // wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x]) while (!self_sg->start[blockIdx.x][threadIdx.x]);
;
} }
__syncthreads(); __syncthreads();
} }
...@@ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, ...@@ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier, // barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses. // we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false> template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) { int rank) {
__syncthreads(); __syncthreads();
// eliminate the case that prior writes are not visible after signals become // eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of // visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than // testing. Might be the case that hardware provides stronger guarantee than
// the memory model. // the memory model.
if constexpr (!final_sync) __threadfence_system(); if constexpr (!final_sync) __threadfence_system();
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// reset flag for next time // reset flag for next time
...@@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, ...@@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write // Latency = 1 p2p write
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
// wait until we got true from all ranks // wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x]) while (!self_sg->end[blockIdx.x][threadIdx.x]);
;
} }
if constexpr (!final_sync) __syncthreads(); if constexpr (!final_sync) __syncthreads();
} }
template <typename P, int ngpus, typename A> template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P *ptrs[], int idx) { DINLINE P packed_reduce(const P* ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]); A tmp = upcast(ptrs[0][idx]);
#pragma unroll #pragma unroll
for (int i = 1; i < ngpus; i++) { for (int i = 1; i < ngpus; i++) {
...@@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) { ...@@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
template <typename T, int ngpus> template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) __global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData *_dp, RankSignals sg, cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result, volatile Signal* self_sg, T* __restrict__ result,
int rank, int size) { int rank, int size) {
using P = typename packed_t<T>::P; using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A; using A = typename packed_t<T>::A;
...@@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1) ...@@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1)
// do the actual reduction // do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
((P *)result)[idx] = ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
} }
end_sync<ngpus, true>(sg, self_sg, rank); end_sync<ngpus, true>(sg, self_sg, rank);
} }
template <typename P> template <typename P>
DINLINE P *get_tmp_buf(volatile Signal *sg) { DINLINE P* get_tmp_buf(volatile Signal* sg) {
return (P *)(((Signal *)sg) + 1); return (P*)(((Signal*)sg) + 1);
} }
template <typename T, int ngpus> template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) __global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData *_dp, RankSignals sg, cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result, volatile Signal* self_sg, T* __restrict__ result,
int rank, int size) { int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x; int stride = gridDim.x * blockDim.x;
...@@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1) ...@@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1)
int start = rank * part; int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part; int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus; int largest_part = part + size % ngpus;
const P *ptrs[ngpus]; const P* ptrs[ngpus];
P *tmps[ngpus]; P* tmps[ngpus];
#pragma unroll #pragma unroll
for (int i = 0; i < ngpus; i++) { for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus; int target = (rank + i) % ngpus;
ptrs[i] = (const P *)_dp->ptrs[target]; ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]); tmps[i] = get_tmp_buf<P>(sg.signals[target]);
} }
auto tmp_out = tmps[0]; auto tmp_out = tmps[0];
...@@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1) ...@@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1)
int gather_from_rank = ((rank + i) % ngpus); int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) { if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx; int dst_idx = gather_from_rank * part + idx;
((P *)result)[dst_idx] = tmps[i][idx]; ((P*)result)[dst_idx] = tmps[i][idx];
} }
} }
} }
...@@ -261,14 +258,14 @@ class CustomAllreduce { ...@@ -261,14 +258,14 @@ class CustomAllreduce {
// below are device pointers // below are device pointers
RankSignals sg_; RankSignals sg_;
std::unordered_map<void *, RankData *> buffers_; std::unordered_map<void*, RankData*> buffers_;
Signal *self_sg_; Signal* self_sg_;
// stores the registered device pointers from all ranks // stores the registered device pointers from all ranks
RankData *d_rank_data_base_, *d_rank_data_end_; RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void *> graph_unreg_buffers_; std::vector<void*> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers // a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char *> ipc_handles_; std::map<IPC_KEY, char*> ipc_handles_;
/** /**
* meta is a pointer to device metadata and temporary buffer for allreduce. * meta is a pointer to device metadata and temporary buffer for allreduce.
...@@ -279,22 +276,22 @@ class CustomAllreduce { ...@@ -279,22 +276,22 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers * note: this class does not own any device memory. Any required buffers
* are passed in from the constructor * are passed in from the constructor
*/ */
CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz, CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz,
const cudaIpcMemHandle_t *handles, const cudaIpcMemHandle_t* handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int rank,
bool full_nvlink = true) bool full_nvlink = true)
: rank_(rank), : rank_(rank),
world_size_(offsets.size()), world_size_(offsets.size()),
full_nvlink_(full_nvlink), full_nvlink_(full_nvlink),
self_sg_(meta), self_sg_(meta),
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)), d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) { for (int i = 0; i < world_size_; i++) {
Signal *rank_sg; Signal* rank_sg;
if (i != rank_) { if (i != rank_) {
char *handle = open_ipc_handle(&handles[i]); char* handle = open_ipc_handle(&handles[i]);
handle += offsets[i]; handle += offsets[i];
rank_sg = (Signal *)handle; rank_sg = (Signal*)handle;
} else { } else {
rank_sg = self_sg_; rank_sg = self_sg_;
} }
...@@ -302,13 +299,13 @@ class CustomAllreduce { ...@@ -302,13 +299,13 @@ class CustomAllreduce {
} }
} }
char *open_ipc_handle(const void *ipc_handle) { char* open_ipc_handle(const void* ipc_handle) {
auto [it, new_handle] = auto [it, new_handle] =
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr}); ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) { if (new_handle) {
char *ipc_ptr; char* ipc_ptr;
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
*((const cudaIpcMemHandle_t *)ipc_handle), *((const cudaIpcMemHandle_t*)ipc_handle),
cudaIpcMemLazyEnablePeerAccess)); cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr; it->second = ipc_ptr;
} }
...@@ -323,7 +320,7 @@ class CustomAllreduce { ...@@ -323,7 +320,7 @@ class CustomAllreduce {
std::vector<int64_t> offsets(num_buffers); std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) { for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i]; auto ptr = graph_unreg_buffers_[i];
void *base_ptr; void* base_ptr;
// note: must share the base address of each allocation, or we get wrong // note: must share the base address of each allocation, or we get wrong
// address // address
if (cuPointerGetAttribute(&base_ptr, if (cuPointerGetAttribute(&base_ptr,
...@@ -331,8 +328,8 @@ class CustomAllreduce { ...@@ -331,8 +328,8 @@ class CustomAllreduce {
(CUdeviceptr)ptr) != CUDA_SUCCESS) (CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr"); throw std::runtime_error("failed to get pointer attr");
CUDACHECK(cudaIpcGetMemHandle( CUDACHECK(cudaIpcGetMemHandle(
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char *)ptr) - ((char *)base_ptr); offsets[i] = ((char*)ptr) - ((char*)base_ptr);
} }
return std::make_pair(handles, offsets); return std::make_pair(handles, offsets);
} }
...@@ -344,13 +341,13 @@ class CustomAllreduce { ...@@ -344,13 +341,13 @@ class CustomAllreduce {
std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
} }
void register_buffer(const std::vector<std::string> &handles, void register_buffer(const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, void *self) { const std::vector<int64_t>& offsets, void* self) {
check_rank_data_capacity(); check_rank_data_capacity();
RankData data; RankData data;
for (int i = 0; i < world_size_; i++) { for (int i = 0; i < world_size_; i++) {
if (i != rank_) { if (i != rank_) {
char *handle = open_ipc_handle(handles[i].data()); char* handle = open_ipc_handle(handles[i].data());
handle += offsets[i]; handle += offsets[i];
data.ptrs[i] = handle; data.ptrs[i] = handle;
} else { } else {
...@@ -371,17 +368,17 @@ class CustomAllreduce { ...@@ -371,17 +368,17 @@ class CustomAllreduce {
// got a different address. IPC handles have internal reference counting // got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small. // mechanism so overhead should be small.
void register_graph_buffers( void register_graph_buffers(
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>> &offsets) { const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size(); auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers); check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers); std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) { for (int i = 0; i < num_buffers; i++) {
auto self_ptr = graph_unreg_buffers_[i]; auto self_ptr = graph_unreg_buffers_[i];
auto &rd = rank_data[i]; auto& rd = rank_data[i];
for (int j = 0; j < world_size_; j++) { for (int j = 0; j < world_size_; j++) {
if (j != rank_) { if (j != rank_) {
char *handle = char* handle =
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i]; handle += offsets[j][i];
rd.ptrs[j] = handle; rd.ptrs[j] = handle;
...@@ -405,7 +402,7 @@ class CustomAllreduce { ...@@ -405,7 +402,7 @@ class CustomAllreduce {
* will cause contention on NVLink bus. * will cause contention on NVLink bus.
*/ */
template <typename T> template <typename T>
void allreduce(cudaStream_t stream, T *input, T *output, int size, void allreduce(cudaStream_t stream, T* input, T* output, int size,
int threads = 512, int block_limit = 36) { int threads = 512, int block_limit = 36) {
auto d = packed_t<T>::P::size; auto d = packed_t<T>::P::size;
if (size % d != 0) if (size % d != 0)
...@@ -418,7 +415,7 @@ class CustomAllreduce { ...@@ -418,7 +415,7 @@ class CustomAllreduce {
std::to_string(kMaxBlocks) + ". Got " + std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit)); std::to_string(block_limit));
RankData *ptrs; RankData* ptrs;
cudaStreamCaptureStatus status; cudaStreamCaptureStatus status;
CUDACHECK(cudaStreamIsCapturing(stream, &status)); CUDACHECK(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) { if (status == cudaStreamCaptureStatusActive) {
......
...@@ -48,7 +48,7 @@ __global__ void dummy_kernel() { ...@@ -48,7 +48,7 @@ __global__ void dummy_kernel() {
} }
template <typename T> template <typename T>
__global__ void set_data(T *data, int size, int myRank) { __global__ void set_data(T* data, int size, int myRank) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
data[idx] = myRank * 0.11f; data[idx] = myRank * 0.11f;
...@@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) { ...@@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) {
} }
template <typename T> template <typename T>
__global__ void convert_data(const T *data1, const T *data2, double *fdata1, __global__ void convert_data(const T* data1, const T* data2, double* fdata1,
double *fdata2, int size) { double* fdata2, int size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
fdata1[idx] = data1[idx]; fdata1[idx] = data1[idx];
...@@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1, ...@@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1,
} }
} }
__global__ void init_rand(curandState_t *state, int size, int nRanks) { __global__ void init_rand(curandState_t* state, int size, int nRanks) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
for (int i = 0; i < nRanks; i++) { for (int i = 0; i < nRanks; i++) {
...@@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) { ...@@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) {
} }
template <typename T> template <typename T>
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth, __global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
int myRank, int nRanks, int size) { int myRank, int nRanks, int size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
...@@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth, ...@@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
} }
template <typename T> template <typename T>
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
int data_size, bool performance_test) { int data_size, bool performance_test) {
T *result; T* result;
cudaStream_t stream; cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
...@@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, ...@@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
cudaIpcMemHandle_t self_data_handle; cudaIpcMemHandle_t self_data_handle;
cudaIpcMemHandle_t data_handles[8]; cudaIpcMemHandle_t data_handles[8];
vllm::Signal *buffer; vllm::Signal* buffer;
T *self_data_copy; T* self_data_copy;
/** /**
* Allocate IPC buffer * Allocate IPC buffer
* *
...@@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, ...@@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
MPI_BYTE, MPI_COMM_WORLD)); MPI_BYTE, MPI_COMM_WORLD));
void *rank_data; void* rank_data;
size_t rank_data_sz = 16 * 1024 * 1024; size_t rank_data_sz = 16 * 1024 * 1024;
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
std::vector<int64_t> offsets(nRanks, 0); std::vector<int64_t> offsets(nRanks, 0);
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
offsets, myRank); offsets, myRank);
auto *self_data = auto* self_data =
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) + reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) +
sizeof(vllm::Signal) + data_size * sizeof(T)); sizeof(vllm::Signal) + data_size * sizeof(T));
// hack buffer registration // hack buffer registration
{ {
std::vector<std::string> handles; std::vector<std::string> handles;
handles.reserve(nRanks); handles.reserve(nRanks);
for (int i = 0; i < nRanks; i++) { for (int i = 0; i < nRanks; i++) {
char *begin = (char *)&data_handles[i]; char* begin = (char*)&data_handles[i];
char *end = (char *)&data_handles[i + 1]; char* end = (char*)&data_handles[i + 1];
handles.emplace_back(begin, end); handles.emplace_back(begin, end);
} }
std::vector<int64_t> offsets(nRanks, std::vector<int64_t> offsets(nRanks,
...@@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, ...@@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
fa.register_buffer(handles, offsets, self_data); fa.register_buffer(handles, offsets, self_data);
} }
double *ground_truth; double* ground_truth;
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
curandState_t *states; curandState_t* states;
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
...@@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, ...@@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
CUDACHECK(cudaStreamDestroy(stream)); CUDACHECK(cudaStreamDestroy(stream));
} }
int main(int argc, char **argv) { int main(int argc, char** argv) {
int nRanks, myRank; int nRanks, myRank;
MPICHECK(MPI_Init(&argc, &argv)); MPICHECK(MPI_Init(&argc, &argv));
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
...@@ -296,7 +296,7 @@ int main(int argc, char **argv) { ...@@ -296,7 +296,7 @@ int main(int argc, char **argv) {
ncclUniqueId id; ncclUniqueId id;
ncclComm_t comm; ncclComm_t comm;
if (myRank == 0) ncclGetUniqueId(&id); if (myRank == 0) ncclGetUniqueId(&id);
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0, MPICHECK(MPI_Bcast(static_cast<void*>(&id), sizeof(id), MPI_BYTE, 0,
MPI_COMM_WORLD)); MPI_COMM_WORLD));
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
......
...@@ -6,32 +6,30 @@ ...@@ -6,32 +6,30 @@
#include <torch/extension.h> #include <torch/extension.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
...@@ -11,26 +11,24 @@ ...@@ -11,26 +11,24 @@
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162; using __nv_bfloat162 = __hip_bfloat162;
#endif #endif
namespace vllm { namespace vllm {
// TODO(woosuk): Further optimize this kernel. // TODO(woosuk): Further optimize this kernel.
template<typename scalar_t> template <typename scalar_t>
__global__ void rms_norm_kernel( __global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float) input[blockIdx.x * hidden_size + idx]; const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x; variance += x * x;
} }
variance = blockReduceSum<float>(variance); variance = blockReduceSum<float>(variance);
...@@ -40,12 +38,12 @@ __global__ void rms_norm_kernel( ...@@ -40,12 +38,12 @@ __global__ void rms_norm_kernel(
__syncthreads(); __syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) input[blockIdx.x * hidden_size + idx]; float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
} }
} }
/* Converter structs for the conversion from torch types to HIP/CUDA types, /* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion to be implemented for now because the relevant type conversion
...@@ -54,51 +52,68 @@ __global__ void rms_norm_kernel( ...@@ -54,51 +52,68 @@ __global__ void rms_norm_kernel(
Each struct should have the member static constexpr bool `exists`: Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type. If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below. If true, the struct should be fully defined as shown in the examples below.
*/ */
template<typename torch_type> template <typename torch_type>
struct _typeConvert { static constexpr bool exists = false; }; struct _typeConvert {
static constexpr bool exists = false;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion // CUDA < 12.0 runs into issues with packed type conversion
template<> template <>
struct _typeConvert<c10::Half> { struct _typeConvert<c10::Half> {
static constexpr bool exists = true; static constexpr bool exists = true;
using hip_type = __half; using hip_type = __half;
using packed_hip_type = __half2; using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); } __device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } __device__ static inline float2 convert(packed_hip_type x) {
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); } return __half22float2(x);
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } }
__device__ static inline hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
}; };
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support // CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely // TODO: Add in ROCm support once public headers handle bf16 maturely
template<> template <>
struct _typeConvert<c10::BFloat16> { struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true; static constexpr bool exists = true;
using hip_type = __nv_bfloat16; using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162; using packed_hip_type = __nv_bfloat162;
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); } __device__ static inline float convert(hip_type x) {
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } return __bfloat162float(x);
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } }
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } __device__ static inline float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
}; };
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) #endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops /* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel. for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented. Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops. Alignment to 16 bytes is required to use 128-bit global memory ops.
*/ */
template<typename scalar_t, int width> template <typename scalar_t, int width>
struct alignas(16) _f16Vec { struct alignas(16) _f16Vec {
/* Not theoretically necessary that width is a power of 2 but should /* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */ almost always be the case for optimization purposes */
static_assert(width > 0 && (width & (width - 1)) == 0, static_assert(width > 0 && (width & (width - 1)) == 0,
"Width is not a positive power of 2!"); "Width is not a positive power of 2!");
using Converter = _typeConvert<scalar_t>; using Converter = _typeConvert<scalar_t>;
...@@ -108,51 +123,49 @@ struct alignas(16) _f16Vec { ...@@ -108,51 +123,49 @@ struct alignas(16) _f16Vec {
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) { __device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i+1]}; T2 temp{data[i], data[i + 1]};
temp += T2{other.data[i], other.data[i+1]}; temp += T2{other.data[i], other.data[i + 1]};
data[i] = temp.x; data[i] = temp.x;
data[i+1] = temp.y; data[i + 1] = temp.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) for (int i = 0; i < width; ++i) data[i] += other.data[i];
data[i] += other.data[i];
} }
return *this; return *this;
} }
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) { __device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i+1]}; T2 temp{data[i], data[i + 1]};
temp *= T2{other.data[i], other.data[i+1]}; temp *= T2{other.data[i], other.data[i + 1]};
data[i] = temp.x; data[i] = temp.x;
data[i+1] = temp.y; data[i + 1] = temp.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) for (int i = 0; i < width; ++i) data[i] *= other.data[i];
data[i] *= other.data[i];
} }
return *this; return *this;
} }
__device__ _f16Vec& operator*=(const float scale) { __device__ _f16Vec& operator*=(const float scale) {
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
float2 temp_f = Converter::convert(T2{data[i], data[i+1]}); float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
temp_f.x *= scale; temp_f.x *= scale;
temp_f.y *= scale; temp_f.y *= scale;
T2 temp = Converter::convert(temp_f); T2 temp = Converter::convert(temp_f);
data[i] = temp.x; data[i] = temp.x;
data[i+1] = temp.y; data[i + 1] = temp.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
float temp = Converter::convert(data[i]) * scale; float temp = Converter::convert(data[i]) * scale;
data[i] = Converter::convert(temp); data[i] = Converter::convert(temp);
...@@ -164,13 +177,13 @@ struct alignas(16) _f16Vec { ...@@ -164,13 +177,13 @@ struct alignas(16) _f16Vec {
__device__ float sum_squares() const { __device__ float sum_squares() const {
float result = 0.0f; float result = 0.0f;
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i+1]}); float2 z = Converter::convert(T2{data[i], data[i + 1]});
result += z.x * z.x + z.y * z.y; result += z.x * z.x + z.y * z.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
float x = Converter::convert(data[i]); float x = Converter::convert(data[i]);
result += x * x; result += x * x;
...@@ -184,15 +197,13 @@ struct alignas(16) _f16Vec { ...@@ -184,15 +197,13 @@ struct alignas(16) _f16Vec {
Additional optimizations we can make in this case are Additional optimizations we can make in this case are
packed and vectorized operations, which help with the packed and vectorized operations, which help with the
memory latency bottleneck. */ memory latency bottleneck. */
template<typename scalar_t, int width> template <typename scalar_t, int width>
__global__ std::enable_if_t< __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic // Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>); static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width); static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
...@@ -203,9 +214,12 @@ __global__ std::enable_if_t< ...@@ -203,9 +214,12 @@ __global__ std::enable_if_t<
/* These and the argument pointers are all declared `restrict` as they are /* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */ in this kernel as that would be undefined behavior */
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input); auto* __restrict__ input_v =
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual); reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight); auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx; int id = blockIdx.x * vec_hidden_size + idx;
...@@ -215,10 +229,11 @@ __global__ std::enable_if_t< ...@@ -215,10 +229,11 @@ __global__ std::enable_if_t<
residual_v[id] = temp; residual_v[id] = temp;
} }
/* Keep the following if-else block in sync with the /* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */ calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) { if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance); variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance); } else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon); s_variance = rsqrtf(variance / hidden_size + epsilon);
} }
...@@ -233,52 +248,50 @@ __global__ std::enable_if_t< ...@@ -233,52 +248,50 @@ __global__ std::enable_if_t<
} }
} }
/* Generic fused_add_rms_norm_kernel /* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations. The width field is not used here but necessary for other specializations.
*/ */
template<typename scalar_t, int width> template <typename scalar_t, int width>
__global__ std::enable_if_t< __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx]; scalar_t z = input[blockIdx.x * hidden_size + idx];
z += residual[blockIdx.x * hidden_size + idx]; z += residual[blockIdx.x * hidden_size + idx];
float x = (float) z; float x = (float)z;
variance += x * x; variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z; residual[blockIdx.x * hidden_size + idx] = z;
} }
/* Keep the following if-else block in sync with the /* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */ calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) { if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance); variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance); } else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon); s_variance = rsqrtf(variance / hidden_size + epsilon);
} }
__syncthreads(); __syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) residual[blockIdx.x * hidden_size + idx]; float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; input[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
} }
} }
} // namespace vllm } // namespace vllm
void rms_norm( void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size]
torch::Tensor& weight, // [hidden_size] float epsilon) {
float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
...@@ -286,40 +299,27 @@ void rms_norm( ...@@ -286,40 +299,27 @@ void rms_norm(
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
input.scalar_type(), vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
"rms_norm_kernel", out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
[&] { weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>( });
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
});
} }
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
"fused_add_rms_norm_kernel", \ vllm::fused_add_rms_norm_kernel<scalar_t, width> \
[&] { \ <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
vllm::fused_add_rms_norm_kernel \ residual.data_ptr<scalar_t>(), \
<scalar_t, width><<<grid, block, 0, stream>>>( \ weight.data_ptr<scalar_t>(), epsilon, \
input.data_ptr<scalar_t>(), \ num_tokens, hidden_size); \
residual.data_ptr<scalar_t>(), \ });
weight.data_ptr<scalar_t>(), \
epsilon, \ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
num_tokens, \ torch::Tensor& residual, // [..., hidden_size]
hidden_size); \ torch::Tensor& weight, // [hidden_size]
}); float epsilon) {
void fused_add_rms_norm(
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
...@@ -342,8 +342,8 @@ void fused_add_rms_norm( ...@@ -342,8 +342,8 @@ void fused_add_rms_norm(
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr()); auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr()); auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr()); auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ bool ptrs_are_aligned =
&& wt_ptr % 16 == 0; inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) { if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8); LAUNCH_FUSED_ADD_RMS_NORM(8);
} else { } else {
......
...@@ -3,5 +3,6 @@ ...@@ -3,5 +3,6 @@
#include <torch/extension.h> #include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); m.def("topk_softmax", &topk_softmax,
"Apply topk softmax to the gating outputs.");
} }
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
void topk_softmax( void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& topk_weights, torch::Tensor& token_expert_indices,
torch::Tensor& topk_indices, torch::Tensor& gating_output);
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);
...@@ -7,119 +7,128 @@ ...@@ -7,119 +7,128 @@
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) #define CEILDIV(x, y) (((x) + (y) - 1) / (y))
namespace vllm { namespace vllm {
namespace { namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
// don't worry about overflow because num_experts is relatively small int32_t col) {
return row * total_col + col; // don't worry about overflow because num_experts is relatively small
} return row * total_col + col;
} }
} // namespace
template <typename scalar_t> template <typename scalar_t>
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
int32_t *sorted_token_ids, int32_t* sorted_token_ids,
int32_t *expert_ids, int32_t* expert_ids,
int32_t *total_tokens_post_pad, int32_t* total_tokens_post_pad,
int32_t num_experts, int32_t num_experts,
int32_t block_size, int32_t block_size, size_t numel) {
size_t numel) { const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread;
const size_t start_idx = threadIdx.x * tokens_per_thread;
extern __shared__ int32_t shared_mem[];
extern __shared__ int32_t shared_mem[];
int32_t* tokens_cnts =
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) int32_t* cumsum =
shared_mem + (num_experts + 1) *
for (int i = 0; i < num_experts; ++i) { num_experts; // 1d tensor with shape (num_experts + 1)
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
} for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
/** }
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are assigned /**
* to expert expert_index. * In the first step we compute token_cnts[thread_index + 1][expert_index],
*/ * which counts how many tokens in the token shard of thread_index are
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { * assigned to expert expert_index.
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; */
} for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
__syncthreads(); }
// For each expert we accumulate the token counts from the different threads. __syncthreads();
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) { // For each expert we accumulate the token counts from the different threads.
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
} for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
__syncthreads(); tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) { __syncthreads();
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) { // We accumulate the token counts of all experts in thread 0.
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; if (threadIdx.x == 0) {
} cumsum[0] = 0;
*total_tokens_post_pad = cumsum[num_experts]; for (int i = 1; i <= num_experts; ++i) {
} cumsum[i] = cumsum[i - 1] +
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
__syncthreads(); block_size) *
block_size;
/**
* For each expert, each thread processes the tokens of the corresponding blocks
* and stores the corresponding expert_id for each block.
*/
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
} }
*total_tokens_post_pad = cumsum[num_experts];
/** }
* Each thread processes a token shard, calculating the index of each token after
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and __syncthreads();
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
* where * represents a padding value(preset in python). /**
*/ * For each expert, each thread processes the tokens of the corresponding
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { * blocks and stores the corresponding expert_id for each block.
int32_t expert_id = topk_ids[i]; */
/** The cumsum[expert_id] stores the starting index of the tokens that the for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] i += block_size) {
* stores the indices of the tokens processed by the expert with expert_id within expert_ids[i / block_size] = threadIdx.x;
* the current thread's token shard. }
*/
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; /**
sorted_token_ids[rank_post_pad] = i; * Each thread processes a token shard, calculating the index of each token
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; * after sorting by expert number. Given the example topk_ids =
} * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
} * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
}
} }
} // namespace vllm
void moe_align_block_size(
torch::Tensor topk_ids, void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
int num_experts, int block_size, torch::Tensor sorted_token_ids,
int block_size, torch::Tensor experts_ids,
torch::Tensor sorted_token_ids, torch::Tensor num_tokens_post_pad) {
torch::Tensor experts_ids, const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor num_tokens_post_pad) { VLLM_DISPATCH_INTEGRAL_TYPES(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
VLLM_DISPATCH_INTEGRAL_TYPES( // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // tensors
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors const int32_t shared_mem =
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); ((num_experts + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);
// set dynamic shared mem // set dynamic shared mem
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>; auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK( AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem)); (void*)kernel, shared_mem));
kernel<<<1, num_experts, shared_mem, stream>>>( kernel<<<1, num_experts, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
sorted_token_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel()); topk_ids.numel());
}); });
} }
...@@ -3,204 +3,139 @@ ...@@ -3,204 +3,139 @@
#include <torch/extension.h> #include <torch/extension.h>
void paged_attention_v1( void paged_attention_v1(
torch::Tensor& out, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& query, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& key_cache, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
torch::Tensor& value_cache, int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
int num_kv_heads, const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
float scale, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
torch::Tensor& block_tables, const int blocksparse_block_size, const int blocksparse_head_sliding_step);
torch::Tensor& seq_lens,
int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale);
void paged_attention_v2( void paged_attention_v2(
torch::Tensor& out, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& exp_sums, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& max_logits, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& tmp_out, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
torch::Tensor& query, int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
torch::Tensor& key_cache, const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
torch::Tensor& value_cache, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
int num_kv_heads, const int blocksparse_block_size, const int blocksparse_head_sliding_step);
float scale,
torch::Tensor& block_tables, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor& seq_lens, float epsilon);
int block_size,
int max_seq_len, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& weight, float epsilon);
const std::string& kv_cache_dtype,
float kv_scale); void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int head_size,
void rms_norm( torch::Tensor& cos_sin_cache, bool is_neox);
torch::Tensor& out,
torch::Tensor& input, void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& weight, torch::Tensor& key, int head_size,
float epsilon); torch::Tensor& cos_sin_cache, bool is_neox,
int rot_dim,
void fused_add_rms_norm( torch::Tensor& cos_sin_cache_offsets);
torch::Tensor& input,
torch::Tensor& residual, void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& weight,
float epsilon); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
void rotary_embedding( void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& positions,
torch::Tensor& query, void gelu_new(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& key,
int head_size, void gelu_fast(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& cos_sin_cache,
bool is_neox);
void batched_rotary_embedding(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox,
int rot_dim,
torch::Tensor& cos_sin_cache_offsets);
void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_tanh_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_new(
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast(
torch::Tensor& out,
torch::Tensor& input);
#ifndef USE_ROCM #ifndef USE_ROCM
torch::Tensor aqlm_gemm( torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& input, const torch::Tensor& codebooks,
const torch::Tensor& codes, const torch::Tensor& scales,
const torch::Tensor& codebooks, const torch::Tensor& codebook_partition_sizes,
const torch::Tensor& scales, const std::optional<torch::Tensor>& bias);
const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias torch::Tensor aqlm_dequant(const torch::Tensor& codes,
); const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes);
torch::Tensor aqlm_dequant(
const torch::Tensor& codes, torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
const torch::Tensor& codebooks, torch::Tensor _scaling_factors, torch::Tensor _zeros,
const torch::Tensor& codebook_partition_sizes int split_k_iters);
);
torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor awq_gemm( torch::Tensor _scaling_factors,
torch::Tensor _in_feats, torch::Tensor _zeros, int split_k_iters, int thx,
torch::Tensor _kernel, int thy);
torch::Tensor _scaling_factors,
torch::Tensor _zeros, torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int split_k_iters); torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k);
torch::Tensor awq_dequantize(
torch::Tensor _kernel, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor _scaling_factors, torch::Tensor& b_meta,
torch::Tensor _zeros, torch::Tensor& b_scales,
int split_k_iters, torch::Tensor& workspace, int64_t num_bits,
int thx, int64_t size_m, int64_t size_n,
int thy); int64_t size_k);
torch::Tensor marlin_gemm( torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& a, torch::Tensor& b_scales, torch::Tensor& g_idx,
torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor& workspace,
torch::Tensor& b_scales, int64_t num_bits, int64_t size_m, int64_t size_n,
torch::Tensor& workspace, int64_t size_k, bool is_k_full);
int64_t size_m,
int64_t size_n, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k); int64_t size_k, int64_t size_n,
int64_t num_bits);
torch::Tensor gptq_marlin_gemm(
torch::Tensor &a, int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor &b_q_weight, torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor &b_scales, torch::Tensor const& b_scales);
torch::Tensor &g_idx,
torch::Tensor &perm,
torch::Tensor &workspace,
int64_t num_bits,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full);
torch::Tensor gptq_marlin_repack(
torch::Tensor &b_q_weight,
torch::Tensor &perm,
int64_t size_k,
int64_t size_n,
int64_t num_bits);
#endif #endif
void squeezellm_gemm( // void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor vec, // float scale);
torch::Tensor mat,
torch::Tensor mul, void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table); torch::Tensor lookup_table);
torch::Tensor gptq_gemm( torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor a, torch::Tensor b_gptq_qzeros,
torch::Tensor b_q_weight, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
torch::Tensor b_gptq_qzeros, bool use_exllama, int bit);
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx, void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit);
bool use_exllama,
int bit); // void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale);
void gptq_shuffle(
torch::Tensor q_weight, // void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor q_perm, // torch::Tensor& scale);
int bit);
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
// void static_scaled_fp8_quant( int block_size, torch::Tensor sorted_token_ids,
// torch::Tensor& out, torch::Tensor experts_ids,
// torch::Tensor& input, torch::Tensor num_tokens_post_pad);
// torch::Tensor& scale);
// void dynamic_scaled_fp8_quant(
// torch::Tensor& out,
// torch::Tensor& input,
// torch::Tensor& scale);
void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM #ifndef USE_ROCM
using fptr_t = uint64_t; using fptr_t = uint64_t;
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int rank,
bool full_nvlink); bool full_nvlink);
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
bool full_nvlink); bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer, void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor &out); torch::Tensor& out);
void dispose(fptr_t _fa); void dispose(fptr_t _fa);
int meta_size(); int meta_size();
void register_buffer(fptr_t _fa, torch::Tensor &t, void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets); const std::vector<int64_t>& offsets);
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa); std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles, fptr_t _fa);
const std::vector<std::vector<int64_t>> &offsets); void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets);
#endif #endif
...@@ -7,14 +7,10 @@ ...@@ -7,14 +7,10 @@
namespace vllm { namespace vllm {
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding( inline __device__ void apply_token_rotary_embedding(
scalar_t* __restrict__ arr, scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index; int x_index, y_index;
scalar_t cos, sin; scalar_t cos, sin;
if (IS_NEOX) { if (IS_NEOX) {
...@@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding( ...@@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding(
arr[y_index] = y * cos + x * sin; arr[y_index] = y * cos + x * sin;
} }
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding( inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] // head_size] or [num_tokens, num_heads,
const scalar_t* cache_ptr, // head_size]
const int head_size, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int num_heads, // head_size] or [num_tokens, num_kv_heads,
const int num_kv_heads, // head_size]
const int rot_dim, const scalar_t* cache_ptr, const int head_size, const int num_heads,
const int token_idx, const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t query_stride, const int64_t key_stride) {
const int64_t key_stride)
{
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr; const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim; const scalar_t* sin_ptr = cache_ptr + embed_dim;
...@@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding( ...@@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, apply_token_rotary_embedding<scalar_t, IS_NEOX>(
sin_ptr, rot_offset, embed_dim); query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
} }
const int nk = num_kv_heads * embed_dim; const int nk = num_kv_heads * embed_dim;
...@@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding( ...@@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding(
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, apply_token_rotary_embedding<scalar_t, IS_NEOX>(
sin_ptr, rot_offset, embed_dim); key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
} }
} }
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel( __global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] const int64_t* __restrict__ positions, // [batch_size, seq_len] or
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] // [num_tokens]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // head_size] or [num_tokens, num_heads,
const int rot_dim, // head_size]
const int64_t query_stride, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int64_t key_stride, // head_size] or [num_tokens, num_kv_heads,
const int num_heads, // head_size]
const int num_kv_heads, const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
const int head_size) { // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
} }
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
__global__ void batched_rotary_embedding_kernel( __global__ void batched_rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] const int64_t* __restrict__ positions, // [batch_size, seq_len] or
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] // [num_tokens]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // head_size] or [num_tokens, num_heads,
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] // head_size]
const int rot_dim, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int64_t query_stride, // head_size] or [num_tokens, num_kv_heads,
const int64_t key_stride, // head_size]
const int num_heads, const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
const int num_kv_heads, // 2]
const int head_size) { const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
// or [num_tokens]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; const scalar_t* cache_ptr =
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
} }
} // namespace vllm } // namespace vllm
void rotary_embedding( void rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] // [num_tokens, num_heads * head_size]
int head_size, torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] // [num_tokens, num_kv_heads * head_size]
bool is_neox) { int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1); int64_t num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;
...@@ -135,36 +141,21 @@ void rotary_embedding( ...@@ -135,36 +141,21 @@ void rotary_embedding(
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
query.scalar_type(), if (is_neox) {
"rotary_embedding", vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
[&] { positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
if (is_neox) { key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( query_stride, key_stride, num_heads, num_kv_heads, head_size);
positions.data_ptr<int64_t>(), } else {
query.data_ptr<scalar_t>(), vllm::rotary_embedding_kernel<scalar_t, false>
key.data_ptr<scalar_t>(), <<<grid, block, 0, stream>>>(
cos_sin_cache.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
rot_dim, key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
query_stride, rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
key_stride, head_size);
num_heads, }
num_kv_heads, });
head_size);
} else {
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
} }
/* /*
...@@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together ...@@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together
and process in batched manner. and process in batched manner.
*/ */
void batched_rotary_embedding( void batched_rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] // [num_tokens, num_heads * head_size]
int head_size, torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] // [num_tokens, num_kv_heads * head_size]
bool is_neox, int head_size,
int rot_dim, torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
torch::Tensor& cos_sin_cache_offsets // [num_tokens] bool is_neox, int rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
) { ) {
int64_t num_tokens = cos_sin_cache_offsets.size(0); int64_t num_tokens = cos_sin_cache_offsets.size(0);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;
...@@ -191,36 +183,21 @@ void batched_rotary_embedding( ...@@ -191,36 +183,21 @@ void batched_rotary_embedding(
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
query.scalar_type(), if (is_neox) {
"rotary_embedding", vllm::batched_rotary_embedding_kernel<scalar_t, true>
[&] { <<<grid, block, 0, stream>>>(
if (is_neox) { positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
positions.data_ptr<int64_t>(), cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
query.data_ptr<scalar_t>(), key_stride, num_heads, num_kv_heads, head_size);
key.data_ptr<scalar_t>(), } else {
cos_sin_cache.data_ptr<scalar_t>(), vllm::batched_rotary_embedding_kernel<scalar_t, false>
cos_sin_cache_offsets.data_ptr<int64_t>(), <<<grid, block, 0, stream>>>(
rot_dim, positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
query_stride, key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
key_stride, cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
num_heads, key_stride, num_heads, num_kv_heads, head_size);
num_kv_heads, }
head_size); });
} else {
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
} }
...@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 2752) \ f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \ f(in_T, out_T, W_T, narrow, 2816) \
f(in_T, out_T, W_T, narrow, 3072) \ f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3328) \
f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \ f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \ f(in_T, out_T, W_T, narrow, 4096) \
...@@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \ f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \ f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6400) \
f(in_T, out_T, W_T, narrow, 6848) \ f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \ f(in_T, out_T, W_T, narrow, 7168) \
...@@ -53,6 +55,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -53,6 +55,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 22016) \ f(in_T, out_T, W_T, narrow, 22016) \
f(in_T, out_T, W_T, narrow, 24576) \ f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 27392) \ f(in_T, out_T, W_T, narrow, 27392) \
f(in_T, out_T, W_T, narrow, 27648) \
f(in_T, out_T, W_T, narrow, 28672) \ f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \ f(in_T, out_T, W_T, narrow, 32256) \
...@@ -96,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -96,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 2752, narrow) \ f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \ f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \ f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3328, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \ f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \ f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \ f(in_T, out_T, W_T, 4096, narrow) \
...@@ -104,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -104,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 5504, narrow) \ f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \ f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \ f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6400, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \ f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \ f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \ f(in_T, out_T, W_T, 7168, narrow) \
...@@ -121,6 +126,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -121,6 +126,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 22016, narrow) \ f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \ f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \ f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 27648, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \ f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \ f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \ f(in_T, out_T, W_T, 32256, narrow) \
......
#pragma once #pragma once
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#ifndef USE_ROCM
#include <cooperative_groups.h> #include <cooperative_groups.h>
#else
#include <hip/hip_cooperative_groups.h>
#endif
#ifndef USE_ROCM
#include <cuda/pipeline> #include <cuda/pipeline>
#endif
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream> #include <iostream>
#include <stdio.h> #include <stdio.h>
...@@ -11,6 +17,24 @@ ...@@ -11,6 +17,24 @@
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
#ifdef USE_ROCM
template <size_t len>
__host__ __device__
inline void* memcpy_blocking(void *dst, const void *src) {
// Does not handle the case of long datatypes
char *d = reinterpret_cast<char *>(dst);
const char *s = reinterpret_cast<const char *>(src);
size_t i = 0;
#pragma unroll
for (i = 0; i < len; ++i) {
d[i] = s[i];
}
return dst;
}
#endif
#ifndef USE_ROCM
// nthrs = (32, 4) // nthrs = (32, 4)
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size, template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T, size_t W_copy_size, int tx, int ty, int tz, typename in_T,
...@@ -141,6 +165,81 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -141,6 +165,81 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
} }
} }
#else
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
typename out_T, typename W_T>
__global__ void
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
size_t j = blockIdx.x;
constexpr size_t tile_size = tx * ty * vec_size;
constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size;
__shared__ float y_warpwise[ty];
float y = 0;
vec_t<in_T, vec_size> x_vec;
vec_t<W_T, vec_size> w_vec;
size_t tile_idx;
#pragma unroll
for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
x_vec.load(X + (batch_idx * feat_in) +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W + (idx * feat_out + j) * feat_in +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size);
}
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += VLLM_SHFL_DOWN_SYNC(sum, offset);
}
__syncthreads();
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
y += sum;
}
}
if (threadIdx.x == 0) {
y_warpwise[threadIdx.y] = y;
}
__syncthreads();
float y_write = 0.f;
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y_write += y_warpwise[i];
}
// write Y;
if (threadIdx.x == 0 && threadIdx.y == 0) {
size_t y_idx = batch_idx * full_y_size + y_offset + j;
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(y_write));
}
}
#endif
// nthrs = (2, 16, 4) // nthrs = (2, 16, 4)
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz, template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
typename in_T, typename out_T, typename W_T> typename in_T, typename out_T, typename W_T>
...@@ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
float sum = 0.f; float sum = 0.f;
#pragma unroll #pragma unroll
for (size_t i = 0; i < vec_size; ++i) { for (size_t i = 0; i < vec_size; ++i) {
#ifndef USE_ROCM
sum += float(w_vec[i]) * float(x_vec[i]) * scale; sum += float(w_vec[i]) * float(x_vec[i]) * scale;
#else
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
#endif
} }
cg::thread_block_tile g = cg::tiled_partition<tx>(block); cg::thread_block_tile g = cg::tiled_partition<tx>(block);
...@@ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
sum = g.shfl(sum, 0); sum = g.shfl(sum, 0);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
#ifndef USE_ROCM
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum); threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
#else
size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y;
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(sum));
#endif
} }
} }
...@@ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
scale); scale);
} }
} else { } else {
#ifndef USE_ROCM
static_assert(feat_in % (vec_size * 32) == 0 || static_assert(feat_in % (vec_size * 32) == 0 ||
feat_in % (vec_size * 16) == 0 || feat_in % (vec_size * 16) == 0 ||
feat_in % (vec_size * 8) == 0); feat_in % (vec_size * 8) == 0);
...@@ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
full_y_size, num_layers, layer_idx, full_y_size, num_layers, layer_idx,
scale); scale);
} }
#else
constexpr size_t rocm_warp_size = warpSize;
#define CHECK_INPUT_TILEABLE_BY(vec_size_) \
feat_in % (rocm_warp_size * vec_size_) == 0
#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \
if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \
constexpr size_t vec_size_shrink = vec_size_; \
constexpr int tx = tx_; \
constexpr int ty = ty_; \
dim3 nblks(feat_out, batch_size); \
dim3 nthrs(tx, ty); \
bgmv_shrink_kernel<feat_in, feat_out, vec_size_shrink, \
vec_size_shrink * sizeof(in_T), \
vec_size_shrink * sizeof(W_T), \
tx, ty, tz> \
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, \
full_y_size, num_layers, layer_idx, \
scale); \
}
static_assert(CHECK_INPUT_TILEABLE_BY(32) ||
CHECK_INPUT_TILEABLE_BY(16) ||
CHECK_INPUT_TILEABLE_BY( 8) ||
CHECK_INPUT_TILEABLE_BY( 4) ||
CHECK_INPUT_TILEABLE_BY( 2) ||
CHECK_INPUT_TILEABLE_BY( 1));
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1)
#undef CHECK_INPUT_TILEABLE_BY
#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM
#endif
} }
} }
......
#ifndef VEC_DTYPES_CUH_ #ifndef VEC_DTYPES_CUH_
#define VEC_DTYPES_CUH_ #define VEC_DTYPES_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifdef FLASHINFER_USE_FP8 #ifdef FLASHINFER_USE_FP8
#include <cuda_fp8.h> #include <cuda_fp8.h>
#endif #endif
...@@ -10,6 +8,9 @@ ...@@ -10,6 +8,9 @@
#include <type_traits> #include <type_traits>
#include "../type_convert.h"
#include "../../cuda_compat.h"
#define FLASHINFER_INLINE \ #define FLASHINFER_INLINE \
inline __attribute__((always_inline)) __device__ __host__ inline __attribute__((always_inline)) __device__ __host__
......
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <cstdint> #include <cstdint>
#include "type_convert.h"
#include "../cuda_compat.h"
#include "bgmv/bgmv_config.h" #include "bgmv/bgmv_config.h"
namespace {
//====== utils ====== //====== utils ======
...@@ -568,15 +567,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, ...@@ -568,15 +567,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
} }
} // namespace
//====== pybind ======
#define DEFINE_pybind(name) m.def(#name, &name, #name);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
"dispatch_bgmv_low_level");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment