Unverified Commit d052f4c8 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

New clang format for sgl kernel (#4194)

parent e1aaa79a
cp ../README.md ../LICENSE .
rm -rf dist
python3 -m build
python3 -m twine upload dist/*
rm -rf README.md LICENSE
......@@ -6,3 +6,10 @@ DerivePointerAlignment: false
PointerAlignment: Left
NamespaceIndentation: None
SortIncludes: true
AllowShortLoopsOnASingleLine: false
BinPackParameters: false # Prevents packing parameters in declarations
BinPackArguments: false # Prevents packing arguments in function calls
AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis
AlignOperands: Align # Aligns arguments vertically
PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument
PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name
......@@ -41,10 +41,15 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
// support float16, bfloat16 and float32
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::FusedAddRMSNorm(
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
static_cast<c_type*>(input.data_ptr()),
static_cast<c_type*>(residual.data_ptr()),
static_cast<c_type*>(weight.data_ptr()),
batch_size,
hidden_size,
eps,
torch_current_stream);
TORCH_CHECK(
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
return true;
});
}
......@@ -153,19 +153,20 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
DINLINE void start_sync(const RankSignals& sg,
DINLINE void start_sync(
const RankSignals& sg,
#ifndef USE_ROCM
volatile
volatile
#endif
Signal* self_sg,
int rank) {
Signal* self_sg,
int rank) {
#ifdef USE_ROCM
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED,
__MEMORY_SCOPE_SYSTEM);
__scoped_atomic_store_n(
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) <
flag)
......@@ -193,12 +194,13 @@ DINLINE void start_sync(const RankSignals& sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals& sg,
DINLINE void end_sync(
const RankSignals& sg,
#ifndef USE_ROCM
volatile
volatile
#endif
Signal* self_sg,
int rank) {
Signal* self_sg,
int rank) {
#ifdef USE_ROCM
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
......@@ -209,11 +211,16 @@ DINLINE void end_sync(const RankSignals& sg,
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, __MEMORY_SCOPE_SYSTEM);
__scoped_atomic_store_n(
&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
__MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
while (__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, __MEMORY_SCOPE_DEVICE) < flag)
while (__scoped_atomic_load_n(
&self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
__MEMORY_SCOPE_DEVICE) < flag)
;
}
__syncthreads();
......@@ -251,12 +258,16 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
RankData* _dp,
RankSignals sg,
#ifndef USE_ROCM
volatile
volatile
#endif
Signal* self_sg,
T* __restrict__ result, int rank, int size) {
Signal* self_sg,
T* __restrict__ result,
int rank,
int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
......@@ -280,12 +291,16 @@ DINLINE P* get_tmp_buf(volatile Signal* sg) {
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
RankData* _dp,
RankSignals sg,
#ifndef USE_ROCM
volatile
volatile
#endif
Signal* self_sg,
T* __restrict__ result, int rank, int size) {
Signal* self_sg,
T* __restrict__ result,
int rank,
int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
......@@ -357,8 +372,14 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
*/
CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, const hipIpcMemHandle_t* handles,
const std::vector<int64_t>& offsets, int rank, bool full_nvlink = true)
CustomAllreduce(
Signal* meta,
void* rank_data,
size_t rank_data_sz,
const hipIpcMemHandle_t* handles,
const std::vector<int64_t>& offsets,
int rank,
bool full_nvlink = true)
: rank_(rank),
world_size_(offsets.size()),
full_nvlink_(full_nvlink),
......@@ -382,8 +403,8 @@ class CustomAllreduce {
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
CUDACHECK(hipIpcOpenMemHandle((void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle),
hipIpcMemLazyEnablePeerAccess));
CUDACHECK(hipIpcOpenMemHandle(
(void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), hipIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
......@@ -399,13 +420,14 @@ class CustomAllreduce {
void* base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
if (hipPointerGetAttribute(&base_ptr,
if (hipPointerGetAttribute(
&base_ptr,
#ifdef USE_ROCM
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#else
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#endif
(hipDeviceptr_t)ptr) != hipSuccess)
(hipDeviceptr_t)ptr) != hipSuccess)
throw std::runtime_error("failed to get pointer attr");
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
......@@ -415,8 +437,8 @@ class CustomAllreduce {
void check_rank_data_capacity(size_t num = 1) {
if (d_rank_data_base_ + num > d_rank_data_end_)
throw std::runtime_error("Rank data buffer is overflowed by " +
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
throw std::runtime_error(
"Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
}
void register_buffer(const std::vector<std::string>& handles, const std::vector<int64_t>& offsets, void* self) {
......@@ -443,8 +465,8 @@ class CustomAllreduce {
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void register_graph_buffers(const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
void
register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers);
......@@ -474,11 +496,17 @@ class CustomAllreduce {
* will cause contention on NVLink bus.
*/
template <typename T>
void allreduce(hipStream_t stream, T* input, T* output, int size,
void allreduce(
hipStream_t stream,
T* input,
T* output,
int size,
#ifndef USE_ROCM
int threads = 512, int block_limit = 36){
int threads = 512,
int block_limit = 36){
#else
int threads = 512, int block_limit = 16) {
int threads = 512,
int block_limit = 16) {
#endif
auto d = packed_t<T>::P::size;
if (size % d != 0)
......@@ -487,8 +515,8 @@ class CustomAllreduce {
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error("max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit));
throw std::runtime_error(
"max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
RankData* ptrs;
hipStreamCaptureStatus status;
......@@ -499,17 +527,17 @@ class CustomAllreduce {
} else {
auto it = buffers_.find(input);
if (it == buffers_.end())
throw std::runtime_error("buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) +
" is not registered!");
throw std::runtime_error(
"buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
ptrs = it->second;
}
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = ::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) \
hipLaunchKernelGGL((name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, \
size);
#define KL(ngpus, name) \
hipLaunchKernelGGL( \
(name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, size);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
......
......@@ -118,8 +118,13 @@ inline __device__ int4 add128b(T& a, T& b) {
return c.packed;
}
__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
size_t const world_size, int const tidx, int const bidx) {
__inline__ __device__ void multi_gpu_barrier(
uint32_t** signals,
uint32_t const flag,
size_t const local_rank,
size_t const world_size,
int const tidx,
int const bidx) {
// After this function, at least one block in each GPU has reached the barrier
if (tidx < world_size) {
// we can think of signals having the shape [world_size, world_size]
......@@ -143,8 +148,14 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
}
template <bool start, bool need_fence = false>
__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
size_t const world_size, int const tidx, int const bidx, int const grid_size) {
__inline__ __device__ void block_barrier(
uint32_t** signals,
uint32_t const flag,
size_t const local_rank,
size_t const world_size,
int const tidx,
int const bidx,
int const grid_size) {
if constexpr (!start) {
__syncthreads();
}
......@@ -227,8 +238,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc
}
}
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
block_barrier<true>(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
grid_size);
block_barrier<true>(
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Each block accumulates the values from the different GPUs on the same node.
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
......@@ -341,8 +352,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
}
}
}
block_barrier<true>(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
grid_size);
block_barrier<true>(
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Each block accumulates the values from the different GPUs on the same node.
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
......@@ -372,8 +383,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
}
}
block_barrier<false, true>(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx,
bidx, grid_size);
block_barrier<false, true>(
params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
// Gather all needed elts from other intra-node ranks
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
......@@ -459,8 +470,12 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT>
void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block,
cudaStream_t stream) {
void dispatchARKernels(
AllReduceStrategyType algo,
AllReduceParams& param,
int blocks_per_grid,
int threads_per_block,
cudaStream_t stream) {
switch (algo) {
case AllReduceStrategyType::ONESHOT: {
oneShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
......@@ -505,8 +520,8 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy
CHECK_CUDA_SUCCESS(cudaGetLastError());
}
void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat,
cudaStream_t stream) {
void trtCustomAllReduce(
AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream) {
if (params.elts_total == 0) {
return;
}
......
......@@ -29,9 +29,14 @@ using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
class AllReduceMeta {
public:
AllReduceMeta(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out) {
AllReduceMeta(
int64_t rank_id,
int64_t world_size,
torch::Tensor& rank_data,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out) {
this->rank_id = (int)rank_id;
this->world_size = (int)world_size;
this->barrier_in = barrier_in;
......@@ -86,9 +91,14 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype)
return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0;
}
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out) {
fptr_t init_custom_ar(
int64_t rank_id,
int64_t world_size,
torch::Tensor& rank_data,
const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers,
const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out) {
auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out);
return (fptr_t)m;
}
......@@ -124,8 +134,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle((void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle),
cudaIpcMemLazyEnablePeerAccess));
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle(
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
......@@ -138,8 +148,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) {
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
void register_graph_buffers(
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
AllReduceMeta* m = reinterpret_cast<AllReduceMeta*>(_fa);
std::vector<std::string> handle_bytes;
handle_bytes.reserve(handles.size());
......
......@@ -23,15 +23,18 @@ limitations under the License.
#define THREADS_PER_BLOCK 128
template <typename T>
__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d]
const T* __restrict__ k, // [b, h, 1, d]
const T* __restrict__ v, // [b, h, 1, e]
const float* __restrict__ past_kv, // [b, h, d, e]
const float* __restrict__ slope, // [h, 1, 1]
T* __restrict__ output, // [b, h, 1, e]
float* __restrict__ new_kv, // [b, h, d, e]
const int batch_size, const int num_heads, const int qk_dim,
const int v_dim) {
__global__ void lightning_attention_decode_kernel(
const T* __restrict__ q, // [b, h, 1, d]
const T* __restrict__ k, // [b, h, 1, d]
const T* __restrict__ v, // [b, h, 1, e]
const float* __restrict__ past_kv, // [b, h, d, e]
const float* __restrict__ slope, // [h, 1, 1]
T* __restrict__ output, // [b, h, 1, e]
float* __restrict__ new_kv, // [b, h, d, e]
const int batch_size,
const int num_heads,
const int qk_dim,
const int v_dim) {
extern __shared__ char smem[];
T* __restrict__ q_shared = reinterpret_cast<T*>(smem);
T* __restrict__ k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
......@@ -109,9 +112,14 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
}
}
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
torch::Tensor new_kv) {
void lightning_attention_decode(
const torch::Tensor& q,
const torch::Tensor& k,
const torch::Tensor& v,
const torch::Tensor& past_kv,
const torch::Tensor& slope,
torch::Tensor output,
torch::Tensor new_kv) {
TORCH_CHECK(q.is_contiguous(), "q must be contiguous");
TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
......@@ -131,8 +139,16 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k,
at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] {
size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float);
lightning_attention_decode_kernel<scalar_t><<<grid, block, smem_size, stream>>>(
q.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), v.data_ptr<scalar_t>(), past_kv.data_ptr<float>(),
slope.data_ptr<float>(), output.data_ptr<scalar_t>(), new_kv.data_ptr<float>(), batch_size, num_heads,
qk_dim, v_dim);
q.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(),
v.data_ptr<scalar_t>(),
past_kv.data_ptr<float>(),
slope.data_ptr<float>(),
output.data_ptr<scalar_t>(),
new_kv.data_ptr<float>(),
batch_size,
num_heads,
qk_dim,
v_dim);
}));
}
......@@ -25,9 +25,15 @@ namespace cutlass {
namespace epilogue {
namespace threadblock {
template <typename ThreadblockShape_, int ThreadCount, typename ScaleTileIterator_, typename OutputTileIterator_,
typename ElementAccumulator_, typename ElementCompute_, typename ElementwiseFunctor_,
bool UseMasking_ = false>
template <
typename ThreadblockShape_,
int ThreadCount,
typename ScaleTileIterator_,
typename OutputTileIterator_,
typename ElementAccumulator_,
typename ElementCompute_,
typename ElementwiseFunctor_,
bool UseMasking_ = false>
class EpilogueVisitorPerRowPerCol {
public:
using ThreadblockShape = ThreadblockShape_;
......@@ -69,8 +75,11 @@ class EpilogueVisitorPerRowPerCol {
Arguments(typename ElementwiseFunctor::Params elementwise_)
: elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, int64_t batch_stride_C_,
int64_t batch_stride_D_)
Arguments(
typename ElementwiseFunctor::Params elementwise_,
int64_t batch_stride_alpha_,
int64_t batch_stride_C_,
int64_t batch_stride_D_)
: elementwise(elementwise_),
batch_stride_alpha(batch_stride_alpha_),
batch_stride_C(batch_stride_C_),
......@@ -131,17 +140,26 @@ class EpilogueVisitorPerRowPerCol {
public:
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage,
cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx,
typename ScaleTileIterator::Params params_alpha_col,
typename OutputTileIterator::Params params_C,
typename OutputTileIterator::Params params_D, bool with_bias, bool per_token_quant,
bool per_channel_quant, AlphaScaleElementType* ptr_alpha_row,
AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C,
typename OutputTileIterator::Element* ptr_D,
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
int column_offset = 0,
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
EpilogueVisitorPerRowPerCol(
Params const& params,
SharedStorage& shared_storage,
cutlass::MatrixCoord const& problem_size,
int thread_idx,
int warp_idx,
int lane_idx,
typename ScaleTileIterator::Params params_alpha_col,
typename OutputTileIterator::Params params_C,
typename OutputTileIterator::Params params_D,
bool with_bias,
bool per_token_quant,
bool per_channel_quant,
AlphaScaleElementType* ptr_alpha_row,
AlphaScaleElementType* ptr_alpha_col,
typename OutputTileIterator::Element* ptr_C,
typename OutputTileIterator::Element* ptr_D,
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
int column_offset = 0,
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
: params_(params),
shared_storage_(shared_storage),
extent_(problem_size),
......@@ -166,8 +184,9 @@ class EpilogueVisitorPerRowPerCol {
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
void set_k_partition(
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
}
/// Called to set the batch index
......@@ -251,8 +270,8 @@ class EpilogueVisitorPerRowPerCol {
private:
CUTLASS_DEVICE
ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum, ComputeFragment const& scale_col,
AlphaScaleElementType const& scale_row) {
ComputeFragment per_token_channel_scale_accumulator_(
ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) {
......@@ -263,8 +282,8 @@ class EpilogueVisitorPerRowPerCol {
}
CUTLASS_DEVICE
ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum, AlphaScaleElementType const& scale_col,
AlphaScaleElementType const& scale_row) {
ComputeFragment per_token_scale_accumulator_(
ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) {
......
......@@ -16,16 +16,20 @@ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelT
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
// specialized dynamic schedule For FP8 kernels with Block Scaling
template <int Stages_, class ClusterShape_ = Shape<_1, _1, _1>, class KernelSchedule = KernelTmaWarpSpecialized,
int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
template <
int Stages_,
class ClusterShape_ = Shape<_1, _1, _1>,
class KernelSchedule = KernelTmaWarpSpecialized,
int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M,
// while zero-value `ScaleGranularityM` indicates that scaling
// granularity is `size<0>(TileShape_MNK{})` along M.
>
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
static_assert(cute::is_same_v<KernelSchedule,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>,
"KernelSchedule must be one of the warp specialized policies");
static_assert(
cute::
is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>,
"KernelSchedule must be one of the warp specialized policies");
};
//////////////////////////////////////////////////////////////////////////////
......
......@@ -159,8 +159,9 @@ class GemmUniversalBaseCompat {
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n"
<< " result = {" << result << "}");
CUTLASS_TRACE_HOST(
" grid_tiled_shape: " << grid_tiled_shape << "\n"
<< " result = {" << result << "}");
return result;
}
......@@ -175,8 +176,8 @@ class GemmUniversalBaseCompat {
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
if (smem_size <= (48 << 10)) {
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel<GemmKernel>,
GemmKernel::kThreadCount, smem_size);
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
if (result == cudaSuccess) {
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
......@@ -184,12 +185,12 @@ class GemmUniversalBaseCompat {
}
} else {
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel<GemmKernel>,
GemmKernel::kThreadCount, 0);
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<< cudaGetErrorString(result));
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1;
}
......@@ -226,8 +227,9 @@ class GemmUniversalBaseCompat {
/// Initializes GEMM state from arguments.
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
CUTLASS_TRACE_HOST(
"GemmUniversalBaseCompat::initialize() - workspace " << workspace
<< ", stream: " << (stream ? "non-null" : "null"));
size_t workspace_bytes = get_workspace_size(args);
......
......@@ -32,10 +32,11 @@ namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmWithEpilogueVisitor {
public:
using Mma = Mma_;
......@@ -119,9 +120,15 @@ struct GemmWithEpilogueVisitor {
Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {}
/// constructs an arguments structure
Arguments(GemmCoord problem_size_, TensorRefA ref_A_, TensorRefB ref_B_, TensorRefAlphaCol ref_alpha_col_,
TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_,
typename EpilogueVisitor::Arguments epilogue_visitor_)
Arguments(
GemmCoord problem_size_,
TensorRefA ref_A_,
TensorRefB ref_B_,
TensorRefAlphaCol ref_alpha_col_,
TensorRefAlphaRow ref_alpha_row_,
TensorRefC ref_C_,
TensorRefC ref_D_,
typename EpilogueVisitor::Arguments epilogue_visitor_)
: mode(GemmUniversalMode::kGemm),
problem_size(problem_size_),
batch_count(1),
......@@ -269,8 +276,9 @@ struct GemmWithEpilogueVisitor {
isAMisaligned = problem_size.k() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
isAMisaligned = problem_size.m() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value ||
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
} else if (
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value ||
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
}
......@@ -278,8 +286,9 @@ struct GemmWithEpilogueVisitor {
isBMisaligned = problem_size.n() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value ||
platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
} else if (
platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value ||
platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
}
......@@ -287,8 +296,9 @@ struct GemmWithEpilogueVisitor {
isCMisaligned = problem_size.n() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
isCMisaligned = problem_size.m() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value ||
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
} else if (
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value ||
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
}
......@@ -373,11 +383,11 @@ struct GemmWithEpilogueVisitor {
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx,
tb_offset_A);
typename Mma::IteratorA iterator_A(
params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx,
tb_offset_B);
typename Mma::IteratorB iterator_B(
params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
......@@ -409,8 +419,8 @@ struct GemmWithEpilogueVisitor {
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle
MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN);
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
......@@ -423,11 +433,25 @@ struct GemmWithEpilogueVisitor {
with_bias = false;
}
EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, params.problem_size.mn(),
thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C,
params.params_D, with_bias, true, true, params.ptr_alpha_row, params.ptr_alpha_col,
params.ptr_C, params.ptr_D, threadblock_offset,
blockIdx.y * params.problem_size.m());
EpilogueVisitor epilogue_visitor(
params.epilogue_visitor,
shared_storage.epilogue.visitor,
params.problem_size.mn(),
thread_idx,
warp_idx,
lane_idx,
params.params_alpha_col,
params.params_C,
params.params_D,
with_bias,
true,
true,
params.ptr_alpha_row,
params.ptr_alpha_col,
params.ptr_C,
params.ptr_D,
threadblock_offset,
blockIdx.y * params.problem_size.m());
if (params.mode == GemmUniversalMode::kGemm) {
// Indicate which position in a serial reduction the output operator is currently updating
......
......@@ -21,10 +21,13 @@
#include "utils.h"
static void check_group_count(const std::vector<torch::Tensor>& inputs, const std::vector<torch::Tensor>& weights,
const std::vector<torch::Tensor>& outputs) {
TORCH_CHECK(((inputs.size() == weights.size()) && (inputs.size() == outputs.size())),
"The group count of inputs, weights and outputs should be the same.");
static void check_group_count(
const std::vector<torch::Tensor>& inputs,
const std::vector<torch::Tensor>& weights,
const std::vector<torch::Tensor>& outputs) {
TORCH_CHECK(
((inputs.size() == weights.size()) && (inputs.size() == outputs.size())),
"The group count of inputs, weights and outputs should be the same.");
}
static void check_device_dtype(const torch::Dtype& dtype, const std::vector<torch::Tensor>& tensors) {
......@@ -68,21 +71,26 @@ static std::vector<void*> get_tensor_ptrs(const std::vector<torch::Tensor>& tens
static torch::Tensor create_ptr_pointer(const std::vector<void*>& ptrs, cudaStream_t stream) {
auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA);
torch::Tensor gpu_ptrs = torch::empty({static_cast<int>(ptrs.size())}, options);
TORCH_CHECK(cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice,
stream) == CUBLAS_STATUS_SUCCESS);
TORCH_CHECK(
cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, stream) ==
CUBLAS_STATUS_SUCCESS);
return gpu_ptrs;
}
// We want compute input @ weight^T in row major
// This is equivalent to computing weight @ input^T in col major
// Cublas only accepts matrix in column major, so this arrangement is needed
void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, // b: (m, k) row major = (k, m) col major
const std::vector<torch::Tensor>& weights, // a: (n, k) row major = (n, k)^T col major
const std::vector<torch::Tensor>& outputs, // c: (m, n) row major = (n, m) col major
const torch::Dtype& out_dtype, int64_t cublas_handle, int64_t cuda_stream) {
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16,
"cublas grouped_gemm can"
"only be applied to float16 and bfloat16 dtype");
void cublas_grouped_gemm(
const std::vector<torch::Tensor>& inputs, // b: (m, k) row major = (k, m) col major
const std::vector<torch::Tensor>& weights, // a: (n, k) row major = (n, k)^T col major
const std::vector<torch::Tensor>& outputs, // c: (m, n) row major = (n, m) col major
const torch::Dtype& out_dtype,
int64_t cublas_handle,
int64_t cuda_stream) {
TORCH_CHECK(
out_dtype == torch::kHalf || out_dtype == torch::kBFloat16,
"cublas grouped_gemm can"
"only be applied to float16 and bfloat16 dtype");
int group_count = inputs.size();
check_group_count(inputs, weights, outputs);
......@@ -133,16 +141,32 @@ void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, // b: (m, k
torch::Tensor d_c = create_ptr_pointer(c_array, stream);
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
auto status = cublasGemmGroupedBatchedEx(handle, transa_array.data(), transb_array.data(), m_array.data(),
n_array.data(), k_array.data(), alpha_array.data(), (void**)d_a.data_ptr(),
cuda_data_type, lda_array.data(), (void**)d_b.data_ptr(), cuda_data_type,
ldb_array.data(), beta_array.data(), (void**)d_c.data_ptr(), cuda_data_type,
ldc_array.data(), group_count, group_size.data(), CUBLAS_COMPUTE_32F);
auto status = cublasGemmGroupedBatchedEx(
handle,
transa_array.data(),
transb_array.data(),
m_array.data(),
n_array.data(),
k_array.data(),
alpha_array.data(),
(void**)d_a.data_ptr(),
cuda_data_type,
lda_array.data(),
(void**)d_b.data_ptr(),
cuda_data_type,
ldb_array.data(),
beta_array.data(),
(void**)d_c.data_ptr(),
cuda_data_type,
ldc_array.data(),
group_count,
group_size.data(),
CUBLAS_COMPUTE_32F);
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status));
TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization");
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false,
"Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion());
TORCH_CHECK_NOT_IMPLEMENTED(
false, "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion());
}
......@@ -35,8 +35,12 @@
using namespace cute;
template <typename OutType, typename TileShape, typename ClusterShape, int ScaleGranularityM = 1>
void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b) {
void launch_sm90_fp8_blockwise_scaled_mm(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b) {
using ElementAccumulator = float;
using ElementCompute = float;
using ElementBlockScale = float;
......@@ -66,19 +70,43 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementCompute, ElementC,
LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, EpilogueSchedule, StoreEpilogueCompute>::CollectiveOp;
ArchTag,
OperatorClass,
TileShape,
ClusterShape,
EpilogueTileType,
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
AlignmentC,
ElementD,
LayoutD,
AlignmentD,
EpilogueSchedule,
StoreEpilogueCompute>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator,
TileShape, ClusterShape,
ArchTag,
OperatorClass,
ElementA,
LayoutA,
AlignmentA,
ElementB,
LayoutB,
AlignmentB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
Gemm gemm_op;
......@@ -127,16 +155,23 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor
}
template <typename OutType>
void sm90_fp8_blockwise_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b) {
void sm90_fp8_blockwise_dispatch_shape(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b) {
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>;
launch_sm90_fp8_blockwise_scaled_mm<OutType, TileShape, ClusterShape>(out, a, b, scales_a, scales_b);
}
torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const torch::Dtype& out_dtype) {
torch::Tensor fp8_blockwise_scaled_mm(
const torch::Tensor& mat_a,
const torch::Tensor& mat_b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Dtype& out_dtype) {
TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
......@@ -145,10 +180,10 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor");
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0,
"mat_a must be multiple of 16 bytes for memory alignment");
TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0,
"mat_b must be multiple of 16 bytes for memory alignment");
TORCH_CHECK(
(mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment");
TORCH_CHECK(
(mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment");
TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn");
TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn");
TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16");
......@@ -186,6 +221,6 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false,
"No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
}
......@@ -8,8 +8,8 @@
#include "utils.h"
template <typename T>
__global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s,
const int64_t num_elements) {
__global__ void
per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, const int64_t num_elements) {
float max_value = 0.0f;
unsigned int tid = threadIdx.x;
unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -56,8 +56,11 @@ __global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __r
}
template <typename T>
__global__ void per_tensor_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output,
const float* __restrict__ scale, const int64_t num_elements) {
__global__ void per_tensor_quant_fp8_kernel(
const T* __restrict__ input,
FP8_TYPE* __restrict__ output,
const float* __restrict__ scale,
const int64_t num_elements) {
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
const int grid_size = blockDim.x * gridDim.x;
const float scale_val = 1.0f / (*scale);
......@@ -124,8 +127,10 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
}
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()), num_elements);
static_cast<scalar_t*>(input.data_ptr()),
static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
num_elements);
return true;
});
}
......@@ -17,10 +17,15 @@ __device__ __forceinline__ float GroupReduce(float val, const int tid) {
}
template <typename T>
__global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, void* __restrict__ output_q,
float* __restrict__ output_s, const int group_size,
const int num_groups, const float eps, const float fp8_min,
const float fp8_max) {
__global__ void per_token_group_quant_fp8_kernel(
const T* __restrict__ input,
void* __restrict__ output_q,
float* __restrict__ output_s,
const int group_size,
const int num_groups,
const float eps,
const float fp8_min,
const float fp8_max) {
const int groups_per_block = 16;
const int local_group_id = threadIdx.x / 16;
const int lane_id = threadIdx.x % 16;
......@@ -80,8 +85,14 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo
}
}
void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s,
int64_t group_size, double eps, double fp8_min, double fp8_max) {
void sgl_per_token_group_quant_fp8(
torch::Tensor input,
torch::Tensor output_q,
torch::Tensor output_s,
int64_t group_size,
double eps,
double fp8_min,
double fp8_max) {
CHECK_INPUT(input);
CHECK_INPUT(output_q);
CHECK_INPUT(output_s);
......@@ -97,8 +108,14 @@ void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q,
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
per_token_group_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), output_q.data_ptr(), static_cast<float*>(output_s.data_ptr()),
group_size, num_groups, (float)eps, (float)fp8_min, (float)fp8_max);
static_cast<scalar_t*>(input.data_ptr()),
output_q.data_ptr(),
static_cast<float*>(output_s.data_ptr()),
group_size,
num_groups,
(float)eps,
(float)fp8_min,
(float)fp8_max);
return true;
});
}
......@@ -7,9 +7,12 @@
#include "utils.h"
template <typename T>
__global__ void per_token_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output_q,
float* __restrict__ output_s, const int64_t hidden_dim,
const int64_t num_tokens) {
__global__ void per_token_quant_fp8_kernel(
const T* __restrict__ input,
FP8_TYPE* __restrict__ output_q,
float* __restrict__ output_s,
const int64_t hidden_dim,
const int64_t num_tokens) {
const int token_idx = blockIdx.x;
if (token_idx >= num_tokens) return;
......@@ -110,8 +113,11 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
per_token_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()), hidden_dim, num_tokens);
static_cast<scalar_t*>(input.data_ptr()),
static_cast<FP8_TYPE*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
return true;
});
}
......@@ -25,9 +25,11 @@ limitations under the License.
#define WARP_SIZE 32
template <typename scalar_t>
__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ cumsum_buffer, size_t numel) {
__global__ void count_and_sort_expert_tokens_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ cumsum_buffer,
size_t numel) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
......@@ -39,10 +41,15 @@ __global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__
}
template <typename scalar_t>
__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) {
__global__ void moe_align_block_size_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad,
int32_t num_experts,
int32_t block_size,
size_t numel,
int32_t* __restrict__ cumsum) {
__shared__ int32_t shared_counts[WARP_SIZE][8];
const int warp_id = threadIdx.x / WARP_SIZE;
......@@ -91,17 +98,29 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id
}
}
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) {
void moe_align_block_size(
torch::Tensor topk_ids,
int64_t num_experts,
int64_t block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer,
torch::Tensor cumsum_buffer) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now.");
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
align_kernel<<<1, 1024, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel(),
cumsum_buffer.data_ptr<int32_t>());
const int block_threads = 256;
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
......@@ -109,8 +128,10 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
const int actual_blocks = std::min(num_blocks, max_blocks);
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(),
topk_ids.numel());
});
}
......@@ -23,10 +23,18 @@
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
__global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len,
bool* tree_mask, int64_t* positions, int64_t* retrive_index,
int64_t* retrive_next_token, int64_t* retrive_next_sibling, int topk, int depth,
int draft_token_num) {
__global__ void build_tree_efficient(
int64_t* parent_list,
int64_t* selected_index,
int32_t* verified_seq_len,
bool* tree_mask,
int64_t* positions,
int64_t* retrive_index,
int64_t* retrive_next_token,
int64_t* retrive_next_sibling,
int topk,
int depth,
int draft_token_num) {
int bid = blockIdx.x;
int tid = threadIdx.x;
......@@ -99,10 +107,18 @@ __global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_ind
}
}
void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index,
at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk,
int64_t depth, int64_t draft_token_num) {
void build_tree_kernel_efficient(
at::Tensor parent_list,
at::Tensor selected_index,
at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
int64_t topk,
int64_t depth,
int64_t draft_token_num) {
// TODO (ying) check shape
// TODO (ying) check type
int bs = parent_list.size(0);
......@@ -111,11 +127,17 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
build_tree_efficient<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int32_t*>(verified_seq_len.data_ptr()), static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()), static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()), static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk), int32_t(depth), int32_t(draft_token_num));
static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int32_t*>(verified_seq_len.data_ptr()),
static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num));
}
// parent_list [bs, topk * (depth - 1) + 1)]
......@@ -124,8 +146,16 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token, depth + 2]
__global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len, bool* tree_mask,
int64_t* positions, int64_t* retrive_index, int topk, int depth, int draft_token_num) {
__global__ void build_tree(
int64_t* parent_list,
int64_t* selected_index,
int32_t* verified_seq_len,
bool* tree_mask,
int64_t* positions,
int64_t* retrive_index,
int topk,
int depth,
int draft_token_num) {
int bid = blockIdx.x;
int tid = threadIdx.x;
......@@ -191,9 +221,16 @@ __global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_
}
}
void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len,
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk,
int64_t depth, int64_t draft_token_num) {
void build_tree_kernel(
at::Tensor parent_list,
at::Tensor selected_index,
at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
int64_t topk,
int64_t depth,
int64_t draft_token_num) {
// TODO (ying) check shape
// TODO (ying) check type
int bs = parent_list.size(0);
......@@ -202,8 +239,13 @@ void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Te
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
build_tree<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()), static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int32_t*>(verified_seq_len.data_ptr()), static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()), static_cast<int64_t*>(retrive_index.data_ptr()), int32_t(topk),
int32_t(depth), int32_t(draft_token_num));
static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int32_t*>(verified_seq_len.data_ptr()),
static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment