Unverified Commit 5f6d10c1 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[CI/Build] Enforce style for C++ and CUDA code with `clang-format` (#4722)

parent 9b9a10d6
...@@ -2,34 +2,28 @@ ...@@ -2,34 +2,28 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#endif #endif
int get_device_attribute( int get_device_attribute(int attribute, int device_id) {
int attribute, int device, value;
int device_id) if (device_id < 0) {
{ cudaGetDevice(&device);
int device, value; } else {
if (device_id < 0) { device = device_id;
cudaGetDevice(&device); }
} cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
else { device);
device = device_id; return value;
}
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
return value;
} }
int get_max_shared_memory_per_block_device_attribute(int device_id) {
int get_max_shared_memory_per_block_device_attribute( int attribute;
int device_id) // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
{ // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
int attribute;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
#ifdef USE_ROCM #ifdef USE_ROCM
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
#else #else
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
#endif #endif
return get_device_attribute(attribute, device_id); return get_device_attribute(attribute, device_id);
} }
...@@ -7,11 +7,11 @@ ...@@ -7,11 +7,11 @@
// fake pointer type // fake pointer type
using fptr_t = uint64_t; using fptr_t = uint64_t;
static_assert(sizeof(void *) == sizeof(fptr_t)); static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int rank,
bool full_nvlink) { bool full_nvlink) {
int world_size = offsets.size(); int world_size = offsets.size();
if (world_size > 8) if (world_size > 8)
...@@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, ...@@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
} }
return (fptr_t) new vllm::CustomAllreduce( return (fptr_t) new vllm::CustomAllreduce(
reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(), reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
} }
...@@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, ...@@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
* 5. A[None].expand(2, -1, -1, -1): Not OK * 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK * 6. A[:, 1:, 1:]: Not OK
*/ */
bool _is_weak_contiguous(torch::Tensor &t) { bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() || return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size()); t.numel() * t.element_size());
} }
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
bool full_nvlink) { bool full_nvlink) {
auto inp_size = inp.numel() * inp.element_size(); auto inp_size = inp.numel() * inp.element_size();
// custom allreduce requires input byte size to be multiples of 16 // custom allreduce requires input byte size to be multiples of 16
...@@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, ...@@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
return false; return false;
} }
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
cudaStream_t stream) { cudaStream_t stream) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out)); TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) { switch (out.scalar_type()) {
case at::ScalarType::Float: { case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()), fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float *>(out.data_ptr()), reinterpret_cast<float*>(out.data_ptr()),
out.numel()); out.numel());
break; break;
} }
case at::ScalarType::Half: { case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()), fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half *>(out.data_ptr()), reinterpret_cast<half*>(out.data_ptr()), out.numel());
out.numel());
break; break;
} }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: { case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>( fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()), stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel()); reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
break; break;
} }
#endif #endif
...@@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, ...@@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
} }
} }
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream(); auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
...@@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { ...@@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
_all_reduce(_fa, inp, out, stream); _all_reduce(_fa, inp, out, stream);
} }
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer, void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor &out) { torch::Tensor& out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream(); auto stream = c10::cuda::getCurrentCUDAStream().stream();
...@@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer, ...@@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
} }
void dispose(fptr_t _fa) { void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
delete fa; delete fa;
} }
int meta_size() { return sizeof(vllm::Signal); } int meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor &t, void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets) { const std::vector<int64_t>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr()); fa->register_buffer(handles, offsets, t.data_ptr());
} }
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta( std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) { fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
return fa->get_graph_buffer_ipc_meta(); return fa->get_graph_buffer_ipc_meta();
} }
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles, void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>> &offsets) { const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets); fa->register_graph_buffers(handles, offsets);
} }
...@@ -31,9 +31,9 @@ struct Signal { ...@@ -31,9 +31,9 @@ struct Signal {
alignas(128) uint32_t end[kMaxBlocks][8]; alignas(128) uint32_t end[kMaxBlocks][8];
}; };
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
struct __align__(16) RankSignals { volatile Signal *signals[8]; }; struct __align__(16) RankSignals { volatile Signal* signals[8]; };
// like std::array, but aligned // like std::array, but aligned
template <typename T, int sz> template <typename T, int sz>
...@@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) { ...@@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) {
// scalar add functions // scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and // for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly // bfloat is disabled so we call the intrinsics directly
DINLINE half &assign_add(half &a, half b) { DINLINE half& assign_add(half& a, half b) {
a = __hadd(a, b); a = __hadd(a, b);
return a; return a;
} }
DINLINE float &assign_add(float &a, float b) { return a += b; } DINLINE float& assign_add(float& a, float b) { return a += b; }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
...@@ -80,14 +80,14 @@ template <> ...@@ -80,14 +80,14 @@ template <>
DINLINE nv_bfloat16 downcast_s(float val) { DINLINE nv_bfloat16 downcast_s(float val) {
return __float2bfloat16(val); return __float2bfloat16(val);
} }
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a = __hadd(a, b); a = __hadd(a, b);
return a; return a;
} }
#endif #endif
template <typename T, int N> template <typename T, int N>
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) { DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
#pragma unroll #pragma unroll
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
assign_add(a.data[i], b.data[i]); assign_add(a.data[i], b.data[i]);
...@@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) { ...@@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against // prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes. // other volatile writes.
template <int ngpus> template <int ngpus>
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) { int rank) {
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// reset flag for next time // reset flag for next time
...@@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, ...@@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write // Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// wait until we got true from all ranks // wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x]) while (!self_sg->start[blockIdx.x][threadIdx.x]);
;
} }
__syncthreads(); __syncthreads();
} }
...@@ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, ...@@ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier, // barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses. // we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false> template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) { int rank) {
__syncthreads(); __syncthreads();
// eliminate the case that prior writes are not visible after signals become // eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of // visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than // testing. Might be the case that hardware provides stronger guarantee than
// the memory model. // the memory model.
if constexpr (!final_sync) __threadfence_system(); if constexpr (!final_sync) __threadfence_system();
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// reset flag for next time // reset flag for next time
...@@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, ...@@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write // Latency = 1 p2p write
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
// wait until we got true from all ranks // wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x]) while (!self_sg->end[blockIdx.x][threadIdx.x]);
;
} }
if constexpr (!final_sync) __syncthreads(); if constexpr (!final_sync) __syncthreads();
} }
template <typename P, int ngpus, typename A> template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P *ptrs[], int idx) { DINLINE P packed_reduce(const P* ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]); A tmp = upcast(ptrs[0][idx]);
#pragma unroll #pragma unroll
for (int i = 1; i < ngpus; i++) { for (int i = 1; i < ngpus; i++) {
...@@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) { ...@@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
template <typename T, int ngpus> template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) __global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData *_dp, RankSignals sg, cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result, volatile Signal* self_sg, T* __restrict__ result,
int rank, int size) { int rank, int size) {
using P = typename packed_t<T>::P; using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A; using A = typename packed_t<T>::A;
...@@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1) ...@@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1)
// do the actual reduction // do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
((P *)result)[idx] = ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
} }
end_sync<ngpus, true>(sg, self_sg, rank); end_sync<ngpus, true>(sg, self_sg, rank);
} }
template <typename P> template <typename P>
DINLINE P *get_tmp_buf(volatile Signal *sg) { DINLINE P* get_tmp_buf(volatile Signal* sg) {
return (P *)(((Signal *)sg) + 1); return (P*)(((Signal*)sg) + 1);
} }
template <typename T, int ngpus> template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) __global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData *_dp, RankSignals sg, cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result, volatile Signal* self_sg, T* __restrict__ result,
int rank, int size) { int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x; int stride = gridDim.x * blockDim.x;
...@@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1) ...@@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1)
int start = rank * part; int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part; int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus; int largest_part = part + size % ngpus;
const P *ptrs[ngpus]; const P* ptrs[ngpus];
P *tmps[ngpus]; P* tmps[ngpus];
#pragma unroll #pragma unroll
for (int i = 0; i < ngpus; i++) { for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus; int target = (rank + i) % ngpus;
ptrs[i] = (const P *)_dp->ptrs[target]; ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]); tmps[i] = get_tmp_buf<P>(sg.signals[target]);
} }
auto tmp_out = tmps[0]; auto tmp_out = tmps[0];
...@@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1) ...@@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1)
int gather_from_rank = ((rank + i) % ngpus); int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) { if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx; int dst_idx = gather_from_rank * part + idx;
((P *)result)[dst_idx] = tmps[i][idx]; ((P*)result)[dst_idx] = tmps[i][idx];
} }
} }
} }
...@@ -261,14 +258,14 @@ class CustomAllreduce { ...@@ -261,14 +258,14 @@ class CustomAllreduce {
// below are device pointers // below are device pointers
RankSignals sg_; RankSignals sg_;
std::unordered_map<void *, RankData *> buffers_; std::unordered_map<void*, RankData*> buffers_;
Signal *self_sg_; Signal* self_sg_;
// stores the registered device pointers from all ranks // stores the registered device pointers from all ranks
RankData *d_rank_data_base_, *d_rank_data_end_; RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void *> graph_unreg_buffers_; std::vector<void*> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers // a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char *> ipc_handles_; std::map<IPC_KEY, char*> ipc_handles_;
/** /**
* meta is a pointer to device metadata and temporary buffer for allreduce. * meta is a pointer to device metadata and temporary buffer for allreduce.
...@@ -279,22 +276,22 @@ class CustomAllreduce { ...@@ -279,22 +276,22 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers * note: this class does not own any device memory. Any required buffers
* are passed in from the constructor * are passed in from the constructor
*/ */
CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz, CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz,
const cudaIpcMemHandle_t *handles, const cudaIpcMemHandle_t* handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int rank,
bool full_nvlink = true) bool full_nvlink = true)
: rank_(rank), : rank_(rank),
world_size_(offsets.size()), world_size_(offsets.size()),
full_nvlink_(full_nvlink), full_nvlink_(full_nvlink),
self_sg_(meta), self_sg_(meta),
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)), d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) { for (int i = 0; i < world_size_; i++) {
Signal *rank_sg; Signal* rank_sg;
if (i != rank_) { if (i != rank_) {
char *handle = open_ipc_handle(&handles[i]); char* handle = open_ipc_handle(&handles[i]);
handle += offsets[i]; handle += offsets[i];
rank_sg = (Signal *)handle; rank_sg = (Signal*)handle;
} else { } else {
rank_sg = self_sg_; rank_sg = self_sg_;
} }
...@@ -302,13 +299,13 @@ class CustomAllreduce { ...@@ -302,13 +299,13 @@ class CustomAllreduce {
} }
} }
char *open_ipc_handle(const void *ipc_handle) { char* open_ipc_handle(const void* ipc_handle) {
auto [it, new_handle] = auto [it, new_handle] =
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr}); ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) { if (new_handle) {
char *ipc_ptr; char* ipc_ptr;
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
*((const cudaIpcMemHandle_t *)ipc_handle), *((const cudaIpcMemHandle_t*)ipc_handle),
cudaIpcMemLazyEnablePeerAccess)); cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr; it->second = ipc_ptr;
} }
...@@ -323,7 +320,7 @@ class CustomAllreduce { ...@@ -323,7 +320,7 @@ class CustomAllreduce {
std::vector<int64_t> offsets(num_buffers); std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) { for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i]; auto ptr = graph_unreg_buffers_[i];
void *base_ptr; void* base_ptr;
// note: must share the base address of each allocation, or we get wrong // note: must share the base address of each allocation, or we get wrong
// address // address
if (cuPointerGetAttribute(&base_ptr, if (cuPointerGetAttribute(&base_ptr,
...@@ -331,8 +328,8 @@ class CustomAllreduce { ...@@ -331,8 +328,8 @@ class CustomAllreduce {
(CUdeviceptr)ptr) != CUDA_SUCCESS) (CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr"); throw std::runtime_error("failed to get pointer attr");
CUDACHECK(cudaIpcGetMemHandle( CUDACHECK(cudaIpcGetMemHandle(
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char *)ptr) - ((char *)base_ptr); offsets[i] = ((char*)ptr) - ((char*)base_ptr);
} }
return std::make_pair(handles, offsets); return std::make_pair(handles, offsets);
} }
...@@ -344,13 +341,13 @@ class CustomAllreduce { ...@@ -344,13 +341,13 @@ class CustomAllreduce {
std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
} }
void register_buffer(const std::vector<std::string> &handles, void register_buffer(const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, void *self) { const std::vector<int64_t>& offsets, void* self) {
check_rank_data_capacity(); check_rank_data_capacity();
RankData data; RankData data;
for (int i = 0; i < world_size_; i++) { for (int i = 0; i < world_size_; i++) {
if (i != rank_) { if (i != rank_) {
char *handle = open_ipc_handle(handles[i].data()); char* handle = open_ipc_handle(handles[i].data());
handle += offsets[i]; handle += offsets[i];
data.ptrs[i] = handle; data.ptrs[i] = handle;
} else { } else {
...@@ -371,17 +368,17 @@ class CustomAllreduce { ...@@ -371,17 +368,17 @@ class CustomAllreduce {
// got a different address. IPC handles have internal reference counting // got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small. // mechanism so overhead should be small.
void register_graph_buffers( void register_graph_buffers(
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>> &offsets) { const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size(); auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers); check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers); std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) { for (int i = 0; i < num_buffers; i++) {
auto self_ptr = graph_unreg_buffers_[i]; auto self_ptr = graph_unreg_buffers_[i];
auto &rd = rank_data[i]; auto& rd = rank_data[i];
for (int j = 0; j < world_size_; j++) { for (int j = 0; j < world_size_; j++) {
if (j != rank_) { if (j != rank_) {
char *handle = char* handle =
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i]; handle += offsets[j][i];
rd.ptrs[j] = handle; rd.ptrs[j] = handle;
...@@ -405,7 +402,7 @@ class CustomAllreduce { ...@@ -405,7 +402,7 @@ class CustomAllreduce {
* will cause contention on NVLink bus. * will cause contention on NVLink bus.
*/ */
template <typename T> template <typename T>
void allreduce(cudaStream_t stream, T *input, T *output, int size, void allreduce(cudaStream_t stream, T* input, T* output, int size,
int threads = 512, int block_limit = 36) { int threads = 512, int block_limit = 36) {
auto d = packed_t<T>::P::size; auto d = packed_t<T>::P::size;
if (size % d != 0) if (size % d != 0)
...@@ -418,7 +415,7 @@ class CustomAllreduce { ...@@ -418,7 +415,7 @@ class CustomAllreduce {
std::to_string(kMaxBlocks) + ". Got " + std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit)); std::to_string(block_limit));
RankData *ptrs; RankData* ptrs;
cudaStreamCaptureStatus status; cudaStreamCaptureStatus status;
CUDACHECK(cudaStreamIsCapturing(stream, &status)); CUDACHECK(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) { if (status == cudaStreamCaptureStatusActive) {
......
...@@ -48,7 +48,7 @@ __global__ void dummy_kernel() { ...@@ -48,7 +48,7 @@ __global__ void dummy_kernel() {
} }
template <typename T> template <typename T>
__global__ void set_data(T *data, int size, int myRank) { __global__ void set_data(T* data, int size, int myRank) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
data[idx] = myRank * 0.11f; data[idx] = myRank * 0.11f;
...@@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) { ...@@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) {
} }
template <typename T> template <typename T>
__global__ void convert_data(const T *data1, const T *data2, double *fdata1, __global__ void convert_data(const T* data1, const T* data2, double* fdata1,
double *fdata2, int size) { double* fdata2, int size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
fdata1[idx] = data1[idx]; fdata1[idx] = data1[idx];
...@@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1, ...@@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1,
} }
} }
__global__ void init_rand(curandState_t *state, int size, int nRanks) { __global__ void init_rand(curandState_t* state, int size, int nRanks) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
for (int i = 0; i < nRanks; i++) { for (int i = 0; i < nRanks; i++) {
...@@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) { ...@@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) {
} }
template <typename T> template <typename T>
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth, __global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
int myRank, int nRanks, int size) { int myRank, int nRanks, int size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
...@@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth, ...@@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
} }
template <typename T> template <typename T>
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
int data_size, bool performance_test) { int data_size, bool performance_test) {
T *result; T* result;
cudaStream_t stream; cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
...@@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, ...@@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
cudaIpcMemHandle_t self_data_handle; cudaIpcMemHandle_t self_data_handle;
cudaIpcMemHandle_t data_handles[8]; cudaIpcMemHandle_t data_handles[8];
vllm::Signal *buffer; vllm::Signal* buffer;
T *self_data_copy; T* self_data_copy;
/** /**
* Allocate IPC buffer * Allocate IPC buffer
* *
...@@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, ...@@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
MPI_BYTE, MPI_COMM_WORLD)); MPI_BYTE, MPI_COMM_WORLD));
void *rank_data; void* rank_data;
size_t rank_data_sz = 16 * 1024 * 1024; size_t rank_data_sz = 16 * 1024 * 1024;
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
std::vector<int64_t> offsets(nRanks, 0); std::vector<int64_t> offsets(nRanks, 0);
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
offsets, myRank); offsets, myRank);
auto *self_data = auto* self_data =
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) + reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) +
sizeof(vllm::Signal) + data_size * sizeof(T)); sizeof(vllm::Signal) + data_size * sizeof(T));
// hack buffer registration // hack buffer registration
{ {
std::vector<std::string> handles; std::vector<std::string> handles;
handles.reserve(nRanks); handles.reserve(nRanks);
for (int i = 0; i < nRanks; i++) { for (int i = 0; i < nRanks; i++) {
char *begin = (char *)&data_handles[i]; char* begin = (char*)&data_handles[i];
char *end = (char *)&data_handles[i + 1]; char* end = (char*)&data_handles[i + 1];
handles.emplace_back(begin, end); handles.emplace_back(begin, end);
} }
std::vector<int64_t> offsets(nRanks, std::vector<int64_t> offsets(nRanks,
...@@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, ...@@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
fa.register_buffer(handles, offsets, self_data); fa.register_buffer(handles, offsets, self_data);
} }
double *ground_truth; double* ground_truth;
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
curandState_t *states; curandState_t* states;
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
...@@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, ...@@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
CUDACHECK(cudaStreamDestroy(stream)); CUDACHECK(cudaStreamDestroy(stream));
} }
int main(int argc, char **argv) { int main(int argc, char** argv) {
int nRanks, myRank; int nRanks, myRank;
MPICHECK(MPI_Init(&argc, &argv)); MPICHECK(MPI_Init(&argc, &argv));
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
...@@ -296,7 +296,7 @@ int main(int argc, char **argv) { ...@@ -296,7 +296,7 @@ int main(int argc, char **argv) {
ncclUniqueId id; ncclUniqueId id;
ncclComm_t comm; ncclComm_t comm;
if (myRank == 0) ncclGetUniqueId(&id); if (myRank == 0) ncclGetUniqueId(&id);
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0, MPICHECK(MPI_Bcast(static_cast<void*>(&id), sizeof(id), MPI_BYTE, 0,
MPI_COMM_WORLD)); MPI_COMM_WORLD));
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
......
...@@ -6,32 +6,30 @@ ...@@ -6,32 +6,30 @@
#include <torch/extension.h> #include <torch/extension.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
...@@ -11,26 +11,24 @@ ...@@ -11,26 +11,24 @@
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162; using __nv_bfloat162 = __hip_bfloat162;
#endif #endif
namespace vllm { namespace vllm {
// TODO(woosuk): Further optimize this kernel. // TODO(woosuk): Further optimize this kernel.
template<typename scalar_t> template <typename scalar_t>
__global__ void rms_norm_kernel( __global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float) input[blockIdx.x * hidden_size + idx]; const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x; variance += x * x;
} }
variance = blockReduceSum<float>(variance); variance = blockReduceSum<float>(variance);
...@@ -40,12 +38,12 @@ __global__ void rms_norm_kernel( ...@@ -40,12 +38,12 @@ __global__ void rms_norm_kernel(
__syncthreads(); __syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) input[blockIdx.x * hidden_size + idx]; float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
} }
} }
/* Converter structs for the conversion from torch types to HIP/CUDA types, /* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion to be implemented for now because the relevant type conversion
...@@ -54,51 +52,68 @@ __global__ void rms_norm_kernel( ...@@ -54,51 +52,68 @@ __global__ void rms_norm_kernel(
Each struct should have the member static constexpr bool `exists`: Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type. If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below. If true, the struct should be fully defined as shown in the examples below.
*/ */
template<typename torch_type> template <typename torch_type>
struct _typeConvert { static constexpr bool exists = false; }; struct _typeConvert {
static constexpr bool exists = false;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion // CUDA < 12.0 runs into issues with packed type conversion
template<> template <>
struct _typeConvert<c10::Half> { struct _typeConvert<c10::Half> {
static constexpr bool exists = true; static constexpr bool exists = true;
using hip_type = __half; using hip_type = __half;
using packed_hip_type = __half2; using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); } __device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } __device__ static inline float2 convert(packed_hip_type x) {
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); } return __half22float2(x);
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } }
__device__ static inline hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
}; };
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support // CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely // TODO: Add in ROCm support once public headers handle bf16 maturely
template<> template <>
struct _typeConvert<c10::BFloat16> { struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true; static constexpr bool exists = true;
using hip_type = __nv_bfloat16; using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162; using packed_hip_type = __nv_bfloat162;
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); } __device__ static inline float convert(hip_type x) {
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } return __bfloat162float(x);
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } }
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } __device__ static inline float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
}; };
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) #endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops /* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel. for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented. Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops. Alignment to 16 bytes is required to use 128-bit global memory ops.
*/ */
template<typename scalar_t, int width> template <typename scalar_t, int width>
struct alignas(16) _f16Vec { struct alignas(16) _f16Vec {
/* Not theoretically necessary that width is a power of 2 but should /* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */ almost always be the case for optimization purposes */
static_assert(width > 0 && (width & (width - 1)) == 0, static_assert(width > 0 && (width & (width - 1)) == 0,
"Width is not a positive power of 2!"); "Width is not a positive power of 2!");
using Converter = _typeConvert<scalar_t>; using Converter = _typeConvert<scalar_t>;
...@@ -108,51 +123,49 @@ struct alignas(16) _f16Vec { ...@@ -108,51 +123,49 @@ struct alignas(16) _f16Vec {
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) { __device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i+1]}; T2 temp{data[i], data[i + 1]};
temp += T2{other.data[i], other.data[i+1]}; temp += T2{other.data[i], other.data[i + 1]};
data[i] = temp.x; data[i] = temp.x;
data[i+1] = temp.y; data[i + 1] = temp.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) for (int i = 0; i < width; ++i) data[i] += other.data[i];
data[i] += other.data[i];
} }
return *this; return *this;
} }
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) { __device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i+1]}; T2 temp{data[i], data[i + 1]};
temp *= T2{other.data[i], other.data[i+1]}; temp *= T2{other.data[i], other.data[i + 1]};
data[i] = temp.x; data[i] = temp.x;
data[i+1] = temp.y; data[i + 1] = temp.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) for (int i = 0; i < width; ++i) data[i] *= other.data[i];
data[i] *= other.data[i];
} }
return *this; return *this;
} }
__device__ _f16Vec& operator*=(const float scale) { __device__ _f16Vec& operator*=(const float scale) {
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
float2 temp_f = Converter::convert(T2{data[i], data[i+1]}); float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
temp_f.x *= scale; temp_f.x *= scale;
temp_f.y *= scale; temp_f.y *= scale;
T2 temp = Converter::convert(temp_f); T2 temp = Converter::convert(temp_f);
data[i] = temp.x; data[i] = temp.x;
data[i+1] = temp.y; data[i + 1] = temp.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
float temp = Converter::convert(data[i]) * scale; float temp = Converter::convert(data[i]) * scale;
data[i] = Converter::convert(temp); data[i] = Converter::convert(temp);
...@@ -164,13 +177,13 @@ struct alignas(16) _f16Vec { ...@@ -164,13 +177,13 @@ struct alignas(16) _f16Vec {
__device__ float sum_squares() const { __device__ float sum_squares() const {
float result = 0.0f; float result = 0.0f;
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i+1]}); float2 z = Converter::convert(T2{data[i], data[i + 1]});
result += z.x * z.x + z.y * z.y; result += z.x * z.x + z.y * z.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
float x = Converter::convert(data[i]); float x = Converter::convert(data[i]);
result += x * x; result += x * x;
...@@ -184,15 +197,13 @@ struct alignas(16) _f16Vec { ...@@ -184,15 +197,13 @@ struct alignas(16) _f16Vec {
Additional optimizations we can make in this case are Additional optimizations we can make in this case are
packed and vectorized operations, which help with the packed and vectorized operations, which help with the
memory latency bottleneck. */ memory latency bottleneck. */
template<typename scalar_t, int width> template <typename scalar_t, int width>
__global__ std::enable_if_t< __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic // Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>); static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width); static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
...@@ -203,9 +214,12 @@ __global__ std::enable_if_t< ...@@ -203,9 +214,12 @@ __global__ std::enable_if_t<
/* These and the argument pointers are all declared `restrict` as they are /* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */ in this kernel as that would be undefined behavior */
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input); auto* __restrict__ input_v =
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual); reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight); auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx; int id = blockIdx.x * vec_hidden_size + idx;
...@@ -215,10 +229,11 @@ __global__ std::enable_if_t< ...@@ -215,10 +229,11 @@ __global__ std::enable_if_t<
residual_v[id] = temp; residual_v[id] = temp;
} }
/* Keep the following if-else block in sync with the /* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */ calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) { if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance); variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance); } else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon); s_variance = rsqrtf(variance / hidden_size + epsilon);
} }
...@@ -233,52 +248,50 @@ __global__ std::enable_if_t< ...@@ -233,52 +248,50 @@ __global__ std::enable_if_t<
} }
} }
/* Generic fused_add_rms_norm_kernel /* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations. The width field is not used here but necessary for other specializations.
*/ */
template<typename scalar_t, int width> template <typename scalar_t, int width>
__global__ std::enable_if_t< __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx]; scalar_t z = input[blockIdx.x * hidden_size + idx];
z += residual[blockIdx.x * hidden_size + idx]; z += residual[blockIdx.x * hidden_size + idx];
float x = (float) z; float x = (float)z;
variance += x * x; variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z; residual[blockIdx.x * hidden_size + idx] = z;
} }
/* Keep the following if-else block in sync with the /* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */ calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) { if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance); variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance); } else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon); s_variance = rsqrtf(variance / hidden_size + epsilon);
} }
__syncthreads(); __syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) residual[blockIdx.x * hidden_size + idx]; float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; input[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
} }
} }
} // namespace vllm } // namespace vllm
void rms_norm( void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size]
torch::Tensor& weight, // [hidden_size] float epsilon) {
float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
...@@ -286,40 +299,27 @@ void rms_norm( ...@@ -286,40 +299,27 @@ void rms_norm(
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
input.scalar_type(), vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
"rms_norm_kernel", out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
[&] { weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>( });
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
});
} }
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
"fused_add_rms_norm_kernel", \ vllm::fused_add_rms_norm_kernel<scalar_t, width> \
[&] { \ <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
vllm::fused_add_rms_norm_kernel \ residual.data_ptr<scalar_t>(), \
<scalar_t, width><<<grid, block, 0, stream>>>( \ weight.data_ptr<scalar_t>(), epsilon, \
input.data_ptr<scalar_t>(), \ num_tokens, hidden_size); \
residual.data_ptr<scalar_t>(), \ });
weight.data_ptr<scalar_t>(), \
epsilon, \ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
num_tokens, \ torch::Tensor& residual, // [..., hidden_size]
hidden_size); \ torch::Tensor& weight, // [hidden_size]
}); float epsilon) {
void fused_add_rms_norm(
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
...@@ -342,8 +342,8 @@ void fused_add_rms_norm( ...@@ -342,8 +342,8 @@ void fused_add_rms_norm(
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr()); auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr()); auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr()); auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ bool ptrs_are_aligned =
&& wt_ptr % 16 == 0; inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) { if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8); LAUNCH_FUSED_ADD_RMS_NORM(8);
} else { } else {
......
...@@ -3,5 +3,6 @@ ...@@ -3,5 +3,6 @@
#include <torch/extension.h> #include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); m.def("topk_softmax", &topk_softmax,
"Apply topk softmax to the gating outputs.");
} }
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
void topk_softmax( void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& topk_weights, torch::Tensor& token_expert_indices,
torch::Tensor& topk_indices, torch::Tensor& gating_output);
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);
...@@ -7,119 +7,128 @@ ...@@ -7,119 +7,128 @@
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) #define CEILDIV(x, y) (((x) + (y) - 1) / (y))
namespace vllm { namespace vllm {
namespace { namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
// don't worry about overflow because num_experts is relatively small int32_t col) {
return row * total_col + col; // don't worry about overflow because num_experts is relatively small
} return row * total_col + col;
} }
} // namespace
template <typename scalar_t> template <typename scalar_t>
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
int32_t *sorted_token_ids, int32_t* sorted_token_ids,
int32_t *expert_ids, int32_t* expert_ids,
int32_t *total_tokens_post_pad, int32_t* total_tokens_post_pad,
int32_t num_experts, int32_t num_experts,
int32_t block_size, int32_t block_size, size_t numel) {
size_t numel) { const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread;
const size_t start_idx = threadIdx.x * tokens_per_thread;
extern __shared__ int32_t shared_mem[];
extern __shared__ int32_t shared_mem[];
int32_t* tokens_cnts =
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) int32_t* cumsum =
shared_mem + (num_experts + 1) *
for (int i = 0; i < num_experts; ++i) { num_experts; // 1d tensor with shape (num_experts + 1)
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
} for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
/** }
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are assigned /**
* to expert expert_index. * In the first step we compute token_cnts[thread_index + 1][expert_index],
*/ * which counts how many tokens in the token shard of thread_index are
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { * assigned to expert expert_index.
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; */
} for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
__syncthreads(); }
// For each expert we accumulate the token counts from the different threads. __syncthreads();
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) { // For each expert we accumulate the token counts from the different threads.
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
} for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
__syncthreads(); tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) { __syncthreads();
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) { // We accumulate the token counts of all experts in thread 0.
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; if (threadIdx.x == 0) {
} cumsum[0] = 0;
*total_tokens_post_pad = cumsum[num_experts]; for (int i = 1; i <= num_experts; ++i) {
} cumsum[i] = cumsum[i - 1] +
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
__syncthreads(); block_size) *
block_size;
/**
* For each expert, each thread processes the tokens of the corresponding blocks
* and stores the corresponding expert_id for each block.
*/
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
} }
*total_tokens_post_pad = cumsum[num_experts];
/** }
* Each thread processes a token shard, calculating the index of each token after
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and __syncthreads();
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
* where * represents a padding value(preset in python). /**
*/ * For each expert, each thread processes the tokens of the corresponding
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { * blocks and stores the corresponding expert_id for each block.
int32_t expert_id = topk_ids[i]; */
/** The cumsum[expert_id] stores the starting index of the tokens that the for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] i += block_size) {
* stores the indices of the tokens processed by the expert with expert_id within expert_ids[i / block_size] = threadIdx.x;
* the current thread's token shard. }
*/
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; /**
sorted_token_ids[rank_post_pad] = i; * Each thread processes a token shard, calculating the index of each token
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; * after sorting by expert number. Given the example topk_ids =
} * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
} * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
}
} }
} // namespace vllm
void moe_align_block_size(
torch::Tensor topk_ids, void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
int num_experts, int block_size, torch::Tensor sorted_token_ids,
int block_size, torch::Tensor experts_ids,
torch::Tensor sorted_token_ids, torch::Tensor num_tokens_post_pad) {
torch::Tensor experts_ids, const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor num_tokens_post_pad) { VLLM_DISPATCH_INTEGRAL_TYPES(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
VLLM_DISPATCH_INTEGRAL_TYPES( // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // tensors
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors const int32_t shared_mem =
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); ((num_experts + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);
// set dynamic shared mem // set dynamic shared mem
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>; auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK( AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem)); (void*)kernel, shared_mem));
kernel<<<1, num_experts, shared_mem, stream>>>( kernel<<<1, num_experts, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
sorted_token_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel()); topk_ids.numel());
}); });
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -117,10 +117,10 @@ struct cutlass_2x_gemm { ...@@ -117,10 +117,10 @@ struct cutlass_2x_gemm {
}; };
template <typename Gemm> template <typename Gemm>
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const &b, torch::Tensor const& b,
torch::Tensor const &a_scales, torch::Tensor const& a_scales,
torch::Tensor const &b_scales) { torch::Tensor const& b_scales) {
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD; using ElementD = typename Gemm::ElementD;
...@@ -136,9 +136,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, ...@@ -136,9 +136,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
using StrideC = Stride<int64_t, Int<1>, Int<0>>; using StrideC = Stride<int64_t, Int<1>, Int<0>>;
StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
auto a_ptr = static_cast<ElementAB const *>(a.data_ptr()); auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB const *>(b.data_ptr()); auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
auto c_ptr = static_cast<ElementD *>(out.data_ptr()); auto c_ptr = static_cast<ElementD*>(out.data_ptr());
auto a_scales_ptr = a_scales.data_ptr<float>(); auto a_scales_ptr = a_scales.data_ptr<float>();
auto b_scales_ptr = b_scales.data_ptr<float>(); auto b_scales_ptr = b_scales.data_ptr<float>();
...@@ -196,10 +196,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, ...@@ -196,10 +196,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
} // namespace } // namespace
void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const &b, torch::Tensor const& b,
torch::Tensor const &a_scales, torch::Tensor const& a_scales,
torch::Tensor const &b_scales) { torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
...@@ -223,10 +223,10 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, ...@@ -223,10 +223,10 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a,
} }
} }
void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const &b, torch::Tensor const& b,
torch::Tensor const &a_scales, torch::Tensor const& a_scales,
torch::Tensor const &b_scales) { torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
...@@ -250,10 +250,10 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, ...@@ -250,10 +250,10 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a,
} }
} }
void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const &b, torch::Tensor const& b,
torch::Tensor const &a_scales, torch::Tensor const& a_scales,
torch::Tensor const &b_scales) { torch::Tensor const& b_scales) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment