Unverified Commit 5f6d10c1 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[CI/Build] Enforce style for C++ and CUDA code with `clang-format` (#4722)

parent 9b9a10d6
...@@ -2,28 +2,22 @@ ...@@ -2,28 +2,22 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#endif #endif
int get_device_attribute( int get_device_attribute(int attribute, int device_id) {
int attribute,
int device_id)
{
int device, value; int device, value;
if (device_id < 0) { if (device_id < 0) {
cudaGetDevice(&device); cudaGetDevice(&device);
} } else {
else {
device = device_id; device = device_id;
} }
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device); cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
device);
return value; return value;
} }
int get_max_shared_memory_per_block_device_attribute(int device_id) {
int get_max_shared_memory_per_block_device_attribute( int attribute;
int device_id) // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
{ // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
int attribute;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
#ifdef USE_ROCM #ifdef USE_ROCM
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
......
...@@ -7,11 +7,11 @@ ...@@ -7,11 +7,11 @@
// fake pointer type // fake pointer type
using fptr_t = uint64_t; using fptr_t = uint64_t;
static_assert(sizeof(void *) == sizeof(fptr_t)); static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int rank,
bool full_nvlink) { bool full_nvlink) {
int world_size = offsets.size(); int world_size = offsets.size();
if (world_size > 8) if (world_size > 8)
...@@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, ...@@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
} }
return (fptr_t) new vllm::CustomAllreduce( return (fptr_t) new vllm::CustomAllreduce(
reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(), reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
} }
...@@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, ...@@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
* 5. A[None].expand(2, -1, -1, -1): Not OK * 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK * 6. A[:, 1:, 1:]: Not OK
*/ */
bool _is_weak_contiguous(torch::Tensor &t) { bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() || return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size()); t.numel() * t.element_size());
} }
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
bool full_nvlink) { bool full_nvlink) {
auto inp_size = inp.numel() * inp.element_size(); auto inp_size = inp.numel() * inp.element_size();
// custom allreduce requires input byte size to be multiples of 16 // custom allreduce requires input byte size to be multiples of 16
...@@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, ...@@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
return false; return false;
} }
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
cudaStream_t stream) { cudaStream_t stream) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out)); TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) { switch (out.scalar_type()) {
case at::ScalarType::Float: { case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()), fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float *>(out.data_ptr()), reinterpret_cast<float*>(out.data_ptr()),
out.numel()); out.numel());
break; break;
} }
case at::ScalarType::Half: { case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()), fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half *>(out.data_ptr()), reinterpret_cast<half*>(out.data_ptr()), out.numel());
out.numel());
break; break;
} }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: { case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>( fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()), stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel()); reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
break; break;
} }
#endif #endif
...@@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, ...@@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
} }
} }
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream(); auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
...@@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { ...@@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
_all_reduce(_fa, inp, out, stream); _all_reduce(_fa, inp, out, stream);
} }
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer, void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor &out) { torch::Tensor& out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream(); auto stream = c10::cuda::getCurrentCUDAStream().stream();
...@@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer, ...@@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
} }
void dispose(fptr_t _fa) { void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
delete fa; delete fa;
} }
int meta_size() { return sizeof(vllm::Signal); } int meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor &t, void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets) { const std::vector<int64_t>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr()); fa->register_buffer(handles, offsets, t.data_ptr());
} }
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta( std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) { fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
return fa->get_graph_buffer_ipc_meta(); return fa->get_graph_buffer_ipc_meta();
} }
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles, void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>> &offsets) { const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets); fa->register_graph_buffers(handles, offsets);
} }
...@@ -31,9 +31,9 @@ struct Signal { ...@@ -31,9 +31,9 @@ struct Signal {
alignas(128) uint32_t end[kMaxBlocks][8]; alignas(128) uint32_t end[kMaxBlocks][8];
}; };
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
struct __align__(16) RankSignals { volatile Signal *signals[8]; }; struct __align__(16) RankSignals { volatile Signal* signals[8]; };
// like std::array, but aligned // like std::array, but aligned
template <typename T, int sz> template <typename T, int sz>
...@@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) { ...@@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) {
// scalar add functions // scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and // for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly // bfloat is disabled so we call the intrinsics directly
DINLINE half &assign_add(half &a, half b) { DINLINE half& assign_add(half& a, half b) {
a = __hadd(a, b); a = __hadd(a, b);
return a; return a;
} }
DINLINE float &assign_add(float &a, float b) { return a += b; } DINLINE float& assign_add(float& a, float b) { return a += b; }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
...@@ -80,14 +80,14 @@ template <> ...@@ -80,14 +80,14 @@ template <>
DINLINE nv_bfloat16 downcast_s(float val) { DINLINE nv_bfloat16 downcast_s(float val) {
return __float2bfloat16(val); return __float2bfloat16(val);
} }
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a = __hadd(a, b); a = __hadd(a, b);
return a; return a;
} }
#endif #endif
template <typename T, int N> template <typename T, int N>
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) { DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
#pragma unroll #pragma unroll
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
assign_add(a.data[i], b.data[i]); assign_add(a.data[i], b.data[i]);
...@@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) { ...@@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against // prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes. // other volatile writes.
template <int ngpus> template <int ngpus>
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) { int rank) {
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// reset flag for next time // reset flag for next time
...@@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, ...@@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write // Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// wait until we got true from all ranks // wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x]) while (!self_sg->start[blockIdx.x][threadIdx.x]);
;
} }
__syncthreads(); __syncthreads();
} }
...@@ -147,7 +146,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, ...@@ -147,7 +146,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier, // barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses. // we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false> template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) { int rank) {
__syncthreads(); __syncthreads();
// eliminate the case that prior writes are not visible after signals become // eliminate the case that prior writes are not visible after signals become
...@@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, ...@@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write // Latency = 1 p2p write
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
// wait until we got true from all ranks // wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x]) while (!self_sg->end[blockIdx.x][threadIdx.x]);
;
} }
if constexpr (!final_sync) __syncthreads(); if constexpr (!final_sync) __syncthreads();
} }
template <typename P, int ngpus, typename A> template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P *ptrs[], int idx) { DINLINE P packed_reduce(const P* ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]); A tmp = upcast(ptrs[0][idx]);
#pragma unroll #pragma unroll
for (int i = 1; i < ngpus; i++) { for (int i = 1; i < ngpus; i++) {
...@@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) { ...@@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
template <typename T, int ngpus> template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) __global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData *_dp, RankSignals sg, cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result, volatile Signal* self_sg, T* __restrict__ result,
int rank, int size) { int rank, int size) {
using P = typename packed_t<T>::P; using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A; using A = typename packed_t<T>::A;
...@@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1) ...@@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1)
// do the actual reduction // do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
((P *)result)[idx] = ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
} }
end_sync<ngpus, true>(sg, self_sg, rank); end_sync<ngpus, true>(sg, self_sg, rank);
} }
template <typename P> template <typename P>
DINLINE P *get_tmp_buf(volatile Signal *sg) { DINLINE P* get_tmp_buf(volatile Signal* sg) {
return (P *)(((Signal *)sg) + 1); return (P*)(((Signal*)sg) + 1);
} }
template <typename T, int ngpus> template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) __global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData *_dp, RankSignals sg, cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result, volatile Signal* self_sg, T* __restrict__ result,
int rank, int size) { int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x; int stride = gridDim.x * blockDim.x;
...@@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1) ...@@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1)
int start = rank * part; int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part; int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus; int largest_part = part + size % ngpus;
const P *ptrs[ngpus]; const P* ptrs[ngpus];
P *tmps[ngpus]; P* tmps[ngpus];
#pragma unroll #pragma unroll
for (int i = 0; i < ngpus; i++) { for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus; int target = (rank + i) % ngpus;
ptrs[i] = (const P *)_dp->ptrs[target]; ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]); tmps[i] = get_tmp_buf<P>(sg.signals[target]);
} }
auto tmp_out = tmps[0]; auto tmp_out = tmps[0];
...@@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1) ...@@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1)
int gather_from_rank = ((rank + i) % ngpus); int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) { if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx; int dst_idx = gather_from_rank * part + idx;
((P *)result)[dst_idx] = tmps[i][idx]; ((P*)result)[dst_idx] = tmps[i][idx];
} }
} }
} }
...@@ -261,14 +258,14 @@ class CustomAllreduce { ...@@ -261,14 +258,14 @@ class CustomAllreduce {
// below are device pointers // below are device pointers
RankSignals sg_; RankSignals sg_;
std::unordered_map<void *, RankData *> buffers_; std::unordered_map<void*, RankData*> buffers_;
Signal *self_sg_; Signal* self_sg_;
// stores the registered device pointers from all ranks // stores the registered device pointers from all ranks
RankData *d_rank_data_base_, *d_rank_data_end_; RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void *> graph_unreg_buffers_; std::vector<void*> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers // a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char *> ipc_handles_; std::map<IPC_KEY, char*> ipc_handles_;
/** /**
* meta is a pointer to device metadata and temporary buffer for allreduce. * meta is a pointer to device metadata and temporary buffer for allreduce.
...@@ -279,22 +276,22 @@ class CustomAllreduce { ...@@ -279,22 +276,22 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers * note: this class does not own any device memory. Any required buffers
* are passed in from the constructor * are passed in from the constructor
*/ */
CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz, CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz,
const cudaIpcMemHandle_t *handles, const cudaIpcMemHandle_t* handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int rank,
bool full_nvlink = true) bool full_nvlink = true)
: rank_(rank), : rank_(rank),
world_size_(offsets.size()), world_size_(offsets.size()),
full_nvlink_(full_nvlink), full_nvlink_(full_nvlink),
self_sg_(meta), self_sg_(meta),
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)), d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) { for (int i = 0; i < world_size_; i++) {
Signal *rank_sg; Signal* rank_sg;
if (i != rank_) { if (i != rank_) {
char *handle = open_ipc_handle(&handles[i]); char* handle = open_ipc_handle(&handles[i]);
handle += offsets[i]; handle += offsets[i];
rank_sg = (Signal *)handle; rank_sg = (Signal*)handle;
} else { } else {
rank_sg = self_sg_; rank_sg = self_sg_;
} }
...@@ -302,13 +299,13 @@ class CustomAllreduce { ...@@ -302,13 +299,13 @@ class CustomAllreduce {
} }
} }
char *open_ipc_handle(const void *ipc_handle) { char* open_ipc_handle(const void* ipc_handle) {
auto [it, new_handle] = auto [it, new_handle] =
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr}); ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) { if (new_handle) {
char *ipc_ptr; char* ipc_ptr;
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
*((const cudaIpcMemHandle_t *)ipc_handle), *((const cudaIpcMemHandle_t*)ipc_handle),
cudaIpcMemLazyEnablePeerAccess)); cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr; it->second = ipc_ptr;
} }
...@@ -323,7 +320,7 @@ class CustomAllreduce { ...@@ -323,7 +320,7 @@ class CustomAllreduce {
std::vector<int64_t> offsets(num_buffers); std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) { for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i]; auto ptr = graph_unreg_buffers_[i];
void *base_ptr; void* base_ptr;
// note: must share the base address of each allocation, or we get wrong // note: must share the base address of each allocation, or we get wrong
// address // address
if (cuPointerGetAttribute(&base_ptr, if (cuPointerGetAttribute(&base_ptr,
...@@ -331,8 +328,8 @@ class CustomAllreduce { ...@@ -331,8 +328,8 @@ class CustomAllreduce {
(CUdeviceptr)ptr) != CUDA_SUCCESS) (CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr"); throw std::runtime_error("failed to get pointer attr");
CUDACHECK(cudaIpcGetMemHandle( CUDACHECK(cudaIpcGetMemHandle(
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char *)ptr) - ((char *)base_ptr); offsets[i] = ((char*)ptr) - ((char*)base_ptr);
} }
return std::make_pair(handles, offsets); return std::make_pair(handles, offsets);
} }
...@@ -344,13 +341,13 @@ class CustomAllreduce { ...@@ -344,13 +341,13 @@ class CustomAllreduce {
std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
} }
void register_buffer(const std::vector<std::string> &handles, void register_buffer(const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, void *self) { const std::vector<int64_t>& offsets, void* self) {
check_rank_data_capacity(); check_rank_data_capacity();
RankData data; RankData data;
for (int i = 0; i < world_size_; i++) { for (int i = 0; i < world_size_; i++) {
if (i != rank_) { if (i != rank_) {
char *handle = open_ipc_handle(handles[i].data()); char* handle = open_ipc_handle(handles[i].data());
handle += offsets[i]; handle += offsets[i];
data.ptrs[i] = handle; data.ptrs[i] = handle;
} else { } else {
...@@ -371,17 +368,17 @@ class CustomAllreduce { ...@@ -371,17 +368,17 @@ class CustomAllreduce {
// got a different address. IPC handles have internal reference counting // got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small. // mechanism so overhead should be small.
void register_graph_buffers( void register_graph_buffers(
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>> &offsets) { const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size(); auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers); check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers); std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) { for (int i = 0; i < num_buffers; i++) {
auto self_ptr = graph_unreg_buffers_[i]; auto self_ptr = graph_unreg_buffers_[i];
auto &rd = rank_data[i]; auto& rd = rank_data[i];
for (int j = 0; j < world_size_; j++) { for (int j = 0; j < world_size_; j++) {
if (j != rank_) { if (j != rank_) {
char *handle = char* handle =
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i]; handle += offsets[j][i];
rd.ptrs[j] = handle; rd.ptrs[j] = handle;
...@@ -405,7 +402,7 @@ class CustomAllreduce { ...@@ -405,7 +402,7 @@ class CustomAllreduce {
* will cause contention on NVLink bus. * will cause contention on NVLink bus.
*/ */
template <typename T> template <typename T>
void allreduce(cudaStream_t stream, T *input, T *output, int size, void allreduce(cudaStream_t stream, T* input, T* output, int size,
int threads = 512, int block_limit = 36) { int threads = 512, int block_limit = 36) {
auto d = packed_t<T>::P::size; auto d = packed_t<T>::P::size;
if (size % d != 0) if (size % d != 0)
...@@ -418,7 +415,7 @@ class CustomAllreduce { ...@@ -418,7 +415,7 @@ class CustomAllreduce {
std::to_string(kMaxBlocks) + ". Got " + std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit)); std::to_string(block_limit));
RankData *ptrs; RankData* ptrs;
cudaStreamCaptureStatus status; cudaStreamCaptureStatus status;
CUDACHECK(cudaStreamIsCapturing(stream, &status)); CUDACHECK(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) { if (status == cudaStreamCaptureStatusActive) {
......
...@@ -48,7 +48,7 @@ __global__ void dummy_kernel() { ...@@ -48,7 +48,7 @@ __global__ void dummy_kernel() {
} }
template <typename T> template <typename T>
__global__ void set_data(T *data, int size, int myRank) { __global__ void set_data(T* data, int size, int myRank) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
data[idx] = myRank * 0.11f; data[idx] = myRank * 0.11f;
...@@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) { ...@@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) {
} }
template <typename T> template <typename T>
__global__ void convert_data(const T *data1, const T *data2, double *fdata1, __global__ void convert_data(const T* data1, const T* data2, double* fdata1,
double *fdata2, int size) { double* fdata2, int size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
fdata1[idx] = data1[idx]; fdata1[idx] = data1[idx];
...@@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1, ...@@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1,
} }
} }
__global__ void init_rand(curandState_t *state, int size, int nRanks) { __global__ void init_rand(curandState_t* state, int size, int nRanks) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
for (int i = 0; i < nRanks; i++) { for (int i = 0; i < nRanks; i++) {
...@@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) { ...@@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) {
} }
template <typename T> template <typename T>
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth, __global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
int myRank, int nRanks, int size) { int myRank, int nRanks, int size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
...@@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth, ...@@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
} }
template <typename T> template <typename T>
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
int data_size, bool performance_test) { int data_size, bool performance_test) {
T *result; T* result;
cudaStream_t stream; cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
...@@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, ...@@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
cudaIpcMemHandle_t self_data_handle; cudaIpcMemHandle_t self_data_handle;
cudaIpcMemHandle_t data_handles[8]; cudaIpcMemHandle_t data_handles[8];
vllm::Signal *buffer; vllm::Signal* buffer;
T *self_data_copy; T* self_data_copy;
/** /**
* Allocate IPC buffer * Allocate IPC buffer
* *
...@@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, ...@@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
MPI_BYTE, MPI_COMM_WORLD)); MPI_BYTE, MPI_COMM_WORLD));
void *rank_data; void* rank_data;
size_t rank_data_sz = 16 * 1024 * 1024; size_t rank_data_sz = 16 * 1024 * 1024;
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
std::vector<int64_t> offsets(nRanks, 0); std::vector<int64_t> offsets(nRanks, 0);
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
offsets, myRank); offsets, myRank);
auto *self_data = auto* self_data =
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) + reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) +
sizeof(vllm::Signal) + data_size * sizeof(T)); sizeof(vllm::Signal) + data_size * sizeof(T));
// hack buffer registration // hack buffer registration
{ {
std::vector<std::string> handles; std::vector<std::string> handles;
handles.reserve(nRanks); handles.reserve(nRanks);
for (int i = 0; i < nRanks; i++) { for (int i = 0; i < nRanks; i++) {
char *begin = (char *)&data_handles[i]; char* begin = (char*)&data_handles[i];
char *end = (char *)&data_handles[i + 1]; char* end = (char*)&data_handles[i + 1];
handles.emplace_back(begin, end); handles.emplace_back(begin, end);
} }
std::vector<int64_t> offsets(nRanks, std::vector<int64_t> offsets(nRanks,
...@@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, ...@@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
fa.register_buffer(handles, offsets, self_data); fa.register_buffer(handles, offsets, self_data);
} }
double *ground_truth; double* ground_truth;
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
curandState_t *states; curandState_t* states;
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
...@@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, ...@@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
CUDACHECK(cudaStreamDestroy(stream)); CUDACHECK(cudaStreamDestroy(stream));
} }
int main(int argc, char **argv) { int main(int argc, char** argv) {
int nRanks, myRank; int nRanks, myRank;
MPICHECK(MPI_Init(&argc, &argv)); MPICHECK(MPI_Init(&argc, &argv));
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
...@@ -296,7 +296,7 @@ int main(int argc, char **argv) { ...@@ -296,7 +296,7 @@ int main(int argc, char **argv) {
ncclUniqueId id; ncclUniqueId id;
ncclComm_t comm; ncclComm_t comm;
if (myRank == 0) ncclGetUniqueId(&id); if (myRank == 0) ncclGetUniqueId(&id);
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0, MPICHECK(MPI_Bcast(static_cast<void*>(&id), sizeof(id), MPI_BYTE, 0,
MPI_COMM_WORLD)); MPI_COMM_WORLD));
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
......
...@@ -12,8 +12,7 @@ ...@@ -12,8 +12,7 @@
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
...@@ -22,8 +21,8 @@ ...@@ -22,8 +21,8 @@
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
...@@ -33,5 +32,4 @@ ...@@ -33,5 +32,4 @@
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
...@@ -11,26 +11,24 @@ ...@@ -11,26 +11,24 @@
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162; using __nv_bfloat162 = __hip_bfloat162;
#endif #endif
namespace vllm { namespace vllm {
// TODO(woosuk): Further optimize this kernel. // TODO(woosuk): Further optimize this kernel.
template<typename scalar_t> template <typename scalar_t>
__global__ void rms_norm_kernel( __global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float) input[blockIdx.x * hidden_size + idx]; const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x; variance += x * x;
} }
variance = blockReduceSum<float>(variance); variance = blockReduceSum<float>(variance);
...@@ -40,12 +38,12 @@ __global__ void rms_norm_kernel( ...@@ -40,12 +38,12 @@ __global__ void rms_norm_kernel(
__syncthreads(); __syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) input[blockIdx.x * hidden_size + idx]; float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
} }
} }
/* Converter structs for the conversion from torch types to HIP/CUDA types, /* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion to be implemented for now because the relevant type conversion
...@@ -56,46 +54,63 @@ __global__ void rms_norm_kernel( ...@@ -56,46 +54,63 @@ __global__ void rms_norm_kernel(
If false, the optimized kernel is not used for the corresponding torch type. If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below. If true, the struct should be fully defined as shown in the examples below.
*/ */
template<typename torch_type> template <typename torch_type>
struct _typeConvert { static constexpr bool exists = false; }; struct _typeConvert {
static constexpr bool exists = false;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion // CUDA < 12.0 runs into issues with packed type conversion
template<> template <>
struct _typeConvert<c10::Half> { struct _typeConvert<c10::Half> {
static constexpr bool exists = true; static constexpr bool exists = true;
using hip_type = __half; using hip_type = __half;
using packed_hip_type = __half2; using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); } __device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } __device__ static inline float2 convert(packed_hip_type x) {
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); } return __half22float2(x);
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } }
__device__ static inline hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
}; };
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support // CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely // TODO: Add in ROCm support once public headers handle bf16 maturely
template<> template <>
struct _typeConvert<c10::BFloat16> { struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true; static constexpr bool exists = true;
using hip_type = __nv_bfloat16; using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162; using packed_hip_type = __nv_bfloat162;
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); } __device__ static inline float convert(hip_type x) {
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } return __bfloat162float(x);
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } }
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } __device__ static inline float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
}; };
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) #endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops /* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel. for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented. Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops. Alignment to 16 bytes is required to use 128-bit global memory ops.
*/ */
template<typename scalar_t, int width> template <typename scalar_t, int width>
struct alignas(16) _f16Vec { struct alignas(16) _f16Vec {
/* Not theoretically necessary that width is a power of 2 but should /* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */ almost always be the case for optimization purposes */
...@@ -108,51 +123,49 @@ struct alignas(16) _f16Vec { ...@@ -108,51 +123,49 @@ struct alignas(16) _f16Vec {
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) { __device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i+1]}; T2 temp{data[i], data[i + 1]};
temp += T2{other.data[i], other.data[i+1]}; temp += T2{other.data[i], other.data[i + 1]};
data[i] = temp.x; data[i] = temp.x;
data[i+1] = temp.y; data[i + 1] = temp.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) for (int i = 0; i < width; ++i) data[i] += other.data[i];
data[i] += other.data[i];
} }
return *this; return *this;
} }
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) { __device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i+1]}; T2 temp{data[i], data[i + 1]};
temp *= T2{other.data[i], other.data[i+1]}; temp *= T2{other.data[i], other.data[i + 1]};
data[i] = temp.x; data[i] = temp.x;
data[i+1] = temp.y; data[i + 1] = temp.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) for (int i = 0; i < width; ++i) data[i] *= other.data[i];
data[i] *= other.data[i];
} }
return *this; return *this;
} }
__device__ _f16Vec& operator*=(const float scale) { __device__ _f16Vec& operator*=(const float scale) {
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
float2 temp_f = Converter::convert(T2{data[i], data[i+1]}); float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
temp_f.x *= scale; temp_f.x *= scale;
temp_f.y *= scale; temp_f.y *= scale;
T2 temp = Converter::convert(temp_f); T2 temp = Converter::convert(temp_f);
data[i] = temp.x; data[i] = temp.x;
data[i+1] = temp.y; data[i + 1] = temp.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
float temp = Converter::convert(data[i]) * scale; float temp = Converter::convert(data[i]) * scale;
data[i] = Converter::convert(temp); data[i] = Converter::convert(temp);
...@@ -164,13 +177,13 @@ struct alignas(16) _f16Vec { ...@@ -164,13 +177,13 @@ struct alignas(16) _f16Vec {
__device__ float sum_squares() const { __device__ float sum_squares() const {
float result = 0.0f; float result = 0.0f;
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i+1]}); float2 z = Converter::convert(T2{data[i], data[i + 1]});
result += z.x * z.x + z.y * z.y; result += z.x * z.x + z.y * z.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
float x = Converter::convert(data[i]); float x = Converter::convert(data[i]);
result += x * x; result += x * x;
...@@ -184,15 +197,13 @@ struct alignas(16) _f16Vec { ...@@ -184,15 +197,13 @@ struct alignas(16) _f16Vec {
Additional optimizations we can make in this case are Additional optimizations we can make in this case are
packed and vectorized operations, which help with the packed and vectorized operations, which help with the
memory latency bottleneck. */ memory latency bottleneck. */
template<typename scalar_t, int width> template <typename scalar_t, int width>
__global__ std::enable_if_t< __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic // Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>); static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width); static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
...@@ -203,9 +214,12 @@ __global__ std::enable_if_t< ...@@ -203,9 +214,12 @@ __global__ std::enable_if_t<
/* These and the argument pointers are all declared `restrict` as they are /* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */ in this kernel as that would be undefined behavior */
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input); auto* __restrict__ input_v =
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual); reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight); auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx; int id = blockIdx.x * vec_hidden_size + idx;
...@@ -218,7 +232,8 @@ __global__ std::enable_if_t< ...@@ -218,7 +232,8 @@ __global__ std::enable_if_t<
calculation of max_block_size in fused_add_rms_norm */ calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) { if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance); variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance); } else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon); s_variance = rsqrtf(variance / hidden_size + epsilon);
} }
...@@ -233,26 +248,23 @@ __global__ std::enable_if_t< ...@@ -233,26 +248,23 @@ __global__ std::enable_if_t<
} }
} }
/* Generic fused_add_rms_norm_kernel /* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations. The width field is not used here but necessary for other specializations.
*/ */
template<typename scalar_t, int width> template <typename scalar_t, int width>
__global__ std::enable_if_t< __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx]; scalar_t z = input[blockIdx.x * hidden_size + idx];
z += residual[blockIdx.x * hidden_size + idx]; z += residual[blockIdx.x * hidden_size + idx];
float x = (float) z; float x = (float)z;
variance += x * x; variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z; residual[blockIdx.x * hidden_size + idx] = z;
} }
...@@ -260,22 +272,23 @@ __global__ std::enable_if_t< ...@@ -260,22 +272,23 @@ __global__ std::enable_if_t<
calculation of max_block_size in fused_add_rms_norm */ calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) { if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance); variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance); } else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon); s_variance = rsqrtf(variance / hidden_size + epsilon);
} }
__syncthreads(); __syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) residual[blockIdx.x * hidden_size + idx]; float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; input[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
} }
} }
} // namespace vllm } // namespace vllm
void rms_norm( void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
float epsilon) { float epsilon) {
...@@ -286,37 +299,24 @@ void rms_norm( ...@@ -286,37 +299,24 @@ void rms_norm(
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
input.scalar_type(),
"rms_norm_kernel",
[&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
}); });
} }
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
"fused_add_rms_norm_kernel", \ vllm::fused_add_rms_norm_kernel<scalar_t, width> \
[&] { \ <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
vllm::fused_add_rms_norm_kernel \
<scalar_t, width><<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \ residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), \ weight.data_ptr<scalar_t>(), epsilon, \
epsilon, \ num_tokens, hidden_size); \
num_tokens, \
hidden_size); \
}); });
void fused_add_rms_norm( void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
float epsilon) { float epsilon) {
...@@ -342,8 +342,8 @@ void fused_add_rms_norm( ...@@ -342,8 +342,8 @@ void fused_add_rms_norm(
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr()); auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr()); auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr()); auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ bool ptrs_are_aligned =
&& wt_ptr % 16 == 0; inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) { if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8); LAUNCH_FUSED_ADD_RMS_NORM(8);
} else { } else {
......
...@@ -3,5 +3,6 @@ ...@@ -3,5 +3,6 @@
#include <torch/extension.h> #include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); m.def("topk_softmax", &topk_softmax,
"Apply topk softmax to the gating outputs.");
} }
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
void topk_softmax( void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices, torch::Tensor& token_expert_indices,
torch::Tensor& gating_output); torch::Tensor& gating_output);
...@@ -7,32 +7,35 @@ ...@@ -7,32 +7,35 @@
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) #define CEILDIV(x, y) (((x) + (y) - 1) / (y))
namespace vllm { namespace vllm {
namespace { namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
int32_t col) {
// don't worry about overflow because num_experts is relatively small // don't worry about overflow because num_experts is relatively small
return row * total_col + col; return row * total_col + col;
} }
} } // namespace
template <typename scalar_t> template <typename scalar_t>
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
int32_t *sorted_token_ids, int32_t* sorted_token_ids,
int32_t *expert_ids, int32_t* expert_ids,
int32_t *total_tokens_post_pad, int32_t* total_tokens_post_pad,
int32_t num_experts, int32_t num_experts,
int32_t block_size, int32_t block_size, size_t numel) {
size_t numel) {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread; const size_t start_idx = threadIdx.x * tokens_per_thread;
extern __shared__ int32_t shared_mem[]; extern __shared__ int32_t shared_mem[];
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) int32_t* tokens_cnts =
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
int32_t* cumsum =
shared_mem + (num_experts + 1) *
num_experts; // 1d tensor with shape (num_experts + 1)
for (int i = 0; i < num_experts; ++i) { for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
...@@ -40,8 +43,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, ...@@ -40,8 +43,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
/** /**
* In the first step we compute token_cnts[thread_index + 1][expert_index], * In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are assigned * which counts how many tokens in the token shard of thread_index are
* to expert expert_index. * assigned to expert expert_index.
*/ */
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
...@@ -52,7 +55,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, ...@@ -52,7 +55,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
// For each expert we accumulate the token counts from the different threads. // For each expert we accumulate the token counts from the different threads.
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) { for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
} }
__syncthreads(); __syncthreads();
...@@ -61,7 +65,10 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, ...@@ -61,7 +65,10 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
cumsum[0] = 0; cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) { for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; cumsum[i] = cumsum[i - 1] +
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
} }
*total_tokens_post_pad = cumsum[num_experts]; *total_tokens_post_pad = cumsum[num_experts];
} }
...@@ -69,57 +76,59 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, ...@@ -69,57 +76,59 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
__syncthreads(); __syncthreads();
/** /**
* For each expert, each thread processes the tokens of the corresponding blocks * For each expert, each thread processes the tokens of the corresponding
* and stores the corresponding expert_id for each block. * blocks and stores the corresponding expert_id for each block.
*/ */
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) { for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x; expert_ids[i / block_size] = threadIdx.x;
} }
/** /**
* Each thread processes a token shard, calculating the index of each token after * Each thread processes a token shard, calculating the index of each token
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and * after sorting by expert number. Given the example topk_ids =
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* where * represents a padding value(preset in python). * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/ */
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i]; int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the /** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] * expert with expert_id needs to process, and
* stores the indices of the tokens processed by the expert with expert_id within * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* the current thread's token shard. * processed by the expert with expert_id within the current thread's token
* shard.
*/ */
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i; sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
} }
} }
} } // namespace vllm
void moe_align_block_size( void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
torch::Tensor topk_ids, int block_size, torch::Tensor sorted_token_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) { torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); // tensors
const int32_t shared_mem =
((num_experts + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);
// set dynamic shared mem // set dynamic shared mem
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>; auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK( AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem)); (void*)kernel, shared_mem));
kernel<<<1, num_experts, shared_mem, stream>>>( kernel<<<1, num_experts, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
num_experts,
block_size,
topk_ids.numel()); topk_ids.numel());
}); });
} }
...@@ -2,224 +2,136 @@ ...@@ -2,224 +2,136 @@
#include <torch/extension.h> #include <torch/extension.h>
void paged_attention_v1( void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& out, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& query, int num_kv_heads, float scale,
torch::Tensor& key_cache, torch::Tensor& block_tables, torch::Tensor& seq_lens,
torch::Tensor& value_cache, int block_size, int max_seq_len,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype, float kv_scale);
float kv_scale);
void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
void paged_attention_v2( torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& exp_sums, torch::Tensor& value_cache, int num_kv_heads,
torch::Tensor& max_logits, float scale, torch::Tensor& block_tables,
torch::Tensor& tmp_out, torch::Tensor& seq_lens, int block_size,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int block_size,
int max_seq_len, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype, float kv_scale);
float kv_scale);
void rms_norm( void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& weight,
float epsilon); float epsilon);
void fused_add_rms_norm( void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& input, torch::Tensor& weight, float epsilon);
torch::Tensor& residual,
torch::Tensor& weight,
float epsilon);
void rotary_embedding( void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& positions, torch::Tensor& key, int head_size,
torch::Tensor& query, torch::Tensor& cos_sin_cache, bool is_neox);
torch::Tensor& key,
int head_size, void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& cos_sin_cache, torch::Tensor& key, int head_size,
bool is_neox); torch::Tensor& cos_sin_cache, bool is_neox,
void batched_rotary_embedding(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox,
int rot_dim, int rot_dim,
torch::Tensor& cos_sin_cache_offsets); torch::Tensor& cos_sin_cache_offsets);
void silu_and_mul( void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_and_mul( void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_tanh_and_mul( void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_new( void gelu_new(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast( void gelu_fast(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
#ifndef USE_ROCM #ifndef USE_ROCM
torch::Tensor aqlm_gemm( torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& scales, const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes, const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias const std::optional<torch::Tensor>& bias);
);
torch::Tensor aqlm_dequant( torch::Tensor aqlm_dequant(const torch::Tensor& codes,
const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes const torch::Tensor& codebook_partition_sizes);
);
torch::Tensor awq_gemm( torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _in_feats, torch::Tensor _scaling_factors, torch::Tensor _zeros,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters); int split_k_iters);
torch::Tensor awq_dequantize( torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _scaling_factors,
torch::Tensor _zeros, torch::Tensor _zeros, int split_k_iters, int thx,
int split_k_iters,
int thx,
int thy); int thy);
torch::Tensor marlin_gemm( torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& a, torch::Tensor& b_scales, torch::Tensor& workspace,
torch::Tensor& b_q_weight, int64_t size_m, int64_t size_n, int64_t size_k);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales, torch::Tensor& b_scales,
torch::Tensor& workspace, torch::Tensor& workspace, int64_t num_bits,
int64_t size_m, int64_t size_m, int64_t size_n,
int64_t size_n,
int64_t size_k); int64_t size_k);
torch::Tensor gptq_marlin_24_gemm( torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor &a, torch::Tensor& b_scales, torch::Tensor& g_idx,
torch::Tensor &b_q_weight, torch::Tensor& perm, torch::Tensor& workspace,
torch::Tensor &b_meta, int64_t num_bits, int64_t size_m, int64_t size_n,
torch::Tensor &b_scales, int64_t size_k, bool is_k_full);
torch::Tensor &workspace,
int64_t num_bits,
int64_t size_m,
int64_t size_n,
int64_t size_k);
torch::Tensor gptq_marlin_gemm( torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
torch::Tensor &a, int64_t size_k, int64_t size_n,
torch::Tensor &b_q_weight,
torch::Tensor &b_scales,
torch::Tensor &g_idx,
torch::Tensor &perm,
torch::Tensor &workspace,
int64_t num_bits,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full);
torch::Tensor gptq_marlin_repack(
torch::Tensor &b_q_weight,
torch::Tensor &perm,
int64_t size_k,
int64_t size_n,
int64_t num_bits); int64_t num_bits);
int cutlass_scaled_mm_dq( int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor& out, torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const &a, torch::Tensor const& b_scales);
torch::Tensor const &b,
torch::Tensor const &a_scales,
torch::Tensor const &b_scales);
#endif #endif
void squeezellm_gemm( void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table); torch::Tensor lookup_table);
torch::Tensor gptq_gemm( torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
torch::Tensor b_g_idx, bool use_exllama, int bit);
bool use_exllama,
int bit); void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit);
void gptq_shuffle( void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor q_weight,
torch::Tensor q_perm,
int bit);
void static_scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale); torch::Tensor& scale);
void dynamic_scaled_fp8_quant( void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale); torch::Tensor& scale);
void moe_align_block_size( void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
torch::Tensor topk_ids, int block_size, torch::Tensor sorted_token_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM #ifndef USE_ROCM
using fptr_t = uint64_t; using fptr_t = uint64_t;
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int rank,
bool full_nvlink); bool full_nvlink);
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
bool full_nvlink); bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer, void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor &out); torch::Tensor& out);
void dispose(fptr_t _fa); void dispose(fptr_t _fa);
int meta_size(); int meta_size();
void register_buffer(fptr_t _fa, torch::Tensor &t, void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets); const std::vector<int64_t>& offsets);
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa); std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles, fptr_t _fa);
const std::vector<std::vector<int64_t>> &offsets); void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets);
#endif #endif
...@@ -7,14 +7,10 @@ ...@@ -7,14 +7,10 @@
namespace vllm { namespace vllm {
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding( inline __device__ void apply_token_rotary_embedding(
scalar_t* __restrict__ arr, scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index; int x_index, y_index;
scalar_t cos, sin; scalar_t cos, sin;
if (IS_NEOX) { if (IS_NEOX) {
...@@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding( ...@@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding(
arr[y_index] = y * cos + x * sin; arr[y_index] = y * cos + x * sin;
} }
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding( inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] // head_size] or [num_tokens, num_heads,
const scalar_t* cache_ptr, // head_size]
const int head_size, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int num_heads, // head_size] or [num_tokens, num_kv_heads,
const int num_kv_heads, // head_size]
const int rot_dim, const scalar_t* cache_ptr, const int head_size, const int num_heads,
const int token_idx, const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t query_stride, const int64_t key_stride) {
const int64_t key_stride)
{
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr; const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim; const scalar_t* sin_ptr = cache_ptr + embed_dim;
...@@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding( ...@@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, apply_token_rotary_embedding<scalar_t, IS_NEOX>(
sin_ptr, rot_offset, embed_dim); query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
} }
const int nk = num_kv_heads * embed_dim; const int nk = num_kv_heads * embed_dim;
...@@ -68,59 +62,71 @@ inline __device__ void apply_rotary_embedding( ...@@ -68,59 +62,71 @@ inline __device__ void apply_rotary_embedding(
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, apply_token_rotary_embedding<scalar_t, IS_NEOX>(
sin_ptr, rot_offset, embed_dim); key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
} }
} }
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel( __global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] const int64_t* __restrict__ positions, // [batch_size, seq_len] or
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] // [num_tokens]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // head_size] or [num_tokens, num_heads,
const int rot_dim, // head_size]
const int64_t query_stride, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int64_t key_stride, // head_size] or [num_tokens, num_kv_heads,
const int num_heads, // head_size]
const int num_kv_heads, const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
const int head_size) { // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
} }
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
__global__ void batched_rotary_embedding_kernel( __global__ void batched_rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] const int64_t* __restrict__ positions, // [batch_size, seq_len] or
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] // [num_tokens]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // head_size] or [num_tokens, num_heads,
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] // head_size]
const int rot_dim, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int64_t query_stride, // head_size] or [num_tokens, num_kv_heads,
const int64_t key_stride, // head_size]
const int num_heads, const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
const int num_kv_heads, // 2]
const int head_size) { const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
// or [num_tokens]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; const scalar_t* cache_ptr =
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
} }
} // namespace vllm } // namespace vllm
void rotary_embedding( void rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int head_size, int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) { bool is_neox) {
...@@ -135,33 +141,18 @@ void rotary_embedding( ...@@ -135,33 +141,18 @@ void rotary_embedding(
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
query.scalar_type(),
"rotary_embedding",
[&] {
if (is_neox) { if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
query.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
key.data_ptr<scalar_t>(), query_stride, key_stride, num_heads, num_kv_heads, head_size);
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else { } else {
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>( vllm::rotary_embedding_kernel<scalar_t, false>
positions.data_ptr<int64_t>(), <<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size); head_size);
} }
}); });
...@@ -173,12 +164,13 @@ and process in batched manner. ...@@ -173,12 +164,13 @@ and process in batched manner.
*/ */
void batched_rotary_embedding( void batched_rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int head_size, int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, bool is_neox, int rot_dim,
int rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens] torch::Tensor& cos_sin_cache_offsets // [num_tokens]
) { ) {
int64_t num_tokens = cos_sin_cache_offsets.size(0); int64_t num_tokens = cos_sin_cache_offsets.size(0);
...@@ -191,36 +183,21 @@ void batched_rotary_embedding( ...@@ -191,36 +183,21 @@ void batched_rotary_embedding(
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
query.scalar_type(),
"rotary_embedding",
[&] {
if (is_neox) { if (is_neox) {
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( vllm::batched_rotary_embedding_kernel<scalar_t, true>
positions.data_ptr<int64_t>(), <<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
cos_sin_cache_offsets.data_ptr<int64_t>(), key_stride, num_heads, num_kv_heads, head_size);
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else { } else {
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>( vllm::batched_rotary_embedding_kernel<scalar_t, false>
positions.data_ptr<int64_t>(), <<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
cos_sin_cache_offsets.data_ptr<int64_t>(), key_stride, num_heads, num_kv_heads, head_size);
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} }
}); });
} }
...@@ -8,114 +8,85 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -8,114 +8,85 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops // Attention ops
ops.def( ops.def("paged_attention_v1", &paged_attention_v1,
"paged_attention_v1", "Compute the attention between an input query and the cached "
&paged_attention_v1, "keys/values using PagedAttention.");
"Compute the attention between an input query and the cached keys/values using PagedAttention."); ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops // Activation ops
ops.def( ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
"silu_and_mul", ops.def("gelu_and_mul", &gelu_and_mul,
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU with `none` approximation."); "Activation function used in GeGLU with `none` approximation.");
ops.def( ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation."); "Activation function used in GeGLU with `tanh` approximation.");
ops.def( ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
"gelu_new", ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm // Layernorm
ops.def( ops.def("rms_norm", &rms_norm,
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor."); "Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def( ops.def("fused_add_rms_norm", &fused_add_rms_norm,
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization"); "In-place fused Add and RMS Normalization");
// Rotary embedding // Rotary embedding
ops.def( ops.def("rotary_embedding", &rotary_embedding,
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
ops.def( ops.def("batched_rotary_embedding", &batched_rotary_embedding,
"batched_rotary_embedding", "Apply GPT-NeoX or GPT-J style rotary embedding to query and key "
&batched_rotary_embedding, "(supports multiple loras)");
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
// Quantization ops // Quantization ops
#ifndef USE_ROCM #ifndef USE_ROCM
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); ops.def("marlin_gemm", &marlin_gemm,
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); "Marlin (Dense) Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm,
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm,
"gptq_marlin Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_repack", &gptq_marlin_repack,
"gptq_marlin repack from GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization."); ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq,
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
"per-row/column quantization.");
#endif #endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor"); ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant,
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); "Compute FP8 quantized tensor for given scaling factor");
ops.def( ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant,
"moe_align_block_size", "Compute FP8 quantized tensor and scaling factor");
&moe_align_block_size, ops.def("moe_align_block_size", &moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); "Aligning the number of tokens to be processed by each expert such "
"that it is divisible by the block size.");
// Cache ops // Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def( cache_ops.def("swap_blocks", &swap_blocks,
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst"); "Swap in (out) the cache blocks from src to dst");
cache_ops.def( cache_ops.def("copy_blocks", &copy_blocks,
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst"); "Copy the cache blocks from src to dst");
cache_ops.def( cache_ops.def("reshape_and_cache", &reshape_and_cache,
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them"); "Reshape the key and value tensors and cache them");
cache_ops.def( cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash,
"reshape_and_cache_flash",
&reshape_and_cache_flash,
"Reshape the key and value tensors and cache them"); "Reshape the key and value tensors and cache them");
cache_ops.def( cache_ops.def("convert_fp8", &convert_fp8,
"convert_fp8",
&convert_fp8,
"Convert the key and value cache to fp8 data type"); "Convert the key and value cache to fp8 data type");
// Cuda utils // Cuda utils
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); pybind11::module cuda_utils =
cuda_utils.def( m.def_submodule("cuda_utils", "vLLM cuda utils");
"get_device_attribute", cuda_utils.def("get_device_attribute", &get_device_attribute,
&get_device_attribute,
"Gets the specified device attribute."); "Gets the specified device attribute.");
cuda_utils.def( cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
"get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute, &get_max_shared_memory_per_block_device_attribute,
"Gets the maximum shared memory per block device attribute."); "Gets the maximum shared memory per block device attribute.");
...@@ -134,5 +105,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -134,5 +105,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
custom_ar.def("register_graph_buffers", &register_graph_buffers, custom_ar.def("register_graph_buffers", &register_graph_buffers,
"register_graph_buffers"); "register_graph_buffers");
#endif #endif
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment