Unverified Commit d052f4c8 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

New clang format for sgl kernel (#4194)

parent e1aaa79a
cp ../README.md ../LICENSE .
rm -rf dist
python3 -m build
python3 -m twine upload dist/*
rm -rf README.md LICENSE
...@@ -6,3 +6,10 @@ DerivePointerAlignment: false ...@@ -6,3 +6,10 @@ DerivePointerAlignment: false
PointerAlignment: Left PointerAlignment: Left
NamespaceIndentation: None NamespaceIndentation: None
SortIncludes: true SortIncludes: true
AllowShortLoopsOnASingleLine: false
BinPackParameters: false # Prevents packing parameters in declarations
BinPackArguments: false # Prevents packing arguments in function calls
AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis
AlignOperands: Align # Aligns arguments vertically
PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument
PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name
...@@ -41,10 +41,15 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T ...@@ -41,10 +41,15 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
// support float16, bfloat16 and float32 // support float16, bfloat16 and float32
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::FusedAddRMSNorm( cudaError_t status = norm::FusedAddRMSNorm(
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()), static_cast<c_type*>(input.data_ptr()),
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); static_cast<c_type*>(residual.data_ptr()),
TORCH_CHECK(status == cudaSuccess, static_cast<c_type*>(weight.data_ptr()),
"FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); batch_size,
hidden_size,
eps,
torch_current_stream);
TORCH_CHECK(
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
return true; return true;
}); });
} }
...@@ -153,19 +153,20 @@ DINLINE O downcast(array_t<float, O::size> val) { ...@@ -153,19 +153,20 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against // prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes. // other volatile writes.
template <int ngpus> template <int ngpus>
DINLINE void start_sync(const RankSignals& sg, DINLINE void start_sync(
const RankSignals& sg,
#ifndef USE_ROCM #ifndef USE_ROCM
volatile volatile
#endif #endif
Signal* self_sg, Signal* self_sg,
int rank) { int rank) {
#ifdef USE_ROCM #ifdef USE_ROCM
uint32_t flag = self_sg->_flag[blockIdx.x] + 1; uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks. // simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write // Latency = 1 p2p write
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __scoped_atomic_store_n(
__MEMORY_SCOPE_SYSTEM); &sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks // wait until we got true from all ranks
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) < while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) <
flag) flag)
...@@ -193,12 +194,13 @@ DINLINE void start_sync(const RankSignals& sg, ...@@ -193,12 +194,13 @@ DINLINE void start_sync(const RankSignals& sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier, // barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses. // we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false> template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals& sg, DINLINE void end_sync(
const RankSignals& sg,
#ifndef USE_ROCM #ifndef USE_ROCM
volatile volatile
#endif #endif
Signal* self_sg, Signal* self_sg,
int rank) { int rank) {
#ifdef USE_ROCM #ifdef USE_ROCM
__syncthreads(); __syncthreads();
// eliminate the case that prior writes are not visible after signals become // eliminate the case that prior writes are not visible after signals become
...@@ -209,11 +211,16 @@ DINLINE void end_sync(const RankSignals& sg, ...@@ -209,11 +211,16 @@ DINLINE void end_sync(const RankSignals& sg,
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks. // simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write // Latency = 1 p2p write
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag, __scoped_atomic_store_n(
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, __MEMORY_SCOPE_SYSTEM); &sg.signals[threadIdx.x]->end[blockIdx.x][rank],
flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
__MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks // wait until we got true from all ranks
while (__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], while (__scoped_atomic_load_n(
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, __MEMORY_SCOPE_DEVICE) < flag) &self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
__MEMORY_SCOPE_DEVICE) < flag)
; ;
} }
__syncthreads(); __syncthreads();
...@@ -251,12 +258,16 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { ...@@ -251,12 +258,16 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
} }
template <typename T, int ngpus> template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData* _dp, RankSignals sg, __global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
RankData* _dp,
RankSignals sg,
#ifndef USE_ROCM #ifndef USE_ROCM
volatile volatile
#endif #endif
Signal* self_sg, Signal* self_sg,
T* __restrict__ result, int rank, int size) { T* __restrict__ result,
int rank,
int size) {
using P = typename packed_t<T>::P; using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A; using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same // note: we don't reorder the address so the accumulation order is the same
...@@ -280,12 +291,16 @@ DINLINE P* get_tmp_buf(volatile Signal* sg) { ...@@ -280,12 +291,16 @@ DINLINE P* get_tmp_buf(volatile Signal* sg) {
} }
template <typename T, int ngpus> template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData* _dp, RankSignals sg, __global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
RankData* _dp,
RankSignals sg,
#ifndef USE_ROCM #ifndef USE_ROCM
volatile volatile
#endif #endif
Signal* self_sg, Signal* self_sg,
T* __restrict__ result, int rank, int size) { T* __restrict__ result,
int rank,
int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x; int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P; using P = typename packed_t<T>::P;
...@@ -357,8 +372,14 @@ class CustomAllreduce { ...@@ -357,8 +372,14 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers * note: this class does not own any device memory. Any required buffers
* are passed in from the constructor * are passed in from the constructor
*/ */
CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, const hipIpcMemHandle_t* handles, CustomAllreduce(
const std::vector<int64_t>& offsets, int rank, bool full_nvlink = true) Signal* meta,
void* rank_data,
size_t rank_data_sz,
const hipIpcMemHandle_t* handles,
const std::vector<int64_t>& offsets,
int rank,
bool full_nvlink = true)
: rank_(rank), : rank_(rank),
world_size_(offsets.size()), world_size_(offsets.size()),
full_nvlink_(full_nvlink), full_nvlink_(full_nvlink),
...@@ -382,8 +403,8 @@ class CustomAllreduce { ...@@ -382,8 +403,8 @@ class CustomAllreduce {
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) { if (new_handle) {
char* ipc_ptr; char* ipc_ptr;
CUDACHECK(hipIpcOpenMemHandle((void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), CUDACHECK(hipIpcOpenMemHandle(
hipIpcMemLazyEnablePeerAccess)); (void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), hipIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr; it->second = ipc_ptr;
} }
return it->second; return it->second;
...@@ -399,13 +420,14 @@ class CustomAllreduce { ...@@ -399,13 +420,14 @@ class CustomAllreduce {
void* base_ptr; void* base_ptr;
// note: must share the base address of each allocation, or we get wrong // note: must share the base address of each allocation, or we get wrong
// address // address
if (hipPointerGetAttribute(&base_ptr, if (hipPointerGetAttribute(
&base_ptr,
#ifdef USE_ROCM #ifdef USE_ROCM
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR, HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#else #else
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#endif #endif
(hipDeviceptr_t)ptr) != hipSuccess) (hipDeviceptr_t)ptr) != hipSuccess)
throw std::runtime_error("failed to get pointer attr"); throw std::runtime_error("failed to get pointer attr");
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr); offsets[i] = ((char*)ptr) - ((char*)base_ptr);
...@@ -415,8 +437,8 @@ class CustomAllreduce { ...@@ -415,8 +437,8 @@ class CustomAllreduce {
void check_rank_data_capacity(size_t num = 1) { void check_rank_data_capacity(size_t num = 1) {
if (d_rank_data_base_ + num > d_rank_data_end_) if (d_rank_data_base_ + num > d_rank_data_end_)
throw std::runtime_error("Rank data buffer is overflowed by " + throw std::runtime_error(
std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); "Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
} }
void register_buffer(const std::vector<std::string>& handles, const std::vector<int64_t>& offsets, void* self) { void register_buffer(const std::vector<std::string>& handles, const std::vector<int64_t>& offsets, void* self) {
...@@ -443,8 +465,8 @@ class CustomAllreduce { ...@@ -443,8 +465,8 @@ class CustomAllreduce {
// rank 1 may get the same input address for the second allreduce, but rank 2 // rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting // got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small. // mechanism so overhead should be small.
void register_graph_buffers(const std::vector<std::string>& handles, void
const std::vector<std::vector<int64_t>>& offsets) { register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size(); auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers); check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers); std::vector<RankData> rank_data(num_buffers);
...@@ -474,11 +496,17 @@ class CustomAllreduce { ...@@ -474,11 +496,17 @@ class CustomAllreduce {
* will cause contention on NVLink bus. * will cause contention on NVLink bus.
*/ */
template <typename T> template <typename T>
void allreduce(hipStream_t stream, T* input, T* output, int size, void allreduce(
hipStream_t stream,
T* input,
T* output,
int size,
#ifndef USE_ROCM #ifndef USE_ROCM
int threads = 512, int block_limit = 36){ int threads = 512,
int block_limit = 36){
#else #else
int threads = 512, int block_limit = 16) { int threads = 512,
int block_limit = 16) {
#endif #endif
auto d = packed_t<T>::P::size; auto d = packed_t<T>::P::size;
if (size % d != 0) if (size % d != 0)
...@@ -487,8 +515,8 @@ class CustomAllreduce { ...@@ -487,8 +515,8 @@ class CustomAllreduce {
"of " + "of " +
std::to_string(d)); std::to_string(d));
if (block_limit > kMaxBlocks) if (block_limit > kMaxBlocks)
throw std::runtime_error("max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + throw std::runtime_error(
std::to_string(block_limit)); "max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
RankData* ptrs; RankData* ptrs;
hipStreamCaptureStatus status; hipStreamCaptureStatus status;
...@@ -499,17 +527,17 @@ class CustomAllreduce { ...@@ -499,17 +527,17 @@ class CustomAllreduce {
} else { } else {
auto it = buffers_.find(input); auto it = buffers_.find(input);
if (it == buffers_.end()) if (it == buffers_.end())
throw std::runtime_error("buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + throw std::runtime_error(
" is not registered!"); "buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
ptrs = it->second; ptrs = it->second;
} }
size /= d; size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P); auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = ::min(block_limit, (size + threads - 1) / threads); int blocks = ::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) \ #define KL(ngpus, name) \
hipLaunchKernelGGL((name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, \ hipLaunchKernelGGL( \
size); (name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, size);
#define REDUCE_CASE(ngpus) \ #define REDUCE_CASE(ngpus) \
case ngpus: { \ case ngpus: { \
if (world_size_ == 2) { \ if (world_size_ == 2) { \
......
...@@ -118,8 +118,13 @@ inline __device__ int4 add128b(T& a, T& b) { ...@@ -118,8 +118,13 @@ inline __device__ int4 add128b(T& a, T& b) {
return c.packed; return c.packed;
} }
__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, __inline__ __device__ void multi_gpu_barrier(
size_t const world_size, int const tidx, int const bidx) { uint32_t** signals,
uint32_t const flag,
size_t const local_rank,
size_t const world_size,
int const tidx,
int const bidx) {
// After this function, at least one block in each GPU has reached the barrier // After this function, at least one block in each GPU has reached the barrier
if (tidx < world_size) { if (tidx < world_size) {
// we can think of signals having the shape [world_size, world_size] // we can think of signals having the shape [world_size, world_size]
...@@ -143,8 +148,14 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const ...@@ -143,8 +148,14 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
} }
template <bool start, bool need_fence = false> template <bool start, bool need_fence = false>
__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, __inline__ __device__ void block_barrier(
size_t const world_size, int const tidx, int const bidx, int const grid_size) { uint32_t** signals,
uint32_t const flag,
size_t const local_rank,
size_t const world_size,
int const tidx,
int const bidx,
int const grid_size) {
if constexpr (!start) { if constexpr (!start) {
__syncthreads(); __syncthreads();
} }
...@@ -227,8 +238,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc ...@@ -227,8 +238,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc
} }
} }
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer // wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
block_barrier<true>(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, block_barrier<true>(
grid_size); params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Each block accumulates the values from the different GPUs on the same node. // Each block accumulates the values from the different GPUs on the same node.
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) { for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
...@@ -341,8 +352,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc ...@@ -341,8 +352,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
} }
} }
} }
block_barrier<true>(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, block_barrier<true>(
grid_size); params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Each block accumulates the values from the different GPUs on the same node. // Each block accumulates the values from the different GPUs on the same node.
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
...@@ -372,8 +383,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc ...@@ -372,8 +383,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
} }
} }
block_barrier<false, true>(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, block_barrier<false, true>(
bidx, grid_size); params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Gather all needed elts from other intra-node ranks // Gather all needed elts from other intra-node ranks
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
...@@ -459,8 +470,12 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar ...@@ -459,8 +470,12 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT> template <typename T, int RANKS_PER_NODE, bool COPY_INPUT>
void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block, void dispatchARKernels(
cudaStream_t stream) { AllReduceStrategyType algo,
AllReduceParams& param,
int blocks_per_grid,
int threads_per_block,
cudaStream_t stream) {
switch (algo) { switch (algo) {
case AllReduceStrategyType::ONESHOT: { case AllReduceStrategyType::ONESHOT: {
oneShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param); oneShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
...@@ -505,8 +520,8 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy ...@@ -505,8 +520,8 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy
CHECK_CUDA_SUCCESS(cudaGetLastError()); CHECK_CUDA_SUCCESS(cudaGetLastError());
} }
void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, void trtCustomAllReduce(
cudaStream_t stream) { AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream) {
if (params.elts_total == 0) { if (params.elts_total == 0) {
return; return;
} }
......
...@@ -29,9 +29,14 @@ using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>; ...@@ -29,9 +29,14 @@ using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
class AllReduceMeta { class AllReduceMeta {
public: public:
AllReduceMeta(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers, AllReduceMeta(
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in, int64_t rank_id,
const std::vector<fptr_t>& barrier_out) { int64_t world_size,
torch::Tensor& rank_data,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out) {
this->rank_id = (int)rank_id; this->rank_id = (int)rank_id;
this->world_size = (int)world_size; this->world_size = (int)world_size;
this->barrier_in = barrier_in; this->barrier_in = barrier_in;
...@@ -86,9 +91,14 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) ...@@ -86,9 +91,14 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype)
return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0; return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0;
} }
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers, fptr_t init_custom_ar(
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in, int64_t rank_id,
const std::vector<fptr_t>& barrier_out) { int64_t world_size,
torch::Tensor& rank_data,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out) {
auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out); auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out);
return (fptr_t)m; return (fptr_t)m;
} }
...@@ -124,8 +134,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) { ...@@ -124,8 +134,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) { if (new_handle) {
char* ipc_ptr; char* ipc_ptr;
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle((void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle(
cudaIpcMemLazyEnablePeerAccess)); (void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr; it->second = ipc_ptr;
} }
return it->second; return it->second;
...@@ -138,8 +148,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) { ...@@ -138,8 +148,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
// rank 1 may get the same input address for the second allreduce, but rank 2 // rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting // got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small. // mechanism so overhead should be small.
void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, void register_graph_buffers(
const std::vector<std::vector<int64_t>>& offsets) { fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa); AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
std::vector<std::string> handle_bytes; std::vector<std::string> handle_bytes;
handle_bytes.reserve(handles.size()); handle_bytes.reserve(handles.size());
......
...@@ -23,15 +23,18 @@ limitations under the License. ...@@ -23,15 +23,18 @@ limitations under the License.
#define THREADS_PER_BLOCK 128 #define THREADS_PER_BLOCK 128
template <typename T> template <typename T>
__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d] __global__ void lightning_attention_decode_kernel(
const T* __restrict__ k, // [b, h, 1, d] const T* __restrict__ q, // [b, h, 1, d]
const T* __restrict__ v, // [b, h, 1, e] const T* __restrict__ k, // [b, h, 1, d]
const float* __restrict__ past_kv, // [b, h, d, e] const T* __restrict__ v, // [b, h, 1, e]
const float* __restrict__ slope, // [h, 1, 1] const float* __restrict__ past_kv, // [b, h, d, e]
T* __restrict__ output, // [b, h, 1, e] const float* __restrict__ slope, // [h, 1, 1]
float* __restrict__ new_kv, // [b, h, d, e] T* __restrict__ output, // [b, h, 1, e]
const int batch_size, const int num_heads, const int qk_dim, float* __restrict__ new_kv, // [b, h, d, e]
const int v_dim) { const int batch_size,
const int num_heads,
const int qk_dim,
const int v_dim) {
extern __shared__ char smem[]; extern __shared__ char smem[];
T* __restrict__ q_shared = reinterpret_cast<T*>(smem); T* __restrict__ q_shared = reinterpret_cast<T*>(smem);
T* __restrict__ k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T)); T* __restrict__ k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
...@@ -109,9 +112,14 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q, ...@@ -109,9 +112,14 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
} }
} }
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, void lightning_attention_decode(
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, const torch::Tensor& q,
torch::Tensor new_kv) { const torch::Tensor& k,
const torch::Tensor& v,
const torch::Tensor& past_kv,
const torch::Tensor& slope,
torch::Tensor output,
torch::Tensor new_kv) {
TORCH_CHECK(q.is_contiguous(), "q must be contiguous"); TORCH_CHECK(q.is_contiguous(), "q must be contiguous");
TORCH_CHECK(k.is_contiguous(), "k must be contiguous"); TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
TORCH_CHECK(v.is_contiguous(), "v must be contiguous"); TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
...@@ -131,8 +139,16 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, ...@@ -131,8 +139,16 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k,
at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] { at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] {
size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float); size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float);
lightning_attention_decode_kernel<scalar_t><<<grid, block, smem_size, stream>>>( lightning_attention_decode_kernel<scalar_t><<<grid, block, smem_size, stream>>>(
q.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), v.data_ptr<scalar_t>(), past_kv.data_ptr<float>(), q.data_ptr<scalar_t>(),
slope.data_ptr<float>(), output.data_ptr<scalar_t>(), new_kv.data_ptr<float>(), batch_size, num_heads, k.data_ptr<scalar_t>(),
qk_dim, v_dim); v.data_ptr<scalar_t>(),
past_kv.data_ptr<float>(),
slope.data_ptr<float>(),
output.data_ptr<scalar_t>(),
new_kv.data_ptr<float>(),
batch_size,
num_heads,
qk_dim,
v_dim);
})); }));
} }
...@@ -25,9 +25,15 @@ namespace cutlass { ...@@ -25,9 +25,15 @@ namespace cutlass {
namespace epilogue { namespace epilogue {
namespace threadblock { namespace threadblock {
template <typename ThreadblockShape_, int ThreadCount, typename ScaleTileIterator_, typename OutputTileIterator_, template <
typename ElementAccumulator_, typename ElementCompute_, typename ElementwiseFunctor_, typename ThreadblockShape_,
bool UseMasking_ = false> int ThreadCount,
typename ScaleTileIterator_,
typename OutputTileIterator_,
typename ElementAccumulator_,
typename ElementCompute_,
typename ElementwiseFunctor_,
bool UseMasking_ = false>
class EpilogueVisitorPerRowPerCol { class EpilogueVisitorPerRowPerCol {
public: public:
using ThreadblockShape = ThreadblockShape_; using ThreadblockShape = ThreadblockShape_;
...@@ -69,8 +75,11 @@ class EpilogueVisitorPerRowPerCol { ...@@ -69,8 +75,11 @@ class EpilogueVisitorPerRowPerCol {
Arguments(typename ElementwiseFunctor::Params elementwise_) Arguments(typename ElementwiseFunctor::Params elementwise_)
: elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} : elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, int64_t batch_stride_C_, Arguments(
int64_t batch_stride_D_) typename ElementwiseFunctor::Params elementwise_,
int64_t batch_stride_alpha_,
int64_t batch_stride_C_,
int64_t batch_stride_D_)
: elementwise(elementwise_), : elementwise(elementwise_),
batch_stride_alpha(batch_stride_alpha_), batch_stride_alpha(batch_stride_alpha_),
batch_stride_C(batch_stride_C_), batch_stride_C(batch_stride_C_),
...@@ -131,17 +140,26 @@ class EpilogueVisitorPerRowPerCol { ...@@ -131,17 +140,26 @@ class EpilogueVisitorPerRowPerCol {
public: public:
CUTLASS_DEVICE CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, EpilogueVisitorPerRowPerCol(
cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, Params const& params,
typename ScaleTileIterator::Params params_alpha_col, SharedStorage& shared_storage,
typename OutputTileIterator::Params params_C, cutlass::MatrixCoord const& problem_size,
typename OutputTileIterator::Params params_D, bool with_bias, bool per_token_quant, int thread_idx,
bool per_channel_quant, AlphaScaleElementType* ptr_alpha_row, int warp_idx,
AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C, int lane_idx,
typename OutputTileIterator::Element* ptr_D, typename ScaleTileIterator::Params params_alpha_col,
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), typename OutputTileIterator::Params params_C,
int column_offset = 0, typename OutputTileIterator::Params params_D,
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) bool with_bias,
bool per_token_quant,
bool per_channel_quant,
AlphaScaleElementType* ptr_alpha_row,
AlphaScaleElementType* ptr_alpha_col,
typename OutputTileIterator::Element* ptr_C,
typename OutputTileIterator::Element* ptr_D,
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
int column_offset = 0,
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
: params_(params), : params_(params),
shared_storage_(shared_storage), shared_storage_(shared_storage),
extent_(problem_size), extent_(problem_size),
...@@ -166,8 +184,9 @@ class EpilogueVisitorPerRowPerCol { ...@@ -166,8 +184,9 @@ class EpilogueVisitorPerRowPerCol {
/// Helper to indicate split-K behavior /// Helper to indicate split-K behavior
CUTLASS_DEVICE CUTLASS_DEVICE
void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme void set_k_partition(
int split_k_slices) { ///< Total number of split-K slices int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
} }
/// Called to set the batch index /// Called to set the batch index
...@@ -251,8 +270,8 @@ class EpilogueVisitorPerRowPerCol { ...@@ -251,8 +270,8 @@ class EpilogueVisitorPerRowPerCol {
private: private:
CUTLASS_DEVICE CUTLASS_DEVICE
ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum, ComputeFragment const& scale_col, ComputeFragment per_token_channel_scale_accumulator_(
AlphaScaleElementType const& scale_row) { ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) {
ComputeFragment result; ComputeFragment result;
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) { for (int i = 0; i < ComputeFragment::kElements; ++i) {
...@@ -263,8 +282,8 @@ class EpilogueVisitorPerRowPerCol { ...@@ -263,8 +282,8 @@ class EpilogueVisitorPerRowPerCol {
} }
CUTLASS_DEVICE CUTLASS_DEVICE
ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum, AlphaScaleElementType const& scale_col, ComputeFragment per_token_scale_accumulator_(
AlphaScaleElementType const& scale_row) { ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) {
ComputeFragment result; ComputeFragment result;
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) { for (int i = 0; i < ComputeFragment::kElements; ++i) {
......
...@@ -16,16 +16,20 @@ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelT ...@@ -16,16 +16,20 @@ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelT
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp // n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// specialized dynamic schedule For FP8 kernels with Block Scaling // specialized dynamic schedule For FP8 kernels with Block Scaling
template <int Stages_, class ClusterShape_ = Shape<_1, _1, _1>, class KernelSchedule = KernelTmaWarpSpecialized, template <
int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, int Stages_,
// while zero-value `ScaleGranularityM` indicates that scaling class ClusterShape_ = Shape<_1, _1, _1>,
// granularity is `size<0>(TileShape_MNK{})` along M. class KernelSchedule = KernelTmaWarpSpecialized,
> int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8 struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> { : MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
static_assert(cute::is_same_v<KernelSchedule, static_assert(
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>, cute::
"KernelSchedule must be one of the warp specialized policies"); is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>,
"KernelSchedule must be one of the warp specialized policies");
}; };
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
......
...@@ -159,8 +159,9 @@ class GemmUniversalBaseCompat { ...@@ -159,8 +159,9 @@ class GemmUniversalBaseCompat {
get_grid_shape_(grid_tiled_shape, gemm_k_size, args); get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" CUTLASS_TRACE_HOST(
<< " result = {" << result << "}"); " grid_tiled_shape: " << grid_tiled_shape << "\n"
<< " result = {" << result << "}");
return result; return result;
} }
...@@ -175,8 +176,8 @@ class GemmUniversalBaseCompat { ...@@ -175,8 +176,8 @@ class GemmUniversalBaseCompat {
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
if (smem_size <= (48 << 10)) { if (smem_size <= (48 << 10)) {
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel<GemmKernel>, cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GemmKernel::kThreadCount, smem_size); &max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
if (result == cudaSuccess) { if (result == cudaSuccess) {
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
...@@ -184,12 +185,12 @@ class GemmUniversalBaseCompat { ...@@ -184,12 +185,12 @@ class GemmUniversalBaseCompat {
} }
} else { } else {
// Query assuming zero shared memory then compute occupancy limit based on SMEM // Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel<GemmKernel>, cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
GemmKernel::kThreadCount, 0); &max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
if (result != cudaSuccess) { if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " CUTLASS_TRACE_HOST(
<< cudaGetErrorString(result)); " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1; return -1;
} }
...@@ -226,8 +227,9 @@ class GemmUniversalBaseCompat { ...@@ -226,8 +227,9 @@ class GemmUniversalBaseCompat {
/// Initializes GEMM state from arguments. /// Initializes GEMM state from arguments.
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " CUTLASS_TRACE_HOST(
<< workspace << ", stream: " << (stream ? "non-null" : "null")); "GemmUniversalBaseCompat::initialize() - workspace " << workspace
<< ", stream: " << (stream ? "non-null" : "null"));
size_t workspace_bytes = get_workspace_size(args); size_t workspace_bytes = get_workspace_size(args);
......
...@@ -32,10 +32,11 @@ namespace kernel { ...@@ -32,10 +32,11 @@ namespace kernel {
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate template <
typename Epilogue_, ///! Epilogue typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename ThreadblockSwizzle_ ///! Threadblock swizzling function typename Epilogue_, ///! Epilogue
> typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmWithEpilogueVisitor { struct GemmWithEpilogueVisitor {
public: public:
using Mma = Mma_; using Mma = Mma_;
...@@ -119,9 +120,15 @@ struct GemmWithEpilogueVisitor { ...@@ -119,9 +120,15 @@ struct GemmWithEpilogueVisitor {
Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {} Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {}
/// constructs an arguments structure /// constructs an arguments structure
Arguments(GemmCoord problem_size_, TensorRefA ref_A_, TensorRefB ref_B_, TensorRefAlphaCol ref_alpha_col_, Arguments(
TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, GemmCoord problem_size_,
typename EpilogueVisitor::Arguments epilogue_visitor_) TensorRefA ref_A_,
TensorRefB ref_B_,
TensorRefAlphaCol ref_alpha_col_,
TensorRefAlphaRow ref_alpha_row_,
TensorRefC ref_C_,
TensorRefC ref_D_,
typename EpilogueVisitor::Arguments epilogue_visitor_)
: mode(GemmUniversalMode::kGemm), : mode(GemmUniversalMode::kGemm),
problem_size(problem_size_), problem_size(problem_size_),
batch_count(1), batch_count(1),
...@@ -269,8 +276,9 @@ struct GemmWithEpilogueVisitor { ...@@ -269,8 +276,9 @@ struct GemmWithEpilogueVisitor {
isAMisaligned = problem_size.k() % kAlignmentA; isAMisaligned = problem_size.k() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) { } else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
isAMisaligned = problem_size.m() % kAlignmentA; isAMisaligned = problem_size.m() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value || } else if (
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) { platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value ||
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
isAMisaligned = problem_size.k() % kAlignmentA; isAMisaligned = problem_size.k() % kAlignmentA;
} }
...@@ -278,8 +286,9 @@ struct GemmWithEpilogueVisitor { ...@@ -278,8 +286,9 @@ struct GemmWithEpilogueVisitor {
isBMisaligned = problem_size.n() % kAlignmentB; isBMisaligned = problem_size.n() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) { } else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
isBMisaligned = problem_size.k() % kAlignmentB; isBMisaligned = problem_size.k() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value || } else if (
platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) { platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value ||
platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
isBMisaligned = problem_size.k() % kAlignmentB; isBMisaligned = problem_size.k() % kAlignmentB;
} }
...@@ -287,8 +296,9 @@ struct GemmWithEpilogueVisitor { ...@@ -287,8 +296,9 @@ struct GemmWithEpilogueVisitor {
isCMisaligned = problem_size.n() % kAlignmentC; isCMisaligned = problem_size.n() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) { } else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
isCMisaligned = problem_size.m() % kAlignmentC; isCMisaligned = problem_size.m() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value || } else if (
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) { platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value ||
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
isCMisaligned = problem_size.n() % kAlignmentC; isCMisaligned = problem_size.n() % kAlignmentC;
} }
...@@ -373,11 +383,11 @@ struct GemmWithEpilogueVisitor { ...@@ -373,11 +383,11 @@ struct GemmWithEpilogueVisitor {
int thread_idx = threadIdx.x; int thread_idx = threadIdx.x;
// Construct iterators to A and B operands // Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, typename Mma::IteratorA iterator_A(
tb_offset_A); params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, typename Mma::IteratorB iterator_B(
tb_offset_B); params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
// Broadcast the warp_id computed by lane 0 to ensure dependent code // Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform. // is compiled as warp-uniform.
...@@ -409,8 +419,8 @@ struct GemmWithEpilogueVisitor { ...@@ -409,8 +419,8 @@ struct GemmWithEpilogueVisitor {
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle // assume identity swizzle
MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM, MatrixCoord threadblock_offset(
threadblock_tile_offset.n() * Mma::Shape::kN); threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
...@@ -423,11 +433,25 @@ struct GemmWithEpilogueVisitor { ...@@ -423,11 +433,25 @@ struct GemmWithEpilogueVisitor {
with_bias = false; with_bias = false;
} }
EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, params.problem_size.mn(), EpilogueVisitor epilogue_visitor(
thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, params.epilogue_visitor,
params.params_D, with_bias, true, true, params.ptr_alpha_row, params.ptr_alpha_col, shared_storage.epilogue.visitor,
params.ptr_C, params.ptr_D, threadblock_offset, params.problem_size.mn(),
blockIdx.y * params.problem_size.m()); thread_idx,
warp_idx,
lane_idx,
params.params_alpha_col,
params.params_C,
params.params_D,
with_bias,
true,
true,
params.ptr_alpha_row,
params.ptr_alpha_col,
params.ptr_C,
params.ptr_D,
threadblock_offset,
blockIdx.y * params.problem_size.m());
if (params.mode == GemmUniversalMode::kGemm) { if (params.mode == GemmUniversalMode::kGemm) {
// Indicate which position in a serial reduction the output operator is currently updating // Indicate which position in a serial reduction the output operator is currently updating
......
...@@ -21,10 +21,13 @@ ...@@ -21,10 +21,13 @@
#include "utils.h" #include "utils.h"
static void check_group_count(const std::vector<torch::Tensor>& inputs, const std::vector<torch::Tensor>& weights, static void check_group_count(
const std::vector<torch::Tensor>& outputs) { const std::vector<torch::Tensor>& inputs,
TORCH_CHECK(((inputs.size() == weights.size()) && (inputs.size() == outputs.size())), const std::vector<torch::Tensor>& weights,
"The group count of inputs, weights and outputs should be the same."); const std::vector<torch::Tensor>& outputs) {
TORCH_CHECK(
((inputs.size() == weights.size()) && (inputs.size() == outputs.size())),
"The group count of inputs, weights and outputs should be the same.");
} }
static void check_device_dtype(const torch::Dtype& dtype, const std::vector<torch::Tensor>& tensors) { static void check_device_dtype(const torch::Dtype& dtype, const std::vector<torch::Tensor>& tensors) {
...@@ -68,21 +71,26 @@ static std::vector<void*> get_tensor_ptrs(const std::vector<torch::Tensor>& tens ...@@ -68,21 +71,26 @@ static std::vector<void*> get_tensor_ptrs(const std::vector<torch::Tensor>& tens
static torch::Tensor create_ptr_pointer(const std::vector<void*>& ptrs, cudaStream_t stream) { static torch::Tensor create_ptr_pointer(const std::vector<void*>& ptrs, cudaStream_t stream) {
auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA);
torch::Tensor gpu_ptrs = torch::empty({static_cast<int>(ptrs.size())}, options); torch::Tensor gpu_ptrs = torch::empty({static_cast<int>(ptrs.size())}, options);
TORCH_CHECK(cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, TORCH_CHECK(
stream) == CUBLAS_STATUS_SUCCESS); cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, stream) ==
CUBLAS_STATUS_SUCCESS);
return gpu_ptrs; return gpu_ptrs;
} }
// We want compute input @ weight^T in row major // We want compute input @ weight^T in row major
// This is equivalent to computing weight @ input^T in col major // This is equivalent to computing weight @ input^T in col major
// Cublas only accepts matrix in column major, so this arrangement is needed // Cublas only accepts matrix in column major, so this arrangement is needed
void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, // b: (m, k) row major = (k, m) col major void cublas_grouped_gemm(
const std::vector<torch::Tensor>& weights, // a: (n, k) row major = (n, k)^T col major const std::vector<torch::Tensor>& inputs, // b: (m, k) row major = (k, m) col major
const std::vector<torch::Tensor>& outputs, // c: (m, n) row major = (n, m) col major const std::vector<torch::Tensor>& weights, // a: (n, k) row major = (n, k)^T col major
const torch::Dtype& out_dtype, int64_t cublas_handle, int64_t cuda_stream) { const std::vector<torch::Tensor>& outputs, // c: (m, n) row major = (n, m) col major
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, const torch::Dtype& out_dtype,
"cublas grouped_gemm can" int64_t cublas_handle,
"only be applied to float16 and bfloat16 dtype"); int64_t cuda_stream) {
TORCH_CHECK(
out_dtype == torch::kHalf || out_dtype == torch::kBFloat16,
"cublas grouped_gemm can"
"only be applied to float16 and bfloat16 dtype");
int group_count = inputs.size(); int group_count = inputs.size();
check_group_count(inputs, weights, outputs); check_group_count(inputs, weights, outputs);
...@@ -133,16 +141,32 @@ void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, // b: (m, k ...@@ -133,16 +141,32 @@ void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, // b: (m, k
torch::Tensor d_c = create_ptr_pointer(c_array, stream); torch::Tensor d_c = create_ptr_pointer(c_array, stream);
#if defined CUDA_VERSION && CUDA_VERSION >= 12050 #if defined CUDA_VERSION && CUDA_VERSION >= 12050
auto status = cublasGemmGroupedBatchedEx(handle, transa_array.data(), transb_array.data(), m_array.data(), auto status = cublasGemmGroupedBatchedEx(
n_array.data(), k_array.data(), alpha_array.data(), (void**)d_a.data_ptr(), handle,
cuda_data_type, lda_array.data(), (void**)d_b.data_ptr(), cuda_data_type, transa_array.data(),
ldb_array.data(), beta_array.data(), (void**)d_c.data_ptr(), cuda_data_type, transb_array.data(),
ldc_array.data(), group_count, group_size.data(), CUBLAS_COMPUTE_32F); m_array.data(),
n_array.data(),
k_array.data(),
alpha_array.data(),
(void**)d_a.data_ptr(),
cuda_data_type,
lda_array.data(),
(void**)d_b.data_ptr(),
cuda_data_type,
ldb_array.data(),
beta_array.data(),
(void**)d_c.data_ptr(),
cuda_data_type,
ldc_array.data(),
group_count,
group_size.data(),
CUBLAS_COMPUTE_32F);
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status)); TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status));
TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization"); TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization");
return; return;
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED(false, TORCH_CHECK_NOT_IMPLEMENTED(
"Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion()); false, "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion());
} }
...@@ -35,8 +35,12 @@ ...@@ -35,8 +35,12 @@
using namespace cute; using namespace cute;
template <typename OutType, typename TileShape, typename ClusterShape, int ScaleGranularityM = 1> template <typename OutType, typename TileShape, typename ClusterShape, int ScaleGranularityM = 1>
void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, void launch_sm90_fp8_blockwise_scaled_mm(
const torch::Tensor& scales_a, const torch::Tensor& scales_b) { torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b) {
using ElementAccumulator = float; using ElementAccumulator = float;
using ElementCompute = float; using ElementCompute = float;
using ElementBlockScale = float; using ElementBlockScale = float;
...@@ -66,19 +70,43 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor ...@@ -66,19 +70,43 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
using KernelSchedule = using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>; cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementCompute, ElementC, ArchTag,
LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, EpilogueSchedule, StoreEpilogueCompute>::CollectiveOp; OperatorClass,
TileShape,
ClusterShape,
EpilogueTileType,
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
AlignmentC,
ElementD,
LayoutD,
AlignmentD,
EpilogueSchedule,
StoreEpilogueCompute>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, ArchTag,
TileShape, ClusterShape, OperatorClass,
ElementA,
LayoutA,
AlignmentA,
ElementB,
LayoutB,
AlignmentB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>, sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp; KernelSchedule>::CollectiveOp;
using GemmKernel = using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>; CollectiveMainloop,
CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
Gemm gemm_op; Gemm gemm_op;
...@@ -127,16 +155,23 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor ...@@ -127,16 +155,23 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
} }
template <typename OutType> template <typename OutType>
void sm90_fp8_blockwise_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, void sm90_fp8_blockwise_dispatch_shape(
const torch::Tensor& scales_a, const torch::Tensor& scales_b) { torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b) {
using TileShape = Shape<_128, _128, _128>; using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
launch_sm90_fp8_blockwise_scaled_mm<OutType, TileShape, ClusterShape>(out, a, b, scales_a, scales_b); launch_sm90_fp8_blockwise_scaled_mm<OutType, TileShape, ClusterShape>(out, a, b, scales_a, scales_b);
} }
torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, torch::Tensor fp8_blockwise_scaled_mm(
const torch::Tensor& scales_a, const torch::Tensor& scales_b, const torch::Tensor& mat_a,
const torch::Dtype& out_dtype) { const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype) {
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
...@@ -145,10 +180,10 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T ...@@ -145,10 +180,10 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0, TORCH_CHECK(
"mat_a must be multiple of 16 bytes for memory alignment"); (mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0, TORCH_CHECK(
"mat_b must be multiple of 16 bytes for memory alignment"); (mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
...@@ -186,6 +221,6 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T ...@@ -186,6 +221,6 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
#endif #endif
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED(false, TORCH_CHECK_NOT_IMPLEMENTED(
"No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version); false, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
} }
...@@ -53,10 +53,17 @@ limitations under the License. ...@@ -53,10 +53,17 @@ limitations under the License.
using namespace cute; using namespace cute;
#if defined CUDA_VERSION && CUDA_VERSION >= 12040 #if defined CUDA_VERSION && CUDA_VERSION >= 12040
template <typename ElementType, typename OutElementType, typename AccumElementType, typename CtaShape, template <
typename WarpShape, int Stages, bool WithBias, typename FP8MathOperator = cutlass::arch::OpMultiplyAdd, typename ElementType,
template <typename...> typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT, typename OutElementType,
typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>> typename AccumElementType,
typename CtaShape,
typename WarpShape,
int Stages,
bool WithBias,
typename FP8MathOperator = cutlass::arch::OpMultiplyAdd,
template <typename...> typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT,
typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>>
struct DeviceGemmFp8RowwiseSm89 { struct DeviceGemmFp8RowwiseSm89 {
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)"); static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
...@@ -85,56 +92,86 @@ struct DeviceGemmFp8RowwiseSm89 { ...@@ -85,56 +92,86 @@ struct DeviceGemmFp8RowwiseSm89 {
// Number of epilogue stages in EVT // Number of epilogue stages in EVT
static constexpr int EVTEpilogueStages = 1; static constexpr int EVTEpilogueStages = 1;
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<CtaShape, WarpShape, ElementC, using OutputTileThreadMap = cutlass::epilogue::threadblock::
AlignmentC, EVTEpilogueStages>; OutputTileThreadLayout<CtaShape, WarpShape, ElementC, AlignmentC, EVTEpilogueStages>;
// Definition of EVT // Definition of EVT
using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch; using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch;
using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute< using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>; cutlass::multiplies,
using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast<OutputTileThreadMap, ElementComputeEpilogue, ElementComputeEpilogue,
Stride<_0, _1, _0>>; ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using bScaleSrc = cutlass::epilogue::threadblock::
VisitorRowBroadcast<OutputTileThreadMap, ElementComputeEpilogue, Stride<_0, _1, _0>>;
using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeBScale, accSrc, bScaleSrc>; using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeBScale, accSrc, bScaleSrc>;
using ComputeAScale = using ComputeAScale = cutlass::epilogue::threadblock::
cutlass::epilogue::threadblock::VisitorCompute<cutlass::multiplies, ElementC, ElementComputeEpilogue, VisitorCompute<cutlass::multiplies, ElementC, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>;
cutlass::FloatRoundStyle::round_to_nearest>; using aScaleSrc = cutlass::epilogue::threadblock::
using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast<OutputTileThreadMap, ElementComputeEpilogue, VisitorColBroadcast<OutputTileThreadMap, ElementComputeEpilogue, Stride<_1, _0, _0>>;
Stride<_1, _0, _0>>;
using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeAScale, EpilogueBScale, aScaleSrc>; using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT<ComputeAScale, EpilogueBScale, aScaleSrc>;
// With bias // With bias
using biasSrc = using biasSrc =
cutlass::epilogue::threadblock::VisitorRowBroadcast<OutputTileThreadMap, ElementOutput, Stride<_0, _1, _0>>; cutlass::epilogue::threadblock::VisitorRowBroadcast<OutputTileThreadMap, ElementOutput, Stride<_0, _1, _0>>;
using ComputeAScaleWithBias = using ComputeAScaleWithBias = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::epilogue::threadblock::VisitorCompute<cutlass::multiply_add, ElementC, ElementComputeEpilogue, cutlass::multiply_add,
cutlass::FloatRoundStyle::round_to_nearest>; ElementC,
ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using EpilogueAScaleWithBias = using EpilogueAScaleWithBias =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAScaleWithBias, EpilogueBScale, aScaleSrc, biasSrc>; cutlass::epilogue::threadblock::Sm80EVT<ComputeAScaleWithBias, EpilogueBScale, aScaleSrc, biasSrc>;
using dTar = cutlass::epilogue::threadblock::VisitorAuxStore< using dTar = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementC, cutlass::FloatRoundStyle::round_to_nearest, Stride<int64_t, _1, _0>>; OutputTileThreadMap,
using EpilogueStore = ElementC,
typename cutlass::platform::conditional<WithBias, cutlass::FloatRoundStyle::round_to_nearest,
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScaleWithBias>, Stride<int64_t, _1, _0>>;
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScale>>::type; using EpilogueStore = typename cutlass::platform::conditional<
WithBias,
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScaleWithBias>,
cutlass::epilogue::threadblock::Sm80EVT<dTar, EpilogueAScale>>::type;
using EpilogueOp = EpilogueStore; using EpilogueOp = EpilogueStore;
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB, ElementA,
cutlass::ComplexTransform::kNone, AlignmentB, ElementC, LayoutC, AlignmentC, ElementAccumulator, LayoutA,
ElementComputeEpilogue, OperatorClass, ArchTag, CtaShape, WarpShape, InstructionShape, EpilogueOp, cutlass::ComplexTransform::kNone,
ThreadblockSwizzle, Stages, FP8MathOperator, EVTEpilogueStages>::GemmKernel; AlignmentA,
ElementB,
LayoutB,
cutlass::ComplexTransform::kNone,
AlignmentB,
ElementC,
LayoutC,
AlignmentC,
ElementAccumulator,
ElementComputeEpilogue,
OperatorClass,
ArchTag,
CtaShape,
WarpShape,
InstructionShape,
EpilogueOp,
ThreadblockSwizzle,
Stages,
FP8MathOperator,
EVTEpilogueStages>::GemmKernel;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}; };
template <typename Gemm, bool WithBias> template <typename Gemm, bool WithBias>
typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, typename Gemm::Arguments prepare_sm89_fp8_args(
const torch::Tensor& scales_a, const torch::Tensor& scales_b, torch::Tensor& out,
const c10::optional<torch::Tensor>& bias) { const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using ElementT = typename Gemm::ElementA; using ElementT = typename Gemm::ElementA;
using ElementOutput = typename Gemm::ElementD; using ElementOutput = typename Gemm::ElementD;
using ElementComputeEpilogue = float; using ElementComputeEpilogue = float;
...@@ -158,54 +195,61 @@ typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch:: ...@@ -158,54 +195,61 @@ typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::
ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr()); ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr()); ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());
typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, // Mode typename Gemm::Arguments args(
{m, n, k}, // Problem size cutlass::gemm::GemmUniversalMode::kGemm, // Mode
1, // Split-k factor {m, n, k}, // Problem size
{}, // Epilogue args 1, // Split-k factor
ptr_a, // a pointer {}, // Epilogue args
ptr_b, // b pointer ptr_a, // a pointer
nullptr, // c pointer (unused) ptr_b, // b pointer
nullptr, // d pointer (unused) nullptr, // c pointer (unused)
m * k, // batch stride a (unused) nullptr, // d pointer (unused)
n * k, // batch stride b (unused) m * k, // batch stride a (unused)
m * n, // batch stride c (unused) n * k, // batch stride b (unused)
m * n, // batch stride d (unused) m * n, // batch stride c (unused)
lda, // stride a m * n, // batch stride d (unused)
ldb, // stride b lda, // stride a
ldc, // stride c (unused) ldb, // stride b
ldc); // stride d (unused) ldc, // stride c (unused)
ldc); // stride d (unused)
if constexpr (WithBias) { if constexpr (WithBias) {
args.epilogue = {{ args.epilogue = {
{ {
{}, // Accumulator {
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, {}, // Accumulator
{} // Multiplies {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
}, {} // Multiplies
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, },
{ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}}, {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
{} // Multiplies {ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}},
}, {} // Multiplies
{ptr_d, {n, _1{}, _0{}}}}; },
{ptr_d, {n, _1{}, _0{}}}};
} else { } else {
args.epilogue = {{ args.epilogue = {
{ {
{}, // Accumulator {
{ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, {}, // Accumulator
{} // Multiplies {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}},
}, {} // Multiplies
{ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, },
{} // Multiplies {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}},
}, {} // Multiplies
{ptr_d, {n, _1{}, _0{}}}}; },
{ptr_d, {n, _1{}, _0{}}}};
} }
return args; return args;
} }
template <typename Gemm, bool WithBias> template <typename Gemm, bool WithBias>
void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, void launch_sm89_fp8_scaled_mm(
const torch::Tensor& scales_a, const torch::Tensor& scales_b, torch::Tensor& out,
const c10::optional<torch::Tensor>& bias) { const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
auto args = prepare_sm89_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias); auto args = prepare_sm89_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
Gemm gemm_op; Gemm gemm_op;
...@@ -222,109 +266,187 @@ void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const ...@@ -222,109 +266,187 @@ void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const
} }
template <typename OutType, typename CtaShape, typename WarpShape, int Stages> template <typename OutType, typename CtaShape, typename WarpShape, int Stages>
void sm89_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, void sm89_fp8_dispatch_bias(
const torch::Tensor& scales_a, const torch::Tensor& scales_b, torch::Tensor& out,
const c10::optional<torch::Tensor>& bias) { const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using ElementInput = cutlass::float_e4m3_t; using ElementInput = cutlass::float_e4m3_t;
using ElementOutput = OutType; using ElementOutput = OutType;
using AccumElementType = float; using AccumElementType = float;
if (bias) { if (bias) {
using Gemm = typename DeviceGemmFp8RowwiseSm89<ElementInput, ElementOutput, AccumElementType, CtaShape, WarpShape, using Gemm = typename DeviceGemmFp8RowwiseSm89<
Stages, true>::Gemm; ElementInput,
ElementOutput,
AccumElementType,
CtaShape,
WarpShape,
Stages,
true>::Gemm;
return launch_sm89_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias); return launch_sm89_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
} else { } else {
using Gemm = typename DeviceGemmFp8RowwiseSm89<ElementInput, ElementOutput, AccumElementType, CtaShape, WarpShape, using Gemm = typename DeviceGemmFp8RowwiseSm89<
Stages, false>::Gemm; ElementInput,
ElementOutput,
AccumElementType,
CtaShape,
WarpShape,
Stages,
false>::Gemm;
return launch_sm89_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias); return launch_sm89_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
} }
} }
template <typename OutType> template <typename OutType>
void sm89_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, void sm89_fp8_dispatch_shape(
const torch::Tensor& scales_a, const torch::Tensor& scales_b, torch::Tensor& out,
const c10::optional<torch::Tensor>& bias) { const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
uint32_t const m = a.size(0); uint32_t const m = a.size(0);
uint32_t const n = out.size(1); uint32_t const n = out.size(1);
if (m == 1) { if (m == 1) {
if (n <= 8192) { if (n <= 8192) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
7>(out, a, b, scales_a, scales_b, bias);
} else { } else {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
5>(out, a, b, scales_a, scales_b, bias);
} }
} else if (m <= 16) { } else if (m <= 16) {
// M in (1, 16] // M in (1, 16]
if (n <= 8192) { if (n <= 8192) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
4>(out, a, b, scales_a, scales_b, bias);
} else if (n <= 16384) { } else if (n <= 16384) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
5>(out, a, b, scales_a, scales_b, bias);
} else { } else {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
7>(out, a, b, scales_a, scales_b, bias);
} }
} else if (m <= 64) { } else if (m <= 64) {
// M in (16, 64] // M in (16, 64]
if (n <= 16384) { if (n <= 16384) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
7>(out, a, b, scales_a, scales_b, bias);
} else { } else {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
7>(out, a, b, scales_a, scales_b, bias);
} }
} else if (m <= 128) { } else if (m <= 128) {
// M in (64, 128] // M in (64, 128]
if (n <= 8192) { if (n <= 8192) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
4>(out, a, b, scales_a, scales_b, bias);
} else if (n <= 16384) { } else if (n <= 16384) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
5>(out, a, b, scales_a, scales_b, bias);
} else { } else {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
5>(out, a, b, scales_a, scales_b, bias);
} }
} else if (m <= 256) { } else if (m <= 256) {
// M in (128, 256] // M in (128, 256]
if (n <= 8192) { if (n <= 8192) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 64>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<64, 32, 64>,
5>(out, a, b, scales_a, scales_b, bias);
} else if (n <= 16384) { } else if (n <= 16384) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 128, 64>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>,
7>(out, a, b, scales_a, scales_b, bias);
} else { } else {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 128>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<128, 64, 128>,
cutlass::gemm::GemmShape<64, 32, 128>,
4>(out, a, b, scales_a, scales_b, bias);
} }
} else if (m <= 512) { } else if (m <= 512) {
// M in (256, 512) // M in (256, 512)
if (n <= 16384) { if (n <= 16384) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>,
2>(out, a, b, scales_a, scales_b, bias);
} else { } else {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>,
4>(out, a, b, scales_a, scales_b, bias);
} }
} else { } else {
// M in (512, inf) // M in (512, inf)
if (n <= 8192) { if (n <= 8192) {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>,
3>(out, a, b, scales_a, scales_b, bias);
} else { } else {
return sm89_fp8_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, return sm89_fp8_dispatch_bias<
cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); OutType,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>,
2>(out, a, b, scales_a, scales_b, bias);
} }
} }
} }
#endif #endif
#if defined CUDA_VERSION && CUDA_VERSION >= 12000 #if defined CUDA_VERSION && CUDA_VERSION >= 12000
template <typename ElementType, typename OutElementType, typename AccumElementType, typename CTAShape, template <
typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType, typename ElementType,
typename TileSchedulerType = void, bool WithBias = false> typename OutElementType,
typename AccumElementType,
typename CTAShape,
typename ClusterShape,
typename MainloopScheduleType,
typename EpilogueScheduleType,
typename TileSchedulerType = void,
bool WithBias = false>
struct DeviceGemmFp8RowwiseSm90 { struct DeviceGemmFp8RowwiseSm90 {
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)"); static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
...@@ -374,44 +496,70 @@ struct DeviceGemmFp8RowwiseSm90 { ...@@ -374,44 +496,70 @@ struct DeviceGemmFp8RowwiseSm90 {
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default
// setting in the Collective Builder // setting in the Collective Builder
// Implement rowwise scaling epilogue. // Implement rowwise scaling epilogue.
using XScale = using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, 0,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>; TileShape,
ElementComputeEpilogue,
using WScale = ElementComputeEpilogue,
cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, 0,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>; TileShape,
ElementComputeEpilogue,
ElementComputeEpilogue,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
0,
TileShape,
ElementOutput,
ElementOutput,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
ElementComputeEpilogue, // First stage output type. cutlass::multiplies,
ElementComputeEpilogue, // First stage input types. ElementComputeEpilogue, // First stage output type.
cutlass::FloatRoundStyle::round_to_nearest>; ElementComputeEpilogue, // First stage input types.
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementOutput, using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
ElementComputeEpilogue, // Second stage input types. cutlass::multiplies,
cutlass::FloatRoundStyle::round_to_nearest>; ElementOutput,
ElementComputeEpilogue, // Second stage input types.
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>; using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
// With bias // With bias
using ComputeWithBias = using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::fusion::Sm90Compute<cutlass::multiply_add, ElementOutput, ElementComputeEpilogue, cutlass::multiply_add,
cutlass::FloatRoundStyle::round_to_nearest>; ElementOutput,
ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>; using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;
using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type; using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::arch::Sm90,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementComputeEpilogue, ElementC, LayoutC, cutlass::arch::OpClassTensorOp,
AlignmentC, ElementOutput, LayoutOutput, AlignmentOutput, cutlass::epilogue::TmaWarpSpecialized, TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementComputeEpilogue,
ElementC,
LayoutC,
AlignmentC,
ElementOutput,
LayoutOutput,
AlignmentOutput,
cutlass::epilogue::TmaWarpSpecialized,
EpilogueEVT>::CollectiveOp; EpilogueEVT>::CollectiveOp;
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
...@@ -423,22 +571,38 @@ struct DeviceGemmFp8RowwiseSm90 { ...@@ -423,22 +571,38 @@ struct DeviceGemmFp8RowwiseSm90 {
using FastAccum = FastPongSchedule; // Default apply Pingpong using FastAccum = FastPongSchedule; // Default apply Pingpong
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, ArchTag,
TileShape, ClusterShape, OperatorClass,
ElementA,
LayoutA,
AlignmentA,
ElementB,
LayoutB,
AlignmentB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>, sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType>::CollectiveOp; MainloopScheduleType>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>; Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
TileSchedulerType>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
}; };
template <typename Gemm, bool WithBias> template <typename Gemm, bool WithBias>
typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, typename Gemm::Arguments prepare_sm90_fp8_args(
const torch::Tensor& scales_a, const torch::Tensor& scales_b, torch::Tensor& out,
const c10::optional<torch::Tensor>& bias) { const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using ElementT = typename Gemm::ElementA; using ElementT = typename Gemm::ElementA;
using ElementOutput = typename Gemm::ElementD; using ElementOutput = typename Gemm::ElementD;
using ElementComputeEpilogue = float; using ElementComputeEpilogue = float;
...@@ -465,14 +629,15 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch:: ...@@ -465,14 +629,15 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1)); StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
StrideC stride_c; StrideC stride_c;
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1)); StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, typename Gemm::Arguments args = {
{m, n, k, 1}, cutlass::gemm::GemmUniversalMode::kGemm,
{ptr_a, stride_a, ptr_b, stride_b}, {m, n, k, 1},
{{}, // epilogue.thread {ptr_a, stride_a, ptr_b, stride_b},
nullptr, {{}, // epilogue.thread
stride_c, nullptr,
ptr_d, stride_c,
stride_d}}; ptr_d,
stride_d}};
if constexpr (WithBias) { if constexpr (WithBias) {
args.epilogue.thread = { args.epilogue.thread = {
{ptr_scales_a}, {ptr_scales_a},
...@@ -500,9 +665,13 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch:: ...@@ -500,9 +665,13 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::
} }
template <typename Gemm, bool WithBias> template <typename Gemm, bool WithBias>
void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, void launch_sm90_fp8_scaled_mm(
const torch::Tensor& scales_a, const torch::Tensor& scales_b, torch::Tensor& out,
const c10::optional<torch::Tensor>& bias) { const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
auto args = prepare_sm90_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias); auto args = prepare_sm90_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
Gemm gemm_op; Gemm gemm_op;
...@@ -519,66 +688,117 @@ void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const ...@@ -519,66 +688,117 @@ void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const
TORCH_CHECK(status == cutlass::Status::kSuccess) TORCH_CHECK(status == cutlass::Status::kSuccess)
} }
template <typename OutType, typename CTAShape, typename ClusterShape, typename MainloopScheduleType, template <
typename TileSchedulerType> typename OutType,
void sm90_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, typename CTAShape,
const torch::Tensor& scales_a, const torch::Tensor& scales_b, typename ClusterShape,
const c10::optional<torch::Tensor>& bias, bool fast_accum = true, typename MainloopScheduleType,
bool use_persistent = false) { typename TileSchedulerType>
void sm90_fp8_dispatch_bias(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias,
bool fast_accum = true,
bool use_persistent = false) {
using ElementInput = cutlass::float_e4m3_t; using ElementInput = cutlass::float_e4m3_t;
using ElementOutput = OutType; using ElementOutput = OutType;
using AccumElementType = float; using AccumElementType = float;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
if (bias) { if (bias) {
using Gemm = using Gemm = typename DeviceGemmFp8RowwiseSm90<
typename DeviceGemmFp8RowwiseSm90<ElementInput, ElementOutput, AccumElementType, CTAShape, ClusterShape, ElementInput,
MainloopScheduleType, EpilogueScheduleType, TileSchedulerType, true>::Gemm; ElementOutput,
AccumElementType,
CTAShape,
ClusterShape,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>::Gemm;
return launch_sm90_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias); return launch_sm90_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
} else { } else {
using Gemm = using Gemm = typename DeviceGemmFp8RowwiseSm90<
typename DeviceGemmFp8RowwiseSm90<ElementInput, ElementOutput, AccumElementType, CTAShape, ClusterShape, ElementInput,
MainloopScheduleType, EpilogueScheduleType, TileSchedulerType, false>::Gemm; ElementOutput,
AccumElementType,
CTAShape,
ClusterShape,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>::Gemm;
return launch_sm90_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias); return launch_sm90_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
} }
} }
template <typename OutType> template <typename OutType>
void sm90_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, void sm90_fp8_dispatch_shape(
const torch::Tensor& scales_a, const torch::Tensor& scales_b, torch::Tensor& out,
const c10::optional<torch::Tensor>& bias) { const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
uint32_t const m = a.size(0); uint32_t const m = a.size(0);
using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using PersistentTileScheduler = cutlass::gemm::PersistentScheduler; using PersistentTileScheduler = cutlass::gemm::PersistentScheduler;
using BasicTileScheduler = void; using BasicTileScheduler = void;
if (m <= 1) { if (m <= 1) {
return sm90_fp8_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _8, _1>, FastBasicScheduler, return sm90_fp8_dispatch_bias<
BasicTileScheduler>(out, a, b, scales_a, scales_b, bias); OutType,
Shape<_64, _64, _128>,
Shape<_1, _8, _1>,
FastBasicScheduler,
BasicTileScheduler>(out, a, b, scales_a, scales_b, bias);
} }
if (m <= 64) { if (m <= 64) {
// m in [1, 64] // m in [1, 64]
return sm90_fp8_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _4, _1>, FastPingpongScheduler, return sm90_fp8_dispatch_bias<
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); OutType,
Shape<_64, _64, _128>,
Shape<_1, _4, _1>,
FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else if (m <= 256) { } else if (m <= 256) {
// m in (64, 256] // m in (64, 256]
return sm90_fp8_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _1, _1>, FastPingpongScheduler, return sm90_fp8_dispatch_bias<
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); OutType,
Shape<_64, _64, _128>,
Shape<_1, _1, _1>,
FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else if (m <= 1024) { } else if (m <= 1024) {
// m in (256, 1024] // m in (256, 1024]
return sm90_fp8_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_1, _1, _1>, FastPingpongScheduler, return sm90_fp8_dispatch_bias<
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); OutType,
Shape<_128, _128, _128>,
Shape<_1, _1, _1>,
FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else { } else {
// m in (1024, inf) // m in (1024, inf)
return sm90_fp8_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_2, _1, _1>, FastPingpongScheduler, return sm90_fp8_dispatch_bias<
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); OutType,
Shape<_128, _128, _128>,
Shape<_2, _1, _1>,
FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} }
} }
#endif #endif
torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, torch::Tensor fp8_scaled_mm(
const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const torch::Tensor& mat_a,
const c10::optional<torch::Tensor>& bias) { const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias) {
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
...@@ -587,10 +807,10 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat ...@@ -587,10 +807,10 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0, TORCH_CHECK(
"mat_a must be multiple of 16 bytes for memory alignment"); (mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0, TORCH_CHECK(
"mat_b must be multiple of 16 bytes for memory alignment"); (mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
......
...@@ -35,11 +35,20 @@ limitations under the License. ...@@ -35,11 +35,20 @@ limitations under the License.
using namespace cute; using namespace cute;
template <typename ElementOutput, typename ArchTag, typename ThreadblockShape, typename WarpShape, template <
typename InstructionShape, int NumStages> typename ElementOutput,
void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, typename ArchTag,
const torch::Tensor& scales_a, const torch::Tensor& scales_b, typename ThreadblockShape,
const c10::optional<torch::Tensor>& bias) { typename WarpShape,
typename InstructionShape,
int NumStages>
void cutlass_int8_scaled_mm(
torch::Tensor& out,
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using ElementAccumulator = int32_t; using ElementAccumulator = int32_t;
using ElementCompute = float; using ElementCompute = float;
using ElementInputA = int8_t; using ElementInputA = int8_t;
...@@ -48,30 +57,51 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons ...@@ -48,30 +57,51 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons
using OperatorClass = cutlass::arch::OpClassTensorOp; using OperatorClass = cutlass::arch::OpClassTensorOp;
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>;
using DefaultGemmConf = cutlass::gemm::device::DefaultGemmConfiguration<OperatorClass, ArchTag, ElementInputA, using DefaultGemmConf = cutlass::gemm::device::
ElementInputB, ElementOutput, ElementCompute>; DefaultGemmConfiguration<OperatorClass, ArchTag, ElementInputA, ElementInputB, ElementOutput, ElementCompute>;
using EpilogueOutputOp = typename DefaultGemmConf::EpilogueOutputOp; using EpilogueOutputOp = typename DefaultGemmConf::EpilogueOutputOp;
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
ElementInputA, cutlass::layout::RowMajor, DefaultGemmConf::kAlignmentA, ElementInputB, ElementInputA,
cutlass::layout::ColumnMajor, DefaultGemmConf::kAlignmentB, ElementOutput, cutlass::layout::RowMajor, cutlass::layout::RowMajor,
ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, DefaultGemmConf::kAlignmentA,
ThreadblockSwizzle, NumStages, true, typename DefaultGemmConf::Operator>::GemmKernel; ElementInputB,
cutlass::layout::ColumnMajor,
DefaultGemmConf::kAlignmentB,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
NumStages,
true,
typename DefaultGemmConf::Operator>::GemmKernel;
using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap< cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape, typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape,
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count, typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count,
GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads, GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads,
GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess, cutlass::sizeof_bits<ElementOutput>::value>, GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess,
cutlass::sizeof_bits<ElementOutput>::value>,
ElementCompute>; ElementCompute>;
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol< using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol<
ThreadblockShape, GemmKernel_::kThreadCount, AlphaColTileIterator, ThreadblockShape,
typename GemmKernel_::Epilogue::OutputTileIterator, ElementAccumulator, ElementCompute, EpilogueOutputOp>; GemmKernel_::kThreadCount,
AlphaColTileIterator,
typename GemmKernel_::Epilogue::OutputTileIterator,
ElementAccumulator,
ElementCompute,
EpilogueOutputOp>;
using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< using Epilogue = typename cutlass::epilogue::threadblock::
EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue; EpilogueWithVisitorFromExistingEpilogue<EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue;
using GemmKernel = using GemmKernel =
cutlass::gemm::kernel::GemmWithEpilogueVisitor<typename GemmKernel_::Mma, Epilogue, ThreadblockSwizzle>; cutlass::gemm::kernel::GemmWithEpilogueVisitor<typename GemmKernel_::Mma, Epilogue, ThreadblockSwizzle>;
...@@ -104,98 +134,164 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons ...@@ -104,98 +134,164 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons
typename EpilogueOutputOp::Params linearScalingParams; typename EpilogueOutputOp::Params linearScalingParams;
typename EpilogueVisitor::Arguments visitor_args{linearScalingParams}; typename EpilogueVisitor::Arguments visitor_args{linearScalingParams};
typename Gemm::Arguments args{{m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0}, typename Gemm::Arguments args{
{a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args}; {m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0}, {a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args};
auto workspace = torch::empty(gemm_op.get_workspace_size(args), auto workspace = torch::empty(
torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
auto can_implement = gemm_op.can_implement(args); auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess, TORCH_CHECK(
"gemm cannot implement, error: ", cutlassGetStatusString(can_implement)); can_implement == cutlass::Status::kSuccess,
"gemm cannot implement, error: ",
cutlassGetStatusString(can_implement));
auto status = gemm_op(args, workspace.data_ptr(), stream); auto status = gemm_op(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status)); TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
} }
template <typename ElementOutput, typename ArchTag, typename InstructionShape> template <typename ElementOutput, typename ArchTag, typename InstructionShape>
void sm75_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, void sm75_dispatch_shape(
const torch::Tensor& scales_a, const torch::Tensor& scales_b, torch::Tensor& out,
const c10::optional<torch::Tensor>& bias) { const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
int m = mat_a.size(0); int m = mat_a.size(0);
if (m <= 32) { if (m <= 32) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<32, 128, 64>, cutlass_int8_scaled_mm<
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, ElementOutput,
scales_b, bias); ArchTag,
cutlass::gemm::GemmShape<32, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else if (m <= 64) { } else if (m <= 64) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 128, 128>, cutlass_int8_scaled_mm<
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, ElementOutput,
scales_b, bias); ArchTag,
cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else if (m <= 256) { } else if (m <= 256) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<128, 128, 128>, cutlass_int8_scaled_mm<
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, ElementOutput,
scales_b, bias); ArchTag,
cutlass::gemm::GemmShape<128, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else { } else {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<128, 128, 64>, cutlass_int8_scaled_mm<
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, ElementOutput,
scales_b, bias); ArchTag,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
2>(out, mat_a, mat_b, scales_a, scales_b, bias);
} }
} }
template <typename ElementOutput, typename ArchTag, typename InstructionShape> template <typename ElementOutput, typename ArchTag, typename InstructionShape>
void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, void sm80_dispatch_shape(
const torch::Tensor& scales_a, const torch::Tensor& scales_b, torch::Tensor& out,
const c10::optional<torch::Tensor>& bias) { const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
int m = mat_a.size(0); int m = mat_a.size(0);
int n = mat_b.size(1); int n = mat_b.size(1);
if (m <= 16) { if (m <= 16) {
if (n <= 4096) { if (n <= 4096) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<16, 64, 128>, cutlass_int8_scaled_mm<
cutlass::gemm::GemmShape<16, 64, 64>, InstructionShape, 6>(out, mat_a, mat_b, scales_a, ElementOutput,
scales_b, bias); ArchTag,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
InstructionShape,
6>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else { } else {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<16, 64, 128>, cutlass_int8_scaled_mm<
cutlass::gemm::GemmShape<16, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, ElementOutput,
scales_b, bias); ArchTag,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
} }
} else if (m <= 32) { } else if (m <= 32) {
if (n <= 4096) { if (n <= 4096) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<32, 64, 128>, cutlass_int8_scaled_mm<
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 6>(out, mat_a, mat_b, scales_a, ElementOutput,
scales_b, bias); ArchTag,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
6>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else { } else {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<32, 64, 128>, cutlass_int8_scaled_mm<
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, ElementOutput,
scales_b, bias); ArchTag,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
} }
} else if (m <= 64) { } else if (m <= 64) {
if (n <= 4096) { if (n <= 4096) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 64, 128>, cutlass_int8_scaled_mm<
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, ElementOutput,
scales_b, bias); ArchTag,
cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else { } else {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 128, 128>, cutlass_int8_scaled_mm<
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, ElementOutput,
scales_b, bias); ArchTag,
cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
} }
} else if (m <= 128 && n < 8192) { } else if (m <= 128 && n < 8192) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 128, 128>, cutlass_int8_scaled_mm<
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, ElementOutput,
scales_b, bias); ArchTag,
cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else { } else {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<128, 128, 64>, cutlass_int8_scaled_mm<
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, ElementOutput,
scales_b, bias); ArchTag,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
} }
} }
template <typename ElementOutput, typename TileShape, typename ClusterShape, typename MainloopScheduleType, template <
bool WithBias> typename ElementOutput,
void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, typename TileShape,
const torch::Tensor& scales_a, const torch::Tensor& scales_b, typename ClusterShape,
const c10::optional<torch::Tensor>& bias) { typename MainloopScheduleType,
bool WithBias>
void cutlass_int8_scaled_mm_sm90(
torch::Tensor& out,
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using ArchTag = cutlass::arch::Sm90; using ArchTag = cutlass::arch::Sm90;
using ElementAccumulator = int32_t; using ElementAccumulator = int32_t;
...@@ -213,50 +309,75 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, ...@@ -213,50 +309,75 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a,
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
using TileSchedulerType = cutlass::gemm::PersistentScheduler; using TileSchedulerType = cutlass::gemm::PersistentScheduler;
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, using XScale = cutlass::epilogue::fusion::
Stride<Int<1>, Int<0>, Int<0>>>; Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride<Int<1>, Int<0>, Int<0>>>;
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, using WScale = cutlass::epilogue::fusion::
Stride<Int<0>, Int<1>, Int<0>>>; Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride<Int<0>, Int<1>, Int<0>>>;
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, using Bias = cutlass::epilogue::fusion::
Stride<Int<0>, Int<1>, Int<0>>>; Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, Stride<Int<0>, Int<1>, Int<0>>>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
// Scale // Scale
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, using Compute0 = cutlass::epilogue::fusion::
cutlass::FloatRoundStyle::round_to_nearest>; Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementOutput, ElementCompute, using Compute1 = cutlass::epilogue::fusion::
cutlass::FloatRoundStyle::round_to_nearest>; Sm90Compute<cutlass::multiplies, ElementOutput, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>; using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
// With bias // With bias
using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute<cutlass::multiply_add, ElementOutput, ElementCompute, using ComputeWithBias = cutlass::epilogue::fusion::
cutlass::FloatRoundStyle::round_to_nearest>; Sm90Compute<cutlass::multiply_add, ElementOutput, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>; using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;
using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type; using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ArchTag,
ElementAccumulator, ElementCompute, ElementOutput, cutlass::layout::RowMajor, AlignmentC, ElementOutput, OperatorClass,
cutlass::layout::RowMajor, AlignmentOutput, EpilogueScheduleType, EpilogueEVT>::CollectiveOp; TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementCompute,
ElementOutput,
cutlass::layout::RowMajor,
AlignmentC,
ElementOutput,
cutlass::layout::RowMajor,
AlignmentOutput,
EpilogueScheduleType,
EpilogueEVT>::CollectiveOp;
using Stages = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( using Stages = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>; sizeof(typename CollectiveEpilogue::SharedStorage))>;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementInputA, cutlass::layout::RowMajor, AlignmentA, ElementInputB, ArchTag,
cutlass::layout::ColumnMajor, AlignmentB, ElementAccumulator, TileShape, ClusterShape, Stages, OperatorClass,
ElementInputA,
cutlass::layout::RowMajor,
AlignmentA,
ElementInputB,
cutlass::layout::ColumnMajor,
AlignmentB,
ElementAccumulator,
TileShape,
ClusterShape,
Stages,
MainloopScheduleType>::CollectiveOp; MainloopScheduleType>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>; Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
TileSchedulerType>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
...@@ -283,14 +404,15 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, ...@@ -283,14 +404,15 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a,
StrideC stride_c; StrideC stride_c;
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1)); StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, typename Gemm::Arguments args = {
{m, n, k, 1}, cutlass::gemm::GemmUniversalMode::kGemm,
{a_ptr, stride_a, b_ptr, stride_b}, {m, n, k, 1},
{{}, // epilogue.thread {a_ptr, stride_a, b_ptr, stride_b},
nullptr, {{}, // epilogue.thread
stride_c, nullptr,
o_ptr, stride_c,
stride_d}}; o_ptr,
stride_d}};
if constexpr (WithBias) { if constexpr (WithBias) {
ElementOutput* bias_ptr = static_cast<ElementOutput*>(bias->data_ptr()); ElementOutput* bias_ptr = static_cast<ElementOutput*>(bias->data_ptr());
...@@ -308,23 +430,29 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, ...@@ -308,23 +430,29 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a,
}; };
} }
auto workspace = torch::empty(gemm_op.get_workspace_size(args), auto workspace = torch::empty(
torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
auto can_implement = gemm_op.can_implement(args); auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess, TORCH_CHECK(
"gemm cannot implement, error: ", cutlassGetStatusString(can_implement)); can_implement == cutlass::Status::kSuccess,
"gemm cannot implement, error: ",
cutlassGetStatusString(can_implement));
auto status = gemm_op(args, workspace.data_ptr(), stream); auto status = gemm_op(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status)); TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
} }
template <typename ElementOutput, typename TileShape, typename ClusterShape, typename MainloopScheduleType> template <typename ElementOutput, typename TileShape, typename ClusterShape, typename MainloopScheduleType>
void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, void sm90_dispatch_bias(
const torch::Tensor& scales_a, const torch::Tensor& scales_b, torch::Tensor& out,
const c10::optional<torch::Tensor>& bias) { const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
if (bias) { if (bias) {
cutlass_int8_scaled_mm_sm90<ElementOutput, TileShape, ClusterShape, MainloopScheduleType, true>( cutlass_int8_scaled_mm_sm90<ElementOutput, TileShape, ClusterShape, MainloopScheduleType, true>(
out, mat_a, mat_b, scales_a, scales_b, bias); out, mat_a, mat_b, scales_a, scales_b, bias);
...@@ -335,45 +463,73 @@ void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const to ...@@ -335,45 +463,73 @@ void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const to
} }
template <typename ElementOutput> template <typename ElementOutput>
void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, void sm90_dispatch_shape(
const torch::Tensor& scales_a, const torch::Tensor& scales_b, torch::Tensor& out,
const c10::optional<torch::Tensor>& bias) { const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
int m = mat_a.size(0); int m = mat_a.size(0);
int n = mat_b.size(1); int n = mat_b.size(1);
if (m <= 32) { if (m <= 32) {
if (n < 8192) { if (n < 8192) {
return sm90_dispatch_bias<ElementOutput, Shape<_64, _64, _128>, Shape<_1, _8, _1>, return sm90_dispatch_bias<
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); ElementOutput,
Shape<_64, _64, _128>,
Shape<_1, _8, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else { } else {
return sm90_dispatch_bias<ElementOutput, Shape<_64, _128, _128>, Shape<_1, _8, _1>, return sm90_dispatch_bias<
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); ElementOutput,
Shape<_64, _128, _128>,
Shape<_1, _8, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
} }
} else if (m <= 64) { } else if (m <= 64) {
if (n < 8192) { if (n < 8192) {
return sm90_dispatch_bias<ElementOutput, Shape<_64, _64, _128>, Shape<_1, _4, _1>, return sm90_dispatch_bias<
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); ElementOutput,
Shape<_64, _64, _128>,
Shape<_1, _4, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else { } else {
return sm90_dispatch_bias<ElementOutput, Shape<_64, _64, _256>, Shape<_1, _1, _1>, return sm90_dispatch_bias<
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); ElementOutput,
Shape<_64, _64, _256>,
Shape<_1, _1, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
} }
} else if (m <= 128) { } else if (m <= 128) {
if (n <= 4096) { if (n <= 4096) {
return sm90_dispatch_bias<ElementOutput, Shape<_64, _64, _128>, Shape<_2, _1, _1>, return sm90_dispatch_bias<
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); ElementOutput,
Shape<_64, _64, _128>,
Shape<_2, _1, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else { } else {
return sm90_dispatch_bias<ElementOutput, Shape<_64, _128, _128>, Shape<_2, _1, _1>, return sm90_dispatch_bias<
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); ElementOutput,
Shape<_64, _128, _128>,
Shape<_2, _1, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
} }
} else { } else {
return sm90_dispatch_bias<ElementOutput, Shape<_128, _128, _128>, Shape<_2, _1, _1>, return sm90_dispatch_bias<
cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b, ElementOutput,
bias); Shape<_128, _128, _128>,
Shape<_2, _1, _1>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b, bias);
} }
} }
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, torch::Tensor int8_scaled_mm(
const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const torch::Tensor& mat_a,
const c10::optional<torch::Tensor>& bias) { const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias) {
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
......
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
#include "utils.h" #include "utils.h"
template <typename T> template <typename T>
__global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, __global__ void
const int64_t num_elements) { per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, const int64_t num_elements) {
float max_value = 0.0f; float max_value = 0.0f;
unsigned int tid = threadIdx.x; unsigned int tid = threadIdx.x;
unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x; unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -56,8 +56,11 @@ __global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __r ...@@ -56,8 +56,11 @@ __global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __r
} }
template <typename T> template <typename T>
__global__ void per_tensor_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output, __global__ void per_tensor_quant_fp8_kernel(
const float* __restrict__ scale, const int64_t num_elements) { const T* __restrict__ input,
FP8_TYPE* __restrict__ output,
const float* __restrict__ scale,
const int64_t num_elements) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x; const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int grid_size = blockDim.x * gridDim.x; const int grid_size = blockDim.x * gridDim.x;
const float scale_val = 1.0f / (*scale); const float scale_val = 1.0f / (*scale);
...@@ -124,8 +127,10 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch ...@@ -124,8 +127,10 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
} }
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>( per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()), static_cast<scalar_t*>(input.data_ptr()),
static_cast<float*>(output_s.data_ptr()), num_elements); static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
num_elements);
return true; return true;
}); });
} }
...@@ -17,10 +17,15 @@ __device__ __forceinline__ float GroupReduce(float val, const int tid) { ...@@ -17,10 +17,15 @@ __device__ __forceinline__ float GroupReduce(float val, const int tid) {
} }
template <typename T> template <typename T>
__global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, void* __restrict__ output_q, __global__ void per_token_group_quant_fp8_kernel(
float* __restrict__ output_s, const int group_size, const T* __restrict__ input,
const int num_groups, const float eps, const float fp8_min, void* __restrict__ output_q,
const float fp8_max) { float* __restrict__ output_s,
const int group_size,
const int num_groups,
const float eps,
const float fp8_min,
const float fp8_max) {
const int groups_per_block = 16; const int groups_per_block = 16;
const int local_group_id = threadIdx.x / 16; const int local_group_id = threadIdx.x / 16;
const int lane_id = threadIdx.x % 16; const int lane_id = threadIdx.x % 16;
...@@ -80,8 +85,14 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo ...@@ -80,8 +85,14 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo
} }
} }
void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, void sgl_per_token_group_quant_fp8(
int64_t group_size, double eps, double fp8_min, double fp8_max) { torch::Tensor input,
torch::Tensor output_q,
torch::Tensor output_s,
int64_t group_size,
double eps,
double fp8_min,
double fp8_max) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(output_q); CHECK_INPUT(output_q);
CHECK_INPUT(output_s); CHECK_INPUT(output_s);
...@@ -97,8 +108,14 @@ void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q, ...@@ -97,8 +108,14 @@ void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q,
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
per_token_group_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>( per_token_group_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), output_q.data_ptr(), static_cast<float*>(output_s.data_ptr()), static_cast<scalar_t*>(input.data_ptr()),
group_size, num_groups, (float)eps, (float)fp8_min, (float)fp8_max); output_q.data_ptr(),
static_cast<float*>(output_s.data_ptr()),
group_size,
num_groups,
(float)eps,
(float)fp8_min,
(float)fp8_max);
return true; return true;
}); });
} }
...@@ -7,9 +7,12 @@ ...@@ -7,9 +7,12 @@
#include "utils.h" #include "utils.h"
template <typename T> template <typename T>
__global__ void per_token_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output_q, __global__ void per_token_quant_fp8_kernel(
float* __restrict__ output_s, const int64_t hidden_dim, const T* __restrict__ input,
const int64_t num_tokens) { FP8_TYPE* __restrict__ output_q,
float* __restrict__ output_s,
const int64_t hidden_dim,
const int64_t num_tokens) {
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
if (token_idx >= num_tokens) return; if (token_idx >= num_tokens) return;
...@@ -110,8 +113,11 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: ...@@ -110,8 +113,11 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
per_token_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>( per_token_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()), static_cast<scalar_t*>(input.data_ptr()),
static_cast<float*>(output_s.data_ptr()), hidden_dim, num_tokens); static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
return true; return true;
}); });
} }
...@@ -25,9 +25,11 @@ limitations under the License. ...@@ -25,9 +25,11 @@ limitations under the License.
#define WARP_SIZE 32 #define WARP_SIZE 32
template <typename scalar_t> template <typename scalar_t>
__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids, __global__ void count_and_sort_expert_tokens_kernel(
int32_t* __restrict__ sorted_token_ids, const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ cumsum_buffer, size_t numel) { int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ cumsum_buffer,
size_t numel) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x; const size_t stride = blockDim.x * gridDim.x;
...@@ -39,10 +41,15 @@ __global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ ...@@ -39,10 +41,15 @@ __global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids, __global__ void moe_align_block_size_kernel(
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, int32_t* __restrict__ sorted_token_ids,
int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) { int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad,
int32_t num_experts,
int32_t block_size,
size_t numel,
int32_t* __restrict__ cumsum) {
__shared__ int32_t shared_counts[WARP_SIZE][8]; __shared__ int32_t shared_counts[WARP_SIZE][8];
const int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
...@@ -91,17 +98,29 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id ...@@ -91,17 +98,29 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id
} }
} }
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, void moe_align_block_size(
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, torch::Tensor topk_ids,
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { int64_t num_experts,
int64_t block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer,
torch::Tensor cumsum_buffer) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now."); TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now.");
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto align_kernel = moe_align_block_size_kernel<scalar_t>; auto align_kernel = moe_align_block_size_kernel<scalar_t>;
align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(), align_kernel<<<1, 1024, 0, stream>>>(
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(), topk_ids.data_ptr<scalar_t>(),
num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>()); sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel(),
cumsum_buffer.data_ptr<int32_t>());
const int block_threads = 256; const int block_threads = 256;
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
...@@ -109,8 +128,10 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b ...@@ -109,8 +128,10 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
const int actual_blocks = std::min(num_blocks, max_blocks); const int actual_blocks = std::min(num_blocks, max_blocks);
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>; auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
sorted_token_ids.data_ptr<int32_t>(), topk_ids.data_ptr<scalar_t>(),
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel()); sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(),
topk_ids.numel());
}); });
} }
...@@ -23,10 +23,18 @@ ...@@ -23,10 +23,18 @@
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = // tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b, // [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token] // draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
__global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len, __global__ void build_tree_efficient(
bool* tree_mask, int64_t* positions, int64_t* retrive_index, int64_t* parent_list,
int64_t* retrive_next_token, int64_t* retrive_next_sibling, int topk, int depth, int64_t* selected_index,
int draft_token_num) { int32_t* verified_seq_len,
bool* tree_mask,
int64_t* positions,
int64_t* retrive_index,
int64_t* retrive_next_token,
int64_t* retrive_next_sibling,
int topk,
int depth,
int draft_token_num) {
int bid = blockIdx.x; int bid = blockIdx.x;
int tid = threadIdx.x; int tid = threadIdx.x;
...@@ -99,10 +107,18 @@ __global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_ind ...@@ -99,10 +107,18 @@ __global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_ind
} }
} }
void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, void build_tree_kernel_efficient(
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, at::Tensor parent_list,
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk, at::Tensor selected_index,
int64_t depth, int64_t draft_token_num) { at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
int64_t topk,
int64_t depth,
int64_t draft_token_num) {
// TODO (ying) check shape // TODO (ying) check shape
// TODO (ying) check type // TODO (ying) check type
int bs = parent_list.size(0); int bs = parent_list.size(0);
...@@ -111,11 +127,17 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind ...@@ -111,11 +127,17 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
build_tree_efficient<<<grid, block, 0, stream>>>( build_tree_efficient<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()), static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int32_t*>(verified_seq_len.data_ptr()), static_cast<bool*>(tree_mask.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()), static_cast<int64_t*>(retrive_index.data_ptr()), static_cast<int32_t*>(verified_seq_len.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()), static_cast<int64_t*>(retrive_next_sibling.data_ptr()), static_cast<bool*>(tree_mask.data_ptr()),
int32_t(topk), int32_t(depth), int32_t(draft_token_num)); static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num));
} }
// parent_list [bs, topk * (depth - 1) + 1)] // parent_list [bs, topk * (depth - 1) + 1)]
...@@ -124,8 +146,16 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind ...@@ -124,8 +146,16 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = // tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b, // [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token, depth + 2] // draft_token, depth + 2]
__global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len, bool* tree_mask, __global__ void build_tree(
int64_t* positions, int64_t* retrive_index, int topk, int depth, int draft_token_num) { int64_t* parent_list,
int64_t* selected_index,
int32_t* verified_seq_len,
bool* tree_mask,
int64_t* positions,
int64_t* retrive_index,
int topk,
int depth,
int draft_token_num) {
int bid = blockIdx.x; int bid = blockIdx.x;
int tid = threadIdx.x; int tid = threadIdx.x;
...@@ -191,9 +221,16 @@ __global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_ ...@@ -191,9 +221,16 @@ __global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_
} }
} }
void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, void build_tree_kernel(
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk, at::Tensor parent_list,
int64_t depth, int64_t draft_token_num) { at::Tensor selected_index,
at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
int64_t topk,
int64_t depth,
int64_t draft_token_num) {
// TODO (ying) check shape // TODO (ying) check shape
// TODO (ying) check type // TODO (ying) check type
int bs = parent_list.size(0); int bs = parent_list.size(0);
...@@ -202,8 +239,13 @@ void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Te ...@@ -202,8 +239,13 @@ void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Te
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
build_tree<<<grid, block, 0, stream>>>( build_tree<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()), static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int32_t*>(verified_seq_len.data_ptr()), static_cast<bool*>(tree_mask.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()), static_cast<int64_t*>(retrive_index.data_ptr()), int32_t(topk), static_cast<int32_t*>(verified_seq_len.data_ptr()),
int32_t(depth), int32_t(draft_token_num)); static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num));
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment