Unverified Commit d052f4c8 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

New clang format for sgl kernel (#4194)

parent e1aaa79a
cp ../README.md ../LICENSE .
rm -rf dist
python3 -m build
python3 -m twine upload dist/*
rm -rf README.md LICENSE
...@@ -6,3 +6,10 @@ DerivePointerAlignment: false ...@@ -6,3 +6,10 @@ DerivePointerAlignment: false
PointerAlignment: Left PointerAlignment: Left
NamespaceIndentation: None NamespaceIndentation: None
SortIncludes: true SortIncludes: true
AllowShortLoopsOnASingleLine: false
BinPackParameters: false # Prevents packing parameters in declarations
BinPackArguments: false # Prevents packing arguments in function calls
AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis
AlignOperands: Align # Aligns arguments vertically
PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument
PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name
...@@ -41,10 +41,15 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T ...@@ -41,10 +41,15 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
// support float16, bfloat16 and float32 // support float16, bfloat16 and float32
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::FusedAddRMSNorm( cudaError_t status = norm::FusedAddRMSNorm(
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()), static_cast<c_type*>(input.data_ptr()),
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); static_cast<c_type*>(residual.data_ptr()),
TORCH_CHECK(status == cudaSuccess, static_cast<c_type*>(weight.data_ptr()),
"FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); batch_size,
hidden_size,
eps,
torch_current_stream);
TORCH_CHECK(
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
return true; return true;
}); });
} }
...@@ -153,19 +153,20 @@ DINLINE O downcast(array_t<float, O::size> val) { ...@@ -153,19 +153,20 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against // prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes. // other volatile writes.
template <int ngpus> template <int ngpus>
DINLINE void start_sync(const RankSignals& sg, DINLINE void start_sync(
const RankSignals& sg,
#ifndef USE_ROCM #ifndef USE_ROCM
volatile volatile
#endif #endif
Signal* self_sg, Signal* self_sg,
int rank) { int rank) {
#ifdef USE_ROCM #ifdef USE_ROCM
uint32_t flag = self_sg->_flag[blockIdx.x] + 1; uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks. // simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write // Latency = 1 p2p write
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __scoped_atomic_store_n(
__MEMORY_SCOPE_SYSTEM); &sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks // wait until we got true from all ranks
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) < while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) <
flag) flag)
...@@ -193,12 +194,13 @@ DINLINE void start_sync(const RankSignals& sg, ...@@ -193,12 +194,13 @@ DINLINE void start_sync(const RankSignals& sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier, // barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses. // we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false> template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals& sg, DINLINE void end_sync(
const RankSignals& sg,
#ifndef USE_ROCM #ifndef USE_ROCM
volatile volatile
#endif #endif
Signal* self_sg, Signal* self_sg,
int rank) { int rank) {
#ifdef USE_ROCM #ifdef USE_ROCM
__syncthreads(); __syncthreads();
// eliminate the case that prior writes are not visible after signals become // eliminate the case that prior writes are not visible after signals become
...@@ -209,11 +211,16 @@ DINLINE void end_sync(const RankSignals& sg, ...@@ -209,11 +211,16 @@ DINLINE void end_sync(const RankSignals& sg,
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks. // simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write // Latency = 1 p2p write
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag, __scoped_atomic_store_n(
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, __MEMORY_SCOPE_SYSTEM); &sg.signals[threadIdx.x]->end[blockIdx.x][rank],
flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
__MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks // wait until we got true from all ranks
while (__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], while (__scoped_atomic_load_n(
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, __MEMORY_SCOPE_DEVICE) < flag) &self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
__MEMORY_SCOPE_DEVICE) < flag)
; ;
} }
__syncthreads(); __syncthreads();
...@@ -251,12 +258,16 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { ...@@ -251,12 +258,16 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
} }
template <typename T, int ngpus> template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData* _dp, RankSignals sg, __global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
RankData* _dp,
RankSignals sg,
#ifndef USE_ROCM #ifndef USE_ROCM
volatile volatile
#endif #endif
Signal* self_sg, Signal* self_sg,
T* __restrict__ result, int rank, int size) { T* __restrict__ result,
int rank,
int size) {
using P = typename packed_t<T>::P; using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A; using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same // note: we don't reorder the address so the accumulation order is the same
...@@ -280,12 +291,16 @@ DINLINE P* get_tmp_buf(volatile Signal* sg) { ...@@ -280,12 +291,16 @@ DINLINE P* get_tmp_buf(volatile Signal* sg) {
} }
template <typename T, int ngpus> template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData* _dp, RankSignals sg, __global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
RankData* _dp,
RankSignals sg,
#ifndef USE_ROCM #ifndef USE_ROCM
volatile volatile
#endif #endif
Signal* self_sg, Signal* self_sg,
T* __restrict__ result, int rank, int size) { T* __restrict__ result,
int rank,
int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x; int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P; using P = typename packed_t<T>::P;
...@@ -357,8 +372,14 @@ class CustomAllreduce { ...@@ -357,8 +372,14 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers * note: this class does not own any device memory. Any required buffers
* are passed in from the constructor * are passed in from the constructor
*/ */
CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, const hipIpcMemHandle_t* handles, CustomAllreduce(
const std::vector<int64_t>& offsets, int rank, bool full_nvlink = true) Signal* meta,
void* rank_data,
size_t rank_data_sz,
const hipIpcMemHandle_t* handles,
const std::vector<int64_t>& offsets,
int rank,
bool full_nvlink = true)
: rank_(rank), : rank_(rank),
world_size_(offsets.size()), world_size_(offsets.size()),
full_nvlink_(full_nvlink), full_nvlink_(full_nvlink),
...@@ -382,8 +403,8 @@ class CustomAllreduce { ...@@ -382,8 +403,8 @@ class CustomAllreduce {
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) { if (new_handle) {
char* ipc_ptr; char* ipc_ptr;
CUDACHECK(hipIpcOpenMemHandle((void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), CUDACHECK(hipIpcOpenMemHandle(
hipIpcMemLazyEnablePeerAccess)); (void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), hipIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr; it->second = ipc_ptr;
} }
return it->second; return it->second;
...@@ -399,13 +420,14 @@ class CustomAllreduce { ...@@ -399,13 +420,14 @@ class CustomAllreduce {
void* base_ptr; void* base_ptr;
// note: must share the base address of each allocation, or we get wrong // note: must share the base address of each allocation, or we get wrong
// address // address
if (hipPointerGetAttribute(&base_ptr, if (hipPointerGetAttribute(
&base_ptr,
#ifdef USE_ROCM #ifdef USE_ROCM
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR, HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#else #else
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#endif #endif
(hipDeviceptr_t)ptr) != hipSuccess) (hipDeviceptr_t)ptr) != hipSuccess)
throw std::runtime_error("failed to get pointer attr"); throw std::runtime_error("failed to get pointer attr");
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr); offsets[i] = ((char*)ptr) - ((char*)base_ptr);
...@@ -415,8 +437,8 @@ class CustomAllreduce { ...@@ -415,8 +437,8 @@ class CustomAllreduce {
void check_rank_data_capacity(size_t num = 1) { void check_rank_data_capacity(size_t num = 1) {
if (d_rank_data_base_ + num > d_rank_data_end_) if (d_rank_data_base_ + num > d_rank_data_end_)
throw std::runtime_error("Rank data buffer is overflowed by " + throw std::runtime_error(
std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); "Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
} }
void register_buffer(const std::vector<std::string>& handles, const std::vector<int64_t>& offsets, void* self) { void register_buffer(const std::vector<std::string>& handles, const std::vector<int64_t>& offsets, void* self) {
...@@ -443,8 +465,8 @@ class CustomAllreduce { ...@@ -443,8 +465,8 @@ class CustomAllreduce {
// rank 1 may get the same input address for the second allreduce, but rank 2 // rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting // got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small. // mechanism so overhead should be small.
void register_graph_buffers(const std::vector<std::string>& handles, void
const std::vector<std::vector<int64_t>>& offsets) { register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size(); auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers); check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers); std::vector<RankData> rank_data(num_buffers);
...@@ -474,11 +496,17 @@ class CustomAllreduce { ...@@ -474,11 +496,17 @@ class CustomAllreduce {
* will cause contention on NVLink bus. * will cause contention on NVLink bus.
*/ */
template <typename T> template <typename T>
void allreduce(hipStream_t stream, T* input, T* output, int size, void allreduce(
hipStream_t stream,
T* input,
T* output,
int size,
#ifndef USE_ROCM #ifndef USE_ROCM
int threads = 512, int block_limit = 36){ int threads = 512,
int block_limit = 36){
#else #else
int threads = 512, int block_limit = 16) { int threads = 512,
int block_limit = 16) {
#endif #endif
auto d = packed_t<T>::P::size; auto d = packed_t<T>::P::size;
if (size % d != 0) if (size % d != 0)
...@@ -487,8 +515,8 @@ class CustomAllreduce { ...@@ -487,8 +515,8 @@ class CustomAllreduce {
"of " + "of " +
std::to_string(d)); std::to_string(d));
if (block_limit > kMaxBlocks) if (block_limit > kMaxBlocks)
throw std::runtime_error("max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + throw std::runtime_error(
std::to_string(block_limit)); "max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
RankData* ptrs; RankData* ptrs;
hipStreamCaptureStatus status; hipStreamCaptureStatus status;
...@@ -499,17 +527,17 @@ class CustomAllreduce { ...@@ -499,17 +527,17 @@ class CustomAllreduce {
} else { } else {
auto it = buffers_.find(input); auto it = buffers_.find(input);
if (it == buffers_.end()) if (it == buffers_.end())
throw std::runtime_error("buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + throw std::runtime_error(
" is not registered!"); "buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
ptrs = it->second; ptrs = it->second;
} }
size /= d; size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P); auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = ::min(block_limit, (size + threads - 1) / threads); int blocks = ::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) \ #define KL(ngpus, name) \
hipLaunchKernelGGL((name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, \ hipLaunchKernelGGL( \
size); (name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, size);
#define REDUCE_CASE(ngpus) \ #define REDUCE_CASE(ngpus) \
case ngpus: { \ case ngpus: { \
if (world_size_ == 2) { \ if (world_size_ == 2) { \
......
...@@ -118,8 +118,13 @@ inline __device__ int4 add128b(T& a, T& b) { ...@@ -118,8 +118,13 @@ inline __device__ int4 add128b(T& a, T& b) {
return c.packed; return c.packed;
} }
__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, __inline__ __device__ void multi_gpu_barrier(
size_t const world_size, int const tidx, int const bidx) { uint32_t** signals,
uint32_t const flag,
size_t const local_rank,
size_t const world_size,
int const tidx,
int const bidx) {
// After this function, at least one block in each GPU has reached the barrier // After this function, at least one block in each GPU has reached the barrier
if (tidx < world_size) { if (tidx < world_size) {
// we can think of signals having the shape [world_size, world_size] // we can think of signals having the shape [world_size, world_size]
...@@ -143,8 +148,14 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const ...@@ -143,8 +148,14 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
} }
template <bool start, bool need_fence = false> template <bool start, bool need_fence = false>
__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, __inline__ __device__ void block_barrier(
size_t const world_size, int const tidx, int const bidx, int const grid_size) { uint32_t** signals,
uint32_t const flag,
size_t const local_rank,
size_t const world_size,
int const tidx,
int const bidx,
int const grid_size) {
if constexpr (!start) { if constexpr (!start) {
__syncthreads(); __syncthreads();
} }
...@@ -227,8 +238,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc ...@@ -227,8 +238,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc
} }
} }
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer // wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
block_barrier<true>(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, block_barrier<true>(
grid_size); params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Each block accumulates the values from the different GPUs on the same node. // Each block accumulates the values from the different GPUs on the same node.
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) { for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
...@@ -341,8 +352,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc ...@@ -341,8 +352,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
} }
} }
} }
block_barrier<true>(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, block_barrier<true>(
grid_size); params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Each block accumulates the values from the different GPUs on the same node. // Each block accumulates the values from the different GPUs on the same node.
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
...@@ -372,8 +383,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc ...@@ -372,8 +383,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
} }
} }
block_barrier<false, true>(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, block_barrier<false, true>(
bidx, grid_size); params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Gather all needed elts from other intra-node ranks // Gather all needed elts from other intra-node ranks
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
...@@ -459,8 +470,12 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar ...@@ -459,8 +470,12 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT> template <typename T, int RANKS_PER_NODE, bool COPY_INPUT>
void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block, void dispatchARKernels(
cudaStream_t stream) { AllReduceStrategyType algo,
AllReduceParams& param,
int blocks_per_grid,
int threads_per_block,
cudaStream_t stream) {
switch (algo) { switch (algo) {
case AllReduceStrategyType::ONESHOT: { case AllReduceStrategyType::ONESHOT: {
oneShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param); oneShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
...@@ -505,8 +520,8 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy ...@@ -505,8 +520,8 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy
CHECK_CUDA_SUCCESS(cudaGetLastError()); CHECK_CUDA_SUCCESS(cudaGetLastError());
} }
void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, void trtCustomAllReduce(
cudaStream_t stream) { AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream) {
if (params.elts_total == 0) { if (params.elts_total == 0) {
return; return;
} }
......
...@@ -29,9 +29,14 @@ using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>; ...@@ -29,9 +29,14 @@ using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
class AllReduceMeta { class AllReduceMeta {
public: public:
AllReduceMeta(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers, AllReduceMeta(
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in, int64_t rank_id,
const std::vector<fptr_t>& barrier_out) { int64_t world_size,
torch::Tensor& rank_data,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out) {
this->rank_id = (int)rank_id; this->rank_id = (int)rank_id;
this->world_size = (int)world_size; this->world_size = (int)world_size;
this->barrier_in = barrier_in; this->barrier_in = barrier_in;
...@@ -86,9 +91,14 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) ...@@ -86,9 +91,14 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype)
return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0; return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0;
} }
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers, fptr_t init_custom_ar(
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in, int64_t rank_id,
const std::vector<fptr_t>& barrier_out) { int64_t world_size,
torch::Tensor& rank_data,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out) {
auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out); auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out);
return (fptr_t)m; return (fptr_t)m;
} }
...@@ -124,8 +134,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) { ...@@ -124,8 +134,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) { if (new_handle) {
char* ipc_ptr; char* ipc_ptr;
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle((void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle(
cudaIpcMemLazyEnablePeerAccess)); (void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr; it->second = ipc_ptr;
} }
return it->second; return it->second;
...@@ -138,8 +148,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) { ...@@ -138,8 +148,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
// rank 1 may get the same input address for the second allreduce, but rank 2 // rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting // got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small. // mechanism so overhead should be small.
void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, void register_graph_buffers(
const std::vector<std::vector<int64_t>>& offsets) { fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa); AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
std::vector<std::string> handle_bytes; std::vector<std::string> handle_bytes;
handle_bytes.reserve(handles.size()); handle_bytes.reserve(handles.size());
......
...@@ -23,15 +23,18 @@ limitations under the License. ...@@ -23,15 +23,18 @@ limitations under the License.
#define THREADS_PER_BLOCK 128 #define THREADS_PER_BLOCK 128
template <typename T> template <typename T>
__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d] __global__ void lightning_attention_decode_kernel(
const T* __restrict__ k, // [b, h, 1, d] const T* __restrict__ q, // [b, h, 1, d]
const T* __restrict__ v, // [b, h, 1, e] const T* __restrict__ k, // [b, h, 1, d]
const float* __restrict__ past_kv, // [b, h, d, e] const T* __restrict__ v, // [b, h, 1, e]
const float* __restrict__ slope, // [h, 1, 1] const float* __restrict__ past_kv, // [b, h, d, e]
T* __restrict__ output, // [b, h, 1, e] const float* __restrict__ slope, // [h, 1, 1]
float* __restrict__ new_kv, // [b, h, d, e] T* __restrict__ output, // [b, h, 1, e]
const int batch_size, const int num_heads, const int qk_dim, float* __restrict__ new_kv, // [b, h, d, e]
const int v_dim) { const int batch_size,
const int num_heads,
const int qk_dim,
const int v_dim) {
extern __shared__ char smem[]; extern __shared__ char smem[];
T* __restrict__ q_shared = reinterpret_cast<T*>(smem); T* __restrict__ q_shared = reinterpret_cast<T*>(smem);
T* __restrict__ k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T)); T* __restrict__ k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
...@@ -109,9 +112,14 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q, ...@@ -109,9 +112,14 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
} }
} }
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, void lightning_attention_decode(
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, const torch::Tensor& q,
torch::Tensor new_kv) { const torch::Tensor& k,
const torch::Tensor& v,
const torch::Tensor& past_kv,
const torch::Tensor& slope,
torch::Tensor output,
torch::Tensor new_kv) {
TORCH_CHECK(q.is_contiguous(), "q must be contiguous"); TORCH_CHECK(q.is_contiguous(), "q must be contiguous");
TORCH_CHECK(k.is_contiguous(), "k must be contiguous"); TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
TORCH_CHECK(v.is_contiguous(), "v must be contiguous"); TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
...@@ -131,8 +139,16 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, ...@@ -131,8 +139,16 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k,
at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] { at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] {
size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float); size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float);
lightning_attention_decode_kernel<scalar_t><<<grid, block, smem_size, stream>>>( lightning_attention_decode_kernel<scalar_t><<<grid, block, smem_size, stream>>>(
q.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), v.data_ptr<scalar_t>(), past_kv.data_ptr<float>(), q.data_ptr<scalar_t>(),
slope.data_ptr<float>(), output.data_ptr<scalar_t>(), new_kv.data_ptr<float>(), batch_size, num_heads, k.data_ptr<scalar_t>(),
qk_dim, v_dim); v.data_ptr<scalar_t>(),
past_kv.data_ptr<float>(),
slope.data_ptr<float>(),
output.data_ptr<scalar_t>(),
new_kv.data_ptr<float>(),
batch_size,
num_heads,
qk_dim,
v_dim);
})); }));
} }
...@@ -25,9 +25,15 @@ namespace cutlass { ...@@ -25,9 +25,15 @@ namespace cutlass {
namespace epilogue { namespace epilogue {
namespace threadblock { namespace threadblock {
template <typename ThreadblockShape_, int ThreadCount, typename ScaleTileIterator_, typename OutputTileIterator_, template <
typename ElementAccumulator_, typename ElementCompute_, typename ElementwiseFunctor_, typename ThreadblockShape_,
bool UseMasking_ = false> int ThreadCount,
typename ScaleTileIterator_,
typename OutputTileIterator_,
typename ElementAccumulator_,
typename ElementCompute_,
typename ElementwiseFunctor_,
bool UseMasking_ = false>
class EpilogueVisitorPerRowPerCol { class EpilogueVisitorPerRowPerCol {
public: public:
using ThreadblockShape = ThreadblockShape_; using ThreadblockShape = ThreadblockShape_;
...@@ -69,8 +75,11 @@ class EpilogueVisitorPerRowPerCol { ...@@ -69,8 +75,11 @@ class EpilogueVisitorPerRowPerCol {
Arguments(typename ElementwiseFunctor::Params elementwise_) Arguments(typename ElementwiseFunctor::Params elementwise_)
: elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} : elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, int64_t batch_stride_C_, Arguments(
int64_t batch_stride_D_) typename ElementwiseFunctor::Params elementwise_,
int64_t batch_stride_alpha_,
int64_t batch_stride_C_,
int64_t batch_stride_D_)
: elementwise(elementwise_), : elementwise(elementwise_),
batch_stride_alpha(batch_stride_alpha_), batch_stride_alpha(batch_stride_alpha_),
batch_stride_C(batch_stride_C_), batch_stride_C(batch_stride_C_),
...@@ -131,17 +140,26 @@ class EpilogueVisitorPerRowPerCol { ...@@ -131,17 +140,26 @@ class EpilogueVisitorPerRowPerCol {
public: public:
CUTLASS_DEVICE CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, EpilogueVisitorPerRowPerCol(
cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, Params const& params,
typename ScaleTileIterator::Params params_alpha_col, SharedStorage& shared_storage,
typename OutputTileIterator::Params params_C, cutlass::MatrixCoord const& problem_size,
typename OutputTileIterator::Params params_D, bool with_bias, bool per_token_quant, int thread_idx,
bool per_channel_quant, AlphaScaleElementType* ptr_alpha_row, int warp_idx,
AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C, int lane_idx,
typename OutputTileIterator::Element* ptr_D, typename ScaleTileIterator::Params params_alpha_col,
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), typename OutputTileIterator::Params params_C,
int column_offset = 0, typename OutputTileIterator::Params params_D,
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) bool with_bias,
bool per_token_quant,
bool per_channel_quant,
AlphaScaleElementType* ptr_alpha_row,
AlphaScaleElementType* ptr_alpha_col,
typename OutputTileIterator::Element* ptr_C,
typename OutputTileIterator::Element* ptr_D,
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
int column_offset = 0,
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
: params_(params), : params_(params),
shared_storage_(shared_storage), shared_storage_(shared_storage),
extent_(problem_size), extent_(problem_size),
...@@ -166,8 +184,9 @@ class EpilogueVisitorPerRowPerCol { ...@@ -166,8 +184,9 @@ class EpilogueVisitorPerRowPerCol {
/// Helper to indicate split-K behavior /// Helper to indicate split-K behavior
CUTLASS_DEVICE CUTLASS_DEVICE
void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme void set_k_partition(
int split_k_slices) { ///< Total number of split-K slices int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
} }
/// Called to set the batch index /// Called to set the batch index
...@@ -251,8 +270,8 @@ class EpilogueVisitorPerRowPerCol { ...@@ -251,8 +270,8 @@ class EpilogueVisitorPerRowPerCol {
private: private:
CUTLASS_DEVICE CUTLASS_DEVICE
ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum, ComputeFragment const& scale_col, ComputeFragment per_token_channel_scale_accumulator_(
AlphaScaleElementType const& scale_row) { ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) {
ComputeFragment result; ComputeFragment result;
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) { for (int i = 0; i < ComputeFragment::kElements; ++i) {
...@@ -263,8 +282,8 @@ class EpilogueVisitorPerRowPerCol { ...@@ -263,8 +282,8 @@ class EpilogueVisitorPerRowPerCol {
} }
CUTLASS_DEVICE CUTLASS_DEVICE
ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum, AlphaScaleElementType const& scale_col, ComputeFragment per_token_scale_accumulator_(
AlphaScaleElementType const& scale_row) { ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) {
ComputeFragment result; ComputeFragment result;
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) { for (int i = 0; i < ComputeFragment::kElements; ++i) {
......
...@@ -16,16 +16,20 @@ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelT ...@@ -16,16 +16,20 @@ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelT
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp // n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// specialized dynamic schedule For FP8 kernels with Block Scaling // specialized dynamic schedule For FP8 kernels with Block Scaling
template <int Stages_, class ClusterShape_ = Shape<_1, _1, _1>, class KernelSchedule = KernelTmaWarpSpecialized, template <
int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, int Stages_,
// while zero-value `ScaleGranularityM` indicates that scaling class ClusterShape_ = Shape<_1, _1, _1>,
// granularity is `size<0>(TileShape_MNK{})` along M. class KernelSchedule = KernelTmaWarpSpecialized,
> int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8 struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> { : MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
static_assert(cute::is_same_v<KernelSchedule, static_assert(
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>, cute::
"KernelSchedule must be one of the warp specialized policies"); is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>,
"KernelSchedule must be one of the warp specialized policies");
}; };
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
......
...@@ -159,8 +159,9 @@ class GemmUniversalBaseCompat { ...@@ -159,8 +159,9 @@ class GemmUniversalBaseCompat {
get_grid_shape_(grid_tiled_shape, gemm_k_size, args); get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" CUTLASS_TRACE_HOST(
<< " result = {" << result << "}"); " grid_tiled_shape: " << grid_tiled_shape << "\n"
<< " result = {" << result << "}");
return result; return result;
} }
...@@ -175,8 +176,8 @@ class GemmUniversalBaseCompat { ...@@ -175,8 +176,8 @@ class GemmUniversalBaseCompat {
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
if (smem_size <= (48 << 10)) { if (smem_size <= (48 << 10)) {
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel<GemmKernel>, cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GemmKernel::kThreadCount, smem_size); &max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
if (result == cudaSuccess) { if (result == cudaSuccess) {
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
...@@ -184,12 +185,12 @@ class GemmUniversalBaseCompat { ...@@ -184,12 +185,12 @@ class GemmUniversalBaseCompat {
} }
} else { } else {
// Query assuming zero shared memory then compute occupancy limit based on SMEM // Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel<GemmKernel>, cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GemmKernel::kThreadCount, 0); &max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
if (result != cudaSuccess) { if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " CUTLASS_TRACE_HOST(
<< cudaGetErrorString(result)); " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1; return -1;
} }
...@@ -226,8 +227,9 @@ class GemmUniversalBaseCompat { ...@@ -226,8 +227,9 @@ class GemmUniversalBaseCompat {
/// Initializes GEMM state from arguments. /// Initializes GEMM state from arguments.
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " CUTLASS_TRACE_HOST(
<< workspace << ", stream: " << (stream ? "non-null" : "null")); "GemmUniversalBaseCompat::initialize() - workspace " << workspace
<< ", stream: " << (stream ? "non-null" : "null"));
size_t workspace_bytes = get_workspace_size(args); size_t workspace_bytes = get_workspace_size(args);
......
...@@ -32,10 +32,11 @@ namespace kernel { ...@@ -32,10 +32,11 @@ namespace kernel {
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate template <
typename Epilogue_, ///! Epilogue typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename ThreadblockSwizzle_ ///! Threadblock swizzling function typename Epilogue_, ///! Epilogue
> typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmWithEpilogueVisitor { struct GemmWithEpilogueVisitor {
public: public:
using Mma = Mma_; using Mma = Mma_;
...@@ -119,9 +120,15 @@ struct GemmWithEpilogueVisitor { ...@@ -119,9 +120,15 @@ struct GemmWithEpilogueVisitor {
Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {} Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {}
/// constructs an arguments structure /// constructs an arguments structure
Arguments(GemmCoord problem_size_, TensorRefA ref_A_, TensorRefB ref_B_, TensorRefAlphaCol ref_alpha_col_, Arguments(
TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, GemmCoord problem_size_,
typename EpilogueVisitor::Arguments epilogue_visitor_) TensorRefA ref_A_,
TensorRefB ref_B_,
TensorRefAlphaCol ref_alpha_col_,
TensorRefAlphaRow ref_alpha_row_,
TensorRefC ref_C_,
TensorRefC ref_D_,
typename EpilogueVisitor::Arguments epilogue_visitor_)
: mode(GemmUniversalMode::kGemm), : mode(GemmUniversalMode::kGemm),
problem_size(problem_size_), problem_size(problem_size_),
batch_count(1), batch_count(1),
...@@ -269,8 +276,9 @@ struct GemmWithEpilogueVisitor { ...@@ -269,8 +276,9 @@ struct GemmWithEpilogueVisitor {
isAMisaligned = problem_size.k() % kAlignmentA; isAMisaligned = problem_size.k() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) { } else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
isAMisaligned = problem_size.m() % kAlignmentA; isAMisaligned = problem_size.m() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value || } else if (
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) { platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value ||
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
isAMisaligned = problem_size.k() % kAlignmentA; isAMisaligned = problem_size.k() % kAlignmentA;
} }
...@@ -278,8 +286,9 @@ struct GemmWithEpilogueVisitor { ...@@ -278,8 +286,9 @@ struct GemmWithEpilogueVisitor {
isBMisaligned = problem_size.n() % kAlignmentB; isBMisaligned = problem_size.n() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) { } else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
isBMisaligned = problem_size.k() % kAlignmentB; isBMisaligned = problem_size.k() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value || } else if (
platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) { platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value ||
platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
isBMisaligned = problem_size.k() % kAlignmentB; isBMisaligned = problem_size.k() % kAlignmentB;
} }
...@@ -287,8 +296,9 @@ struct GemmWithEpilogueVisitor { ...@@ -287,8 +296,9 @@ struct GemmWithEpilogueVisitor {
isCMisaligned = problem_size.n() % kAlignmentC; isCMisaligned = problem_size.n() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) { } else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
isCMisaligned = problem_size.m() % kAlignmentC; isCMisaligned = problem_size.m() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value || } else if (
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) { platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value ||
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
isCMisaligned = problem_size.n() % kAlignmentC; isCMisaligned = problem_size.n() % kAlignmentC;
} }
...@@ -373,11 +383,11 @@ struct GemmWithEpilogueVisitor { ...@@ -373,11 +383,11 @@ struct GemmWithEpilogueVisitor {
int thread_idx = threadIdx.x; int thread_idx = threadIdx.x;
// Construct iterators to A and B operands // Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, typename Mma::IteratorA iterator_A(
tb_offset_A); params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, typename Mma::IteratorB iterator_B(
tb_offset_B); params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
// Broadcast the warp_id computed by lane 0 to ensure dependent code // Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform. // is compiled as warp-uniform.
...@@ -409,8 +419,8 @@ struct GemmWithEpilogueVisitor { ...@@ -409,8 +419,8 @@ struct GemmWithEpilogueVisitor {
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle // assume identity swizzle
MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM, MatrixCoord threadblock_offset(
threadblock_tile_offset.n() * Mma::Shape::kN); threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
...@@ -423,11 +433,25 @@ struct GemmWithEpilogueVisitor { ...@@ -423,11 +433,25 @@ struct GemmWithEpilogueVisitor {
with_bias = false; with_bias = false;
} }
EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, params.problem_size.mn(), EpilogueVisitor epilogue_visitor(
thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, params.epilogue_visitor,
params.params_D, with_bias, true, true, params.ptr_alpha_row, params.ptr_alpha_col, shared_storage.epilogue.visitor,
params.ptr_C, params.ptr_D, threadblock_offset, params.problem_size.mn(),
blockIdx.y * params.problem_size.m()); thread_idx,
warp_idx,
lane_idx,
params.params_alpha_col,
params.params_C,
params.params_D,
with_bias,
true,
true,
params.ptr_alpha_row,
params.ptr_alpha_col,
params.ptr_C,
params.ptr_D,
threadblock_offset,
blockIdx.y * params.problem_size.m());
if (params.mode == GemmUniversalMode::kGemm) { if (params.mode == GemmUniversalMode::kGemm) {
// Indicate which position in a serial reduction the output operator is currently updating // Indicate which position in a serial reduction the output operator is currently updating
......
...@@ -21,10 +21,13 @@ ...@@ -21,10 +21,13 @@
#include "utils.h" #include "utils.h"
static void check_group_count(const std::vector<torch::Tensor>& inputs, const std::vector<torch::Tensor>& weights, static void check_group_count(
const std::vector<torch::Tensor>& outputs) { const std::vector<torch::Tensor>& inputs,
TORCH_CHECK(((inputs.size() == weights.size()) && (inputs.size() == outputs.size())), const std::vector<torch::Tensor>& weights,
"The group count of inputs, weights and outputs should be the same."); const std::vector<torch::Tensor>& outputs) {
TORCH_CHECK(
((inputs.size() == weights.size()) && (inputs.size() == outputs.size())),
"The group count of inputs, weights and outputs should be the same.");
} }
static void check_device_dtype(const torch::Dtype& dtype, const std::vector<torch::Tensor>& tensors) { static void check_device_dtype(const torch::Dtype& dtype, const std::vector<torch::Tensor>& tensors) {
...@@ -68,21 +71,26 @@ static std::vector<void*> get_tensor_ptrs(const std::vector<torch::Tensor>& tens ...@@ -68,21 +71,26 @@ static std::vector<void*> get_tensor_ptrs(const std::vector<torch::Tensor>& tens
static torch::Tensor create_ptr_pointer(const std::vector<void*>& ptrs, cudaStream_t stream) { static torch::Tensor create_ptr_pointer(const std::vector<void*>& ptrs, cudaStream_t stream) {
auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA);
torch::Tensor gpu_ptrs = torch::empty({static_cast<int>(ptrs.size())}, options); torch::Tensor gpu_ptrs = torch::empty({static_cast<int>(ptrs.size())}, options);
TORCH_CHECK(cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, TORCH_CHECK(
stream) == CUBLAS_STATUS_SUCCESS); cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, stream) ==
CUBLAS_STATUS_SUCCESS);
return gpu_ptrs; return gpu_ptrs;
} }
// We want compute input @ weight^T in row major // We want compute input @ weight^T in row major
// This is equivalent to computing weight @ input^T in col major // This is equivalent to computing weight @ input^T in col major
// Cublas only accepts matrix in column major, so this arrangement is needed // Cublas only accepts matrix in column major, so this arrangement is needed
void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, // b: (m, k) row major = (k, m) col major void cublas_grouped_gemm(
const std::vector<torch::Tensor>& weights, // a: (n, k) row major = (n, k)^T col major const std::vector<torch::Tensor>& inputs, // b: (m, k) row major = (k, m) col major
const std::vector<torch::Tensor>& outputs, // c: (m, n) row major = (n, m) col major const std::vector<torch::Tensor>& weights, // a: (n, k) row major = (n, k)^T col major
const torch::Dtype& out_dtype, int64_t cublas_handle, int64_t cuda_stream) { const std::vector<torch::Tensor>& outputs, // c: (m, n) row major = (n, m) col major
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, const torch::Dtype& out_dtype,
"cublas grouped_gemm can" int64_t cublas_handle,
"only be applied to float16 and bfloat16 dtype"); int64_t cuda_stream) {
TORCH_CHECK(
out_dtype == torch::kHalf || out_dtype == torch::kBFloat16,
"cublas grouped_gemm can"
"only be applied to float16 and bfloat16 dtype");
int group_count = inputs.size(); int group_count = inputs.size();
check_group_count(inputs, weights, outputs); check_group_count(inputs, weights, outputs);
...@@ -133,16 +141,32 @@ void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, // b: (m, k ...@@ -133,16 +141,32 @@ void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, // b: (m, k
torch::Tensor d_c = create_ptr_pointer(c_array, stream); torch::Tensor d_c = create_ptr_pointer(c_array, stream);
#if defined CUDA_VERSION && CUDA_VERSION >= 12050 #if defined CUDA_VERSION && CUDA_VERSION >= 12050
auto status = cublasGemmGroupedBatchedEx(handle, transa_array.data(), transb_array.data(), m_array.data(), auto status = cublasGemmGroupedBatchedEx(
n_array.data(), k_array.data(), alpha_array.data(), (void**)d_a.data_ptr(), handle,
cuda_data_type, lda_array.data(), (void**)d_b.data_ptr(), cuda_data_type, transa_array.data(),
ldb_array.data(), beta_array.data(), (void**)d_c.data_ptr(), cuda_data_type, transb_array.data(),
ldc_array.data(), group_count, group_size.data(), CUBLAS_COMPUTE_32F); m_array.data(),
n_array.data(),
k_array.data(),
alpha_array.data(),
(void**)d_a.data_ptr(),
cuda_data_type,
lda_array.data(),
(void**)d_b.data_ptr(),
cuda_data_type,
ldb_array.data(),
beta_array.data(),
(void**)d_c.data_ptr(),
cuda_data_type,
ldc_array.data(),
group_count,
group_size.data(),
CUBLAS_COMPUTE_32F);
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status)); TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status));
TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization"); TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization");
return; return;
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED(false, TORCH_CHECK_NOT_IMPLEMENTED(
"Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion()); false, "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion());
} }
...@@ -35,8 +35,12 @@ ...@@ -35,8 +35,12 @@
using namespace cute; using namespace cute;
template <typename OutType, typename TileShape, typename ClusterShape, int ScaleGranularityM = 1> template <typename OutType, typename TileShape, typename ClusterShape, int ScaleGranularityM = 1>
void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, void launch_sm90_fp8_blockwise_scaled_mm(
const torch::Tensor& scales_a, const torch::Tensor& scales_b) { torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b) {
using ElementAccumulator = float; using ElementAccumulator = float;
using ElementCompute = float; using ElementCompute = float;
using ElementBlockScale = float; using ElementBlockScale = float;
...@@ -66,19 +70,43 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor ...@@ -66,19 +70,43 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
using KernelSchedule = using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>; cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementCompute, ElementC, ArchTag,
LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, EpilogueSchedule, StoreEpilogueCompute>::CollectiveOp; OperatorClass,
TileShape,
ClusterShape,
EpilogueTileType,
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
AlignmentC,
ElementD,
LayoutD,
AlignmentD,
EpilogueSchedule,
StoreEpilogueCompute>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, ArchTag,
TileShape, ClusterShape, OperatorClass,
ElementA,
LayoutA,
AlignmentA,
ElementB,
LayoutB,
AlignmentB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>, sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp; KernelSchedule>::CollectiveOp;
using GemmKernel = using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>; CollectiveMainloop,
CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
Gemm gemm_op; Gemm gemm_op;
...@@ -127,16 +155,23 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor ...@@ -127,16 +155,23 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
} }
template <typename OutType> template <typename OutType>
void sm90_fp8_blockwise_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, void sm90_fp8_blockwise_dispatch_shape(
const torch::Tensor& scales_a, const torch::Tensor& scales_b) { torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b) {
using TileShape = Shape<_128, _128, _128>; using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
launch_sm90_fp8_blockwise_scaled_mm<OutType, TileShape, ClusterShape>(out, a, b, scales_a, scales_b); launch_sm90_fp8_blockwise_scaled_mm<OutType, TileShape, ClusterShape>(out, a, b, scales_a, scales_b);
} }
torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, torch::Tensor fp8_blockwise_scaled_mm(
const torch::Tensor& scales_a, const torch::Tensor& scales_b, const torch::Tensor& mat_a,
const torch::Dtype& out_dtype) { const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype) {
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
...@@ -145,10 +180,10 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T ...@@ -145,10 +180,10 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0, TORCH_CHECK(
"mat_a must be multiple of 16 bytes for memory alignment"); (mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0, TORCH_CHECK(
"mat_b must be multiple of 16 bytes for memory alignment"); (mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
...@@ -186,6 +221,6 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T ...@@ -186,6 +221,6 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
#endif #endif
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED(false, TORCH_CHECK_NOT_IMPLEMENTED(
"No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version); false, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
} }
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
#include "utils.h" #include "utils.h"
template <typename T> template <typename T>
__global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, __global__ void
const int64_t num_elements) { per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, const int64_t num_elements) {
float max_value = 0.0f; float max_value = 0.0f;
unsigned int tid = threadIdx.x; unsigned int tid = threadIdx.x;
unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x; unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -56,8 +56,11 @@ __global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __r ...@@ -56,8 +56,11 @@ __global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __r
} }
template <typename T> template <typename T>
__global__ void per_tensor_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output, __global__ void per_tensor_quant_fp8_kernel(
const float* __restrict__ scale, const int64_t num_elements) { const T* __restrict__ input,
FP8_TYPE* __restrict__ output,
const float* __restrict__ scale,
const int64_t num_elements) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x; const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int grid_size = blockDim.x * gridDim.x; const int grid_size = blockDim.x * gridDim.x;
const float scale_val = 1.0f / (*scale); const float scale_val = 1.0f / (*scale);
...@@ -124,8 +127,10 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch ...@@ -124,8 +127,10 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
} }
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>( per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()), static_cast<scalar_t*>(input.data_ptr()),
static_cast<float*>(output_s.data_ptr()), num_elements); static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
num_elements);
return true; return true;
}); });
} }
...@@ -17,10 +17,15 @@ __device__ __forceinline__ float GroupReduce(float val, const int tid) { ...@@ -17,10 +17,15 @@ __device__ __forceinline__ float GroupReduce(float val, const int tid) {
} }
template <typename T> template <typename T>
__global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, void* __restrict__ output_q, __global__ void per_token_group_quant_fp8_kernel(
float* __restrict__ output_s, const int group_size, const T* __restrict__ input,
const int num_groups, const float eps, const float fp8_min, void* __restrict__ output_q,
const float fp8_max) { float* __restrict__ output_s,
const int group_size,
const int num_groups,
const float eps,
const float fp8_min,
const float fp8_max) {
const int groups_per_block = 16; const int groups_per_block = 16;
const int local_group_id = threadIdx.x / 16; const int local_group_id = threadIdx.x / 16;
const int lane_id = threadIdx.x % 16; const int lane_id = threadIdx.x % 16;
...@@ -80,8 +85,14 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo ...@@ -80,8 +85,14 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo
} }
} }
void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, void sgl_per_token_group_quant_fp8(
int64_t group_size, double eps, double fp8_min, double fp8_max) { torch::Tensor input,
torch::Tensor output_q,
torch::Tensor output_s,
int64_t group_size,
double eps,
double fp8_min,
double fp8_max) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(output_q); CHECK_INPUT(output_q);
CHECK_INPUT(output_s); CHECK_INPUT(output_s);
...@@ -97,8 +108,14 @@ void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q, ...@@ -97,8 +108,14 @@ void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q,
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
per_token_group_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>( per_token_group_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), output_q.data_ptr(), static_cast<float*>(output_s.data_ptr()), static_cast<scalar_t*>(input.data_ptr()),
group_size, num_groups, (float)eps, (float)fp8_min, (float)fp8_max); output_q.data_ptr(),
static_cast<float*>(output_s.data_ptr()),
group_size,
num_groups,
(float)eps,
(float)fp8_min,
(float)fp8_max);
return true; return true;
}); });
} }
...@@ -7,9 +7,12 @@ ...@@ -7,9 +7,12 @@
#include "utils.h" #include "utils.h"
template <typename T> template <typename T>
__global__ void per_token_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output_q, __global__ void per_token_quant_fp8_kernel(
float* __restrict__ output_s, const int64_t hidden_dim, const T* __restrict__ input,
const int64_t num_tokens) { FP8_TYPE* __restrict__ output_q,
float* __restrict__ output_s,
const int64_t hidden_dim,
const int64_t num_tokens) {
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
if (token_idx >= num_tokens) return; if (token_idx >= num_tokens) return;
...@@ -110,8 +113,11 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: ...@@ -110,8 +113,11 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
per_token_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>( per_token_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()), static_cast<scalar_t*>(input.data_ptr()),
static_cast<float*>(output_s.data_ptr()), hidden_dim, num_tokens); static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
return true; return true;
}); });
} }
...@@ -25,9 +25,11 @@ limitations under the License. ...@@ -25,9 +25,11 @@ limitations under the License.
#define WARP_SIZE 32 #define WARP_SIZE 32
template <typename scalar_t> template <typename scalar_t>
__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids, __global__ void count_and_sort_expert_tokens_kernel(
int32_t* __restrict__ sorted_token_ids, const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ cumsum_buffer, size_t numel) { int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ cumsum_buffer,
size_t numel) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x; const size_t stride = blockDim.x * gridDim.x;
...@@ -39,10 +41,15 @@ __global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ ...@@ -39,10 +41,15 @@ __global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids, __global__ void moe_align_block_size_kernel(
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, int32_t* __restrict__ sorted_token_ids,
int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) { int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad,
int32_t num_experts,
int32_t block_size,
size_t numel,
int32_t* __restrict__ cumsum) {
__shared__ int32_t shared_counts[WARP_SIZE][8]; __shared__ int32_t shared_counts[WARP_SIZE][8];
const int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
...@@ -91,17 +98,29 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id ...@@ -91,17 +98,29 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id
} }
} }
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, void moe_align_block_size(
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, torch::Tensor topk_ids,
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { int64_t num_experts,
int64_t block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer,
torch::Tensor cumsum_buffer) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now."); TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now.");
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto align_kernel = moe_align_block_size_kernel<scalar_t>; auto align_kernel = moe_align_block_size_kernel<scalar_t>;
align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(), align_kernel<<<1, 1024, 0, stream>>>(
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(), topk_ids.data_ptr<scalar_t>(),
num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>()); sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel(),
cumsum_buffer.data_ptr<int32_t>());
const int block_threads = 256; const int block_threads = 256;
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
...@@ -109,8 +128,10 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b ...@@ -109,8 +128,10 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
const int actual_blocks = std::min(num_blocks, max_blocks); const int actual_blocks = std::min(num_blocks, max_blocks);
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>; auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
sorted_token_ids.data_ptr<int32_t>(), topk_ids.data_ptr<scalar_t>(),
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel()); sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(),
topk_ids.numel());
}); });
} }
...@@ -23,10 +23,18 @@ ...@@ -23,10 +23,18 @@
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = // tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b, // [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token] // draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
__global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len, __global__ void build_tree_efficient(
bool* tree_mask, int64_t* positions, int64_t* retrive_index, int64_t* parent_list,
int64_t* retrive_next_token, int64_t* retrive_next_sibling, int topk, int depth, int64_t* selected_index,
int draft_token_num) { int32_t* verified_seq_len,
bool* tree_mask,
int64_t* positions,
int64_t* retrive_index,
int64_t* retrive_next_token,
int64_t* retrive_next_sibling,
int topk,
int depth,
int draft_token_num) {
int bid = blockIdx.x; int bid = blockIdx.x;
int tid = threadIdx.x; int tid = threadIdx.x;
...@@ -99,10 +107,18 @@ __global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_ind ...@@ -99,10 +107,18 @@ __global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_ind
} }
} }
void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, void build_tree_kernel_efficient(
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, at::Tensor parent_list,
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk, at::Tensor selected_index,
int64_t depth, int64_t draft_token_num) { at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
int64_t topk,
int64_t depth,
int64_t draft_token_num) {
// TODO (ying) check shape // TODO (ying) check shape
// TODO (ying) check type // TODO (ying) check type
int bs = parent_list.size(0); int bs = parent_list.size(0);
...@@ -111,11 +127,17 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind ...@@ -111,11 +127,17 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
build_tree_efficient<<<grid, block, 0, stream>>>( build_tree_efficient<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()), static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int32_t*>(verified_seq_len.data_ptr()), static_cast<bool*>(tree_mask.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()), static_cast<int64_t*>(retrive_index.data_ptr()), static_cast<int32_t*>(verified_seq_len.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()), static_cast<int64_t*>(retrive_next_sibling.data_ptr()), static_cast<bool*>(tree_mask.data_ptr()),
int32_t(topk), int32_t(depth), int32_t(draft_token_num)); static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num));
} }
// parent_list [bs, topk * (depth - 1) + 1)] // parent_list [bs, topk * (depth - 1) + 1)]
...@@ -124,8 +146,16 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind ...@@ -124,8 +146,16 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = // tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b, // [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token, depth + 2] // draft_token, depth + 2]
__global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len, bool* tree_mask, __global__ void build_tree(
int64_t* positions, int64_t* retrive_index, int topk, int depth, int draft_token_num) { int64_t* parent_list,
int64_t* selected_index,
int32_t* verified_seq_len,
bool* tree_mask,
int64_t* positions,
int64_t* retrive_index,
int topk,
int depth,
int draft_token_num) {
int bid = blockIdx.x; int bid = blockIdx.x;
int tid = threadIdx.x; int tid = threadIdx.x;
...@@ -191,9 +221,16 @@ __global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_ ...@@ -191,9 +221,16 @@ __global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_
} }
} }
void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, void build_tree_kernel(
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk, at::Tensor parent_list,
int64_t depth, int64_t draft_token_num) { at::Tensor selected_index,
at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
int64_t topk,
int64_t depth,
int64_t draft_token_num) {
// TODO (ying) check shape // TODO (ying) check shape
// TODO (ying) check type // TODO (ying) check type
int bs = parent_list.size(0); int bs = parent_list.size(0);
...@@ -202,8 +239,13 @@ void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Te ...@@ -202,8 +239,13 @@ void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Te
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
build_tree<<<grid, block, 0, stream>>>( build_tree<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()), static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int32_t*>(verified_seq_len.data_ptr()), static_cast<bool*>(tree_mask.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()), static_cast<int64_t*>(retrive_index.data_ptr()), int32_t(topk), static_cast<int32_t*>(verified_seq_len.data_ptr()),
int32_t(depth), int32_t(draft_token_num)); static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num));
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment