Unverified Commit 8d09630a authored by gongchensu's avatar gongchensu Committed by GitHub
Browse files

Merge branch 'demo131' into Issue/862

parents ab52dead 012df56c
#ifndef __PER_CHANNEL_DEQUANT_INT8_KERNEL_CUH__
#define __PER_CHANNEL_DEQUANT_INT8_KERNEL_CUH__
/**
* @brief Symmetric dequantization kernel for post-processing quantized matrix multiplication
*
* This kernel performs symmetric dequantization on the packed integer output from
* a quantized matrix multiplication. It converts integer results back to floating-point
* values by applying per-tensor scaling factors from both input and weight tensors,
* then adds bias terms.
*
* The dequantization formula is:
* y = x_scale * w_scale * y_packed + bias
*
* @tparam Tdata Output data type (typically bfloat16 or half)
*
* @param[out] y Output tensor after dequantization
* Shape: [M, N], Data type: Tdata
*
* @param[in] y_packed Packed integer output from quantized matmul
* Shape: [M, N], Data type: int32_t
* Contains integer results of: x_packed[i,:] * w_packed[:,j]
*
* @param[in] bias Bias tensor to add after dequantization
* Shape: [N], Data type: Tdata
* Broadcasted across all rows
*
* @param[in] x_packed Packed quantized input tensor (not directly used here)
* Shape: [M, K], Data type: int8_t
* Included for context of the computation pipeline
*
* @param[in] x_scale Per-tensor scaling factors for input
* Shape: [M], Data type: float
* One scale value per input row
*
* @param[in] w_packed Packed quantized weight tensor (not directly used here)
* Shape: [K, N], Data type: int8_t
* Included for context of the computation pipeline
*
* @param[in] w_scale Per-tensor scaling factors for weights
* Shape: [N], Data type: float
* One scale value per output column
*
* @param[in] M Batch size / number of input rows
*
* @param[in] K Inner dimension of matrix multiplication
*
* @param[in] N Output dimension / number of output columns
*
* @note This kernel assumes symmetric quantization (zero-point = 0)
* @note Each thread processes one element of the output matrix
* @note Grid and block dimensions should be configured to cover [M, N] output space
*/
template <typename Tdata>
__device__ void postSymKernel(Tdata *y, int32_t *y_packed, const Tdata *bias, const int8_t *x_packed, const float *x_scale, const int8_t *w_packed, const float *w_scale, int M, int K, int N) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= M || col >= N) {
return;
}
int idx = row * N + col;
float output1 = x_scale[row] * w_scale[col] * ((float)y_packed[idx]);
float output = output1 + (float)bias[col];
y[idx] = static_cast<Tdata>(output);
}
// y = x_scale * w_scale * y_packed
template <typename Tdata>
__device__ void postSymKernel(Tdata *y, int32_t *y_packed, const int8_t *x_packed, const float *x_scale, const int8_t *w_packed, const float *w_scale, int M, int K, int N) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= M || col >= N) {
return;
}
int idx = row * N + col;
float output = x_scale[row] * w_scale[col] * ((float)y_packed[idx]);
y[idx] = static_cast<Tdata>(output);
}
#endif // __PER_CHANNEL_DEQUANT_INT8_KERNEL_CUH__
#ifndef __GEMM_INFO_H__
#define __I8GEMM_INFO_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include <algorithm>
namespace op::i8gemm {
struct BlasMatrix {
int ndim;
int batch;
int stride;
int rows;
int cols;
int row_stride;
int col_stride;
static utils::Result<BlasMatrix> create(infiniopTensorDescriptor_t layout) {
BlasMatrix ans;
if (layout->ndim() == 2) {
ans.ndim = 2;
ans.batch = 1;
ans.stride = 0;
ans.rows = layout->dim(0);
ans.cols = layout->dim(1);
ans.row_stride = layout->stride(0);
ans.col_stride = layout->stride(1);
} else if (layout->ndim() == 3) {
ans.ndim = 3;
ans.batch = layout->dim(0);
ans.stride = ans.batch == 1 ? 0 : layout->stride(0);
ans.rows = layout->dim(1);
ans.cols = layout->dim(2);
ans.row_stride = layout->stride(1);
ans.col_stride = layout->stride(2);
} else {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (ans.row_stride != 1 && ans.col_stride != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
return utils::Result<BlasMatrix>(ans);
}
bool match_batch(int _batch) const {
return batch == _batch || batch == 1;
}
void transpose() {
std::swap(rows, cols);
std::swap(row_stride, col_stride);
}
int ld() const {
return row_stride == 1 ? col_stride : row_stride;
}
};
enum class MatrixLayout : char {
COL_MAJOR,
ROW_MAJOR,
};
class I8GemmInfo {
I8GemmInfo() = default;
public:
BlasMatrix a_matrix;
BlasMatrix b_matrix;
BlasMatrix out_matrix;
int m, n, k, batch;
static utils::Result<I8GemmInfo> create(
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
MatrixLayout layout) {
auto a_matrix = BlasMatrix::create(a_desc);
CHECK_RESULT(a_matrix);
auto b_matrix = BlasMatrix::create(b_desc);
CHECK_RESULT(b_matrix);
auto out_matrix = BlasMatrix::create(out_desc);
CHECK_RESULT(out_matrix);
if (out_matrix->rows != a_matrix->rows || out_matrix->cols != b_matrix->cols || a_matrix->cols != b_matrix->rows) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
auto batch = out_matrix->batch;
if (!a_matrix->match_batch(batch) || !b_matrix->match_batch(batch)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
auto m = out_matrix->rows;
auto n = out_matrix->cols;
auto k = a_matrix->cols;
return utils::Result<I8GemmInfo>(I8GemmInfo{
a_matrix.take(),
b_matrix.take(),
out_matrix.take(),
m,
n,
k,
batch});
}
};
} // namespace op::i8gemm
#endif // __I8GEMM_INFO_H__
#ifndef __I8GEMM_H__
#define __I8GEMM_H__
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::i8gemm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
size_t _workspace_size; \
I8GemmInfo _info; \
infiniDtype_t _out_dtype; \
\
Descriptor(Opaque *opaque, I8GemmInfo info, \
size_t workspace_size, \
infiniDtype_t out_dtype, \
infiniDevice_t device_type, int device_id) \
: InfiniopDescriptor{device_type, device_id}, _out_dtype(out_dtype), \
_opaque(opaque), _info(info), _workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t minWorkspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t bias_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t a_scale_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t b_scale_desc); \
template <unsigned int BLOCK_SIZE, typename Tdata> \
infiniStatus_t launchKernel(const I8GemmInfo &info, Tdata *y, \
const Tdata *bias, const int8_t *x_packed, \
const float *x_scale, const int8_t *w_packed, \
const float *w_scale, void *stream, void *workspace) const; \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *out, const void *bias, const void *a, \
const void *a_scale, const void *b, \
const void *b_scale, void *stream) const; \
}; \
}
#endif // __I8GEMM_H__
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
#pragma once
#include <cutlass/arch/memory.h>
#include <cutlass/numeric_conversion.h>
namespace cutlass {
namespace epilogue {
namespace threadblock {
template <
typename ThreadblockShape_,
int ThreadCount,
typename ScaleTileIterator_,
typename OutputTileIterator_,
typename ElementAccumulator_,
typename ElementCompute_,
typename ElementwiseFunctor_,
bool UseMasking_ = false>
class EpilogueVisitorPerRowPerCol {
public:
using ThreadblockShape = ThreadblockShape_;
static int const kThreadCount = ThreadCount;
using ScaleTileIterator = ScaleTileIterator_;
using OutputTileIterator = OutputTileIterator_;
using ElementwiseFunctor = ElementwiseFunctor_;
static int const kIterations = OutputTileIterator::kIterations;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ElementOutput = typename OutputTileIterator::Element;
using LayoutOutput = cutlass::layout::RowMajor;
using ElementAccumulator = ElementAccumulator_;
using AlphaScaleElementType = typename ScaleTileIterator::Element;
using ElementCompute = ElementCompute_;
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
/// Argument structure
struct Arguments {
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_alpha;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
Arguments(typename ElementwiseFunctor::Params elementwise_)
: elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
Arguments(
typename ElementwiseFunctor::Params elementwise_,
int64_t batch_stride_alpha_,
int64_t batch_stride_C_,
int64_t batch_stride_D_)
: elementwise(elementwise_),
batch_stride_alpha(batch_stride_alpha_),
batch_stride_C(batch_stride_C_),
batch_stride_D(batch_stride_D_) {}
};
struct Params {
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_alpha;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() {}
CUTLASS_HOST_DEVICE
Params(Arguments const &args)
: elementwise(args.elementwise),
batch_stride_alpha(args.batch_stride_alpha),
batch_stride_C(args.batch_stride_C),
batch_stride_D(args.batch_stride_D) {}
};
/// Shared storage
struct SharedStorage {};
private:
Params const &params_;
SharedStorage &shared_storage_;
MatrixCoord extent_;
MatrixCoord extent_real_;
ElementwiseFunctor elementwise_;
bool const with_bias_;
bool const per_token_quant_;
bool const per_channel_quant_;
AlphaScaleElementType *ptr_alpha_row_;
AlphaScaleElementType *ptr_alpha_col_;
ScaleTileIterator iterator_alpha_col_;
OutputTileIterator iterator_C_;
OutputTileIterator iterator_D_;
AlphaScaleElementType element_alpha_row_ = 1.0f;
AlphaScaleElementType element_alpha_col_ = 1.0f;
typename ScaleTileIterator::Fragment fragment_alpha_col_;
typename OutputTileIterator::Fragment fragment_C_;
typename OutputTileIterator::Fragment fragment_D_;
ElementAccumulator beta_;
int column_offset_;
MatrixCoord thread_offset_;
public:
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol(
Params const &params,
SharedStorage &shared_storage,
cutlass::MatrixCoord const &problem_size,
int thread_idx,
int warp_idx,
int lane_idx,
typename ScaleTileIterator::Params params_alpha_col,
typename OutputTileIterator::Params params_C,
typename OutputTileIterator::Params params_D,
bool with_bias,
bool per_token_quant,
bool per_channel_quant,
AlphaScaleElementType *ptr_alpha_row,
AlphaScaleElementType *ptr_alpha_col,
typename OutputTileIterator::Element *ptr_C,
typename OutputTileIterator::Element *ptr_D,
cutlass::MatrixCoord const &threadblock_offset = cutlass::MatrixCoord(0, 0),
int column_offset = 0,
cutlass::MatrixCoord const &problem_size_real = cutlass::MatrixCoord(0, 0))
: params_(params),
shared_storage_(shared_storage),
extent_(problem_size),
elementwise_(params.elementwise),
with_bias_(with_bias),
per_token_quant_(per_token_quant),
per_channel_quant_(per_channel_quant),
ptr_alpha_row_(ptr_alpha_row),
ptr_alpha_col_(ptr_alpha_col),
iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset),
iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
extent_real_(problem_size_real) {
if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) {
element_alpha_col_ = *ptr_alpha_col_;
}
if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) {
element_alpha_row_ = *ptr_alpha_row_;
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void begin_epilogue() {
if (per_channel_quant_) {
iterator_alpha_col_.load(fragment_alpha_col_);
}
if (with_bias_) {
iterator_C_.load(fragment_C_);
}
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void begin_step(int step_idx) {
fragment_D_.clear();
}
/// Called at the start of a row
CUTLASS_DEVICE
void begin_row(int row_idx) {
// load alpha_row in begin_step only when per token(row) scaling is used
if (per_token_quant_) {
int thread_offset_row = iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
}
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const &accum) {
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
ComputeFragment result = source_converter(accum);
if (per_channel_quant_) {
ComputeFragment alpha_col = reinterpret_cast<ComputeFragment *>(&fragment_alpha_col_)[column_idx];
result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
} else {
result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
}
if (with_bias_) {
NumericArrayConverter<ElementCompute, ElementOutput, kElementsPerAccess> bias_converter;
OutputVector bias = reinterpret_cast<OutputVector *>(&fragment_C_)[column_idx];
result = bias_accumulator_(result, bias_converter(bias));
}
// Convert to the output
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
OutputVector &output = reinterpret_cast<OutputVector *>(&fragment_D_)[frag_idx];
output = output_converter(result);
}
/// Called at the end of a row
CUTLASS_DEVICE
void end_row(int row_idx) {}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void end_step(int step_idx) {
iterator_D_.store(fragment_D_);
++iterator_D_;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void end_epilogue() {}
private:
CUTLASS_DEVICE
ComputeFragment per_token_channel_scale_accumulator_(
ComputeFragment const &accum, ComputeFragment const &scale_col, AlphaScaleElementType const &scale_row) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) {
result[i] = accum[i] * (scale_col[i] * scale_row);
}
return result;
}
CUTLASS_DEVICE
ComputeFragment per_token_scale_accumulator_(
ComputeFragment const &accum, AlphaScaleElementType const &scale_col, AlphaScaleElementType const &scale_row) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) {
result[i] = accum[i] * (scale_col * scale_row);
}
return result;
}
CUTLASS_DEVICE
ComputeFragment bias_accumulator_(ComputeFragment const &accum, ComputeFragment const &bias) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < OutputVector::kElements; ++i) {
result[i] = accum[i] + bias[i];
}
return result;
}
};
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/device_kernel.h>
#include <cutlass/trace.h>
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
that feature at the moment.
*/
template <typename GemmKernel_>
class GemmUniversalBaseCompat {
public:
using GemmKernel = GemmKernel_;
using ThreadblockShape = typename GemmKernel::Mma::Shape;
using ElementA = typename GemmKernel::ElementA;
using LayoutA = typename GemmKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
using ElementB = typename GemmKernel::ElementB;
using LayoutB = typename GemmKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
using ElementC = typename GemmKernel::ElementC;
using LayoutC = typename GemmKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
using Operator = typename GemmKernel::Operator;
/// Argument structure
using Arguments = typename GemmKernel::Arguments;
protected:
/// Kernel parameters object
typename GemmKernel::Params params_;
protected:
/// Private helper to obtain the grid dimensions with fix-up for split-K
static void get_grid_shape_(gemm::GemmCoord &grid_tiled_shape, int &gemm_k_size, Arguments const &args) {
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
gemm_k_size = args.problem_size.k();
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
int const kAlignK = const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
if (gemm_k_size) {
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
}
public:
/// Constructs the GEMM.
GemmUniversalBaseCompat() {}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) {
return Status::kErrorInvalidProblem;
}
return GemmKernel::can_implement(args);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()");
size_t workspace_bytes = 0;
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
if (args.mode == GemmUniversalMode::kGemmSplitKParallel) {
// Split-K parallel always requires a temporary workspace
workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k());
} else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) {
// Serial split-K only requires a temporary workspace if the number of partitions along the
// GEMM K dimension is greater than one.
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
}
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
return workspace_bytes;
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()");
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
CUTLASS_TRACE_HOST(
" grid_tiled_shape: " << grid_tiled_shape << "\n"
<< " result = {" << result << "}");
return result;
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
if (smem_size <= (48 << 10)) {
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
if (result == cudaSuccess) {
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
} else {
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1;
}
if (smem_capacity < 0) {
int device_idx = 0;
result = cudaGetDevice(&device_idx);
if (result != cudaSuccess) {
return -1;
}
cudaDeviceProp properties;
result = cudaGetDeviceProperties(&properties, device_idx);
if (result != cudaSuccess) {
return -1;
}
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
}
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
return occupancy;
}
CUTLASS_TRACE_HOST(" returning internal error");
return -1;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST(
"GemmUniversalBaseCompat::initialize() - workspace " << workspace
<< ", stream: " << (stream ? "non-null" : "null"));
size_t workspace_bytes = get_workspace_size(args);
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
if (workspace_bytes) {
if (!workspace) {
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
return Status::kErrorWorkspaceNull;
}
if (args.mode == GemmUniversalMode::kGemm) {
CUTLASS_TRACE_HOST(" clearing device workspace");
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
}
// Get CUDA grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
// Initialize the Params structure
params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast<int *>(workspace));
// Specify shared memory capacity for kernel.
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
cudaError_t result = cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace) {
return Status::kErrorWorkspaceNull;
}
params_.update(args, workspace);
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()");
//
// Configure grid and block dimensions
//
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(GemmKernel::kThreadCount, 1, 1);
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
//
// Launch kernel
//
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes");
// Launch
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
//
// Query for errors
//
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h
#pragma once
#include <cutlass/complex.h>
#include <cutlass/cutlass.h>
#include <cutlass/fast_math.h>
#include <cutlass/matrix_coord.h>
#include <cutlass/trace.h>
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmWithEpilogueVisitor {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueVisitor = typename Epilogue::Visitor;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using TensorRefA = TensorRef<ElementA, LayoutA>;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using TensorRefB = TensorRef<ElementB, LayoutB>;
using ElementCompute = typename EpilogueVisitor::ElementCompute;
using LayoutAlphaCol = cutlass::layout::RowMajor;
using LayoutAlphaRow = cutlass::layout::ColumnMajor;
using TensorRefAlphaCol = TensorRef<ElementCompute, LayoutAlphaCol>;
using TensorRefAlphaRow = TensorRef<ElementCompute, LayoutAlphaRow>;
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename Epilogue::Layout;
using TensorRefC = TensorRef<ElementC, LayoutC>;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
using EpilogueOutputOp =
typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment = const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
TensorRefA ref_A;
TensorRefB ref_B;
TensorRefAlphaCol ref_alpha_col;
TensorRefAlphaRow ref_alpha_row;
TensorRefC ref_C;
TensorRefC ref_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_D;
typename EpilogueVisitor::Arguments epilogue_visitor;
//
// Methods
//
Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {}
/// constructs an arguments structure
Arguments(
GemmCoord problem_size_,
TensorRefA ref_A_,
TensorRefB ref_B_,
TensorRefAlphaCol ref_alpha_col_,
TensorRefAlphaRow ref_alpha_row_,
TensorRefC ref_C_,
TensorRefC ref_D_,
typename EpilogueVisitor::Arguments epilogue_visitor_)
: mode(GemmUniversalMode::kGemm),
problem_size(problem_size_),
batch_count(1),
ref_A(ref_A_),
ref_B(ref_B_),
ref_alpha_col(ref_alpha_col_),
ref_alpha_row(ref_alpha_row_),
ref_C(ref_C_),
ref_D(ref_D_),
batch_stride_A(0),
batch_stride_B(0),
batch_stride_D(0),
epilogue_visitor(epilogue_visitor_) {}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col;
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row;
typename EpilogueVisitor::OutputTileIterator::Params params_C;
typename EpilogueVisitor::OutputTileIterator::Params params_D;
GemmUniversalMode mode;
int batch_count;
int gemm_k_size;
void *ptr_A;
void *ptr_B;
typename EpilogueVisitor::ScaleTileIterator::Element *ptr_alpha_col;
typename EpilogueVisitor::ScaleTileIterator::Element *ptr_alpha_row;
ElementC *ptr_C;
ElementC *ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
typename EpilogueVisitor::Params epilogue_visitor;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: swizzle_log_tile(0),
params_A(0),
params_B(0),
params_alpha_col(0),
params_C(0),
params_D(0),
batch_count(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_alpha_col(nullptr),
ptr_alpha_row(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
batch_stride_A(0),
batch_stride_B(0) {}
Params(Arguments const &args, cutlass::gemm::GemmCoord const &grid_tiled_shape_, int gemm_k_size_, int *workspace_)
: problem_size(args.problem_size),
swizzle_log_tile(0),
params_A(args.ref_A.layout()),
params_B(args.ref_B.layout()),
params_alpha_col(args.ref_alpha_col.layout()),
params_alpha_row(args.ref_alpha_col.layout()),
params_C(args.ref_C.layout()),
params_D(args.ref_D.layout()),
mode(args.mode),
batch_count(args.batch_count),
gemm_k_size(args.problem_size.k()),
ptr_A(args.ref_A.data()),
ptr_B(args.ref_B.data()),
ptr_alpha_col(args.ref_alpha_col.data()),
ptr_alpha_row(args.ref_alpha_row.data()),
ptr_C(args.ref_C.data()),
ptr_D(args.ref_D.data()),
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
epilogue_visitor(args.epilogue_visitor) {
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
int const kAlignK = const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
if (gemm_k_size) {
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
struct {
typename Epilogue::SharedStorage epilogue;
typename EpilogueVisitor::SharedStorage visitor;
} epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmWithEpilogueVisitor() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const &problem_size) {
CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess;
bool isAMisaligned = false;
bool isBMisaligned = false;
bool isCMisaligned = false;
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
isAMisaligned = problem_size.m() % kAlignmentA;
} else if (
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value || platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
}
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
isBMisaligned = problem_size.n() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
} else if (
platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value || platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
}
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
isCMisaligned = problem_size.m() % kAlignmentC;
} else if (
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value || platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
}
if (isAMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
return Status::kErrorMisalignedOperand;
}
if (isBMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
return Status::kErrorMisalignedOperand;
}
if (isCMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
return Status::kErrorMisalignedOperand;
}
CUTLASS_TRACE_HOST(" returning kSuccess");
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return can_implement(args.problem_size);
}
static size_t get_extra_workspace_size(Arguments const &args, cutlass::gemm::GemmCoord const &grid_tiled_shape) {
return 0;
}
#define SPLIT_K_ENABLED 1
/// Executes one GEMM
CUTLASS_DEVICE
void run_kernel_(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
#if SPLIT_K_ENABLED
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) {
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
} else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
} else if (params.mode == GemmUniversalMode::kArray) {
ptr_A = static_cast<ElementA *const *>(params.ptr_A)[threadblock_tile_offset.k()];
ptr_B = static_cast<ElementB *const *>(params.ptr_B)[threadblock_tile_offset.k()];
}
#endif
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
offset_k,
};
cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
//
// Construct the epilogue visitor
//
bool with_bias = true;
if (params.ptr_C == nullptr) {
with_bias = false;
}
EpilogueVisitor epilogue_visitor(
params.epilogue_visitor,
shared_storage.epilogue.visitor,
params.problem_size.mn(),
thread_idx,
warp_idx,
lane_idx,
params.params_alpha_col,
params.params_C,
params.params_D,
with_bias,
true,
true,
params.ptr_alpha_row,
params.ptr_alpha_col,
params.ptr_C,
params.ptr_D,
threadblock_offset,
blockIdx.y * params.problem_size.m());
if (params.mode == GemmUniversalMode::kGemm) {
// Indicate which position in a serial reduction the output operator is currently updating
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
} else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
}
// Construct the epilogue
Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(epilogue_visitor, accumulators);
}
template <typename CompilationArch>
CUTLASS_DEVICE void run_kernel(Params const &params, SharedStorage &shared_storage) {
if constexpr (platform::is_same<ArchTag, CompilationArch>::value) {
run_kernel_(params, shared_storage);
} else {
CUTLASS_NOT_IMPLEMENTED();
}
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
run_kernel<ArchTag>(params, shared_storage);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/linear_combination.h>
#include <cutlass/epilogue/threadblock/epilogue_with_visitor.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <cutlass/numeric_types.h>
#include <cute/atom/mma_atom.hpp>
#include <cute/tensor.hpp>
#include <cutlass/epilogue/collective/collective_builder.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "epilogue_per_row_per_col_scale.h"
#include "gemm_universal_base_compat.h"
#include "gemm_with_epilogue_visitor.h"
using namespace cute;
inline infiniStatus_t check_cutlass_status(cutlass::Status status) {
if (status != cutlass::Status::kSuccess) {
return INFINI_STATUS_INTERNAL_ERROR;
}
return INFINI_STATUS_SUCCESS;
}
template <
typename ElementOutput,
typename ArchTag,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
int NumStages>
void cutlass_int8_scaled_mm(
void *out,
const void *a,
const void *b,
const void *a_scale,
const void *b_scale,
const void *bias,
int m,
int n,
int k,
int lda,
int ldb,
int ldd,
void *stream) {
using ElementAccumulator = int32_t;
using ElementCompute = float;
using ElementInputA = int8_t;
using ElementInputB = int8_t;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>;
using DefaultGemmConf = cutlass::gemm::device::
DefaultGemmConfiguration<OperatorClass, ArchTag, ElementInputA, ElementInputB, ElementOutput, ElementCompute>;
using EpilogueOutputOp = typename DefaultGemmConf::EpilogueOutputOp;
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
ElementInputA,
cutlass::layout::RowMajor,
DefaultGemmConf::kAlignmentA,
ElementInputB,
cutlass::layout::ColumnMajor,
DefaultGemmConf::kAlignmentB,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
NumStages,
true,
typename DefaultGemmConf::Operator>::GemmKernel;
using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape,
typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count,
GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads,
GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess,
cutlass::sizeof_bits<ElementOutput>::value>,
ElementCompute>;
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol<
ThreadblockShape,
GemmKernel_::kThreadCount,
AlphaColTileIterator,
typename GemmKernel_::Epilogue::OutputTileIterator,
ElementAccumulator,
ElementCompute,
EpilogueOutputOp>;
using Epilogue = typename cutlass::epilogue::threadblock::
EpilogueWithVisitorFromExistingEpilogue<EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue;
using GemmKernel = cutlass::gemm::kernel::GemmWithEpilogueVisitor<typename GemmKernel_::Mma, Epilogue, ThreadblockSwizzle>;
using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>;
Gemm gemm_op;
auto a_ptr = static_cast<ElementInputA *>(const_cast<void *>(a));
auto b_ptr = static_cast<ElementInputB *>(const_cast<void *>(b));
auto o_ptr = static_cast<ElementOutput *>(const_cast<void *>(out));
auto a_s_ptr = static_cast<ElementCompute *>(const_cast<void *>(a_scale));
auto b_s_ptr = static_cast<ElementCompute *>(const_cast<void *>(b_scale));
ElementOutput *bias_ptr = nullptr;
int64_t ldc = 0;
if (bias) {
bias_ptr = static_cast<ElementOutput *>(const_cast<void *>(bias));
}
typename EpilogueOutputOp::Params linearScalingParams;
typename EpilogueVisitor::Arguments visitor_args{linearScalingParams};
typename Gemm::Arguments args{
{m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0}, {a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args};
check_cutlass_status(gemm_op.can_implement(args));
auto status = gemm_op(args, nullptr, (cudaStream_t)stream);
check_cutlass_status(status);
}
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
void sm75_dispatch_shape(
void *out,
const void *a,
const void *b,
const void *a_scale,
const void *b_scale,
const void *bias,
int m,
int n,
int k,
int lda,
int ldb,
int ldd,
void *stream) {
if (m <= 32) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<32, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
2>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else if (m <= 64) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
2>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else if (m <= 256) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<128, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
2>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
2>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
}
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
void sm80_dispatch_shape(
void *out,
const void *a,
const void *b,
const void *a_scale,
const void *b_scale,
const void *bias,
int m,
int n,
int k,
int lda,
int ldb,
int ldd,
void *stream) {
if (m <= 16) {
if (n <= 4096) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
InstructionShape,
6>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
InstructionShape,
5>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
} else if (m <= 32) {
if (n <= 4096) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
6>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
5>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
} else if (m <= 64) {
if (n <= 4096) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
5>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
} else if (m <= 128 && n < 8192) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
}
// Dispatch shape for sm89 (L40S, L20, RTX 4090), according to:
// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
void sm89_dispatch_shape(
void *out,
const void *a,
const void *b,
const void *a_scale,
const void *b_scale,
const void *bias,
int m,
int n,
int k,
int lda,
int ldb,
int ldd,
void *stream) {
if (m <= 16) {
if (n <= 8192) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<16, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
InstructionShape,
5>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<16, 128, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
InstructionShape,
4>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
} else if (m <= 32) {
if (n <= 8192) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<32, 64, 128>,
cutlass::gemm::GemmShape<16, 64, 64>,
InstructionShape,
5>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<32, 128, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
4>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
} else if (m <= 64) {
if (n <= 8192) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
5>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
3>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
} else if (m <= 128) {
if (n <= 8192) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
3>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else if (n <= 16384) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>,
InstructionShape,
5>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
} else if (m <= 256) {
if (n <= 4096) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
3>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else if (n <= 8192) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else if (n <= 16384) {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
3>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
} else {
cutlass_int8_scaled_mm<
ElementOutput,
ArchTag,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
InstructionShape,
5>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
}
template <
typename ElementOutput,
typename TileShape,
typename ClusterShape,
typename MainloopScheduleType,
bool WithBias>
void cutlass_int8_scaled_mm_sm90(
void *out,
const void *a,
const void *b,
const void *a_scale,
const void *b_scale,
const void *bias,
int m,
int n,
int k,
int lda,
int ldb,
int ldd,
void *stream) {
using ArchTag = cutlass::arch::Sm90;
using ElementAccumulator = int32_t;
using ElementCompute = float;
using ElementInputA = int8_t;
using ElementInputB = int8_t;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementOutput>::value;
static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
using TileSchedulerType = cutlass::gemm::PersistentScheduler;
using XScale = cutlass::epilogue::fusion::
Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride<Int<1>, Int<0>, Int<0>>>;
using WScale = cutlass::epilogue::fusion::
Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride<Int<0>, Int<1>, Int<0>>>;
using Bias = cutlass::epilogue::fusion::
Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, Stride<Int<0>, Int<1>, Int<0>>>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
// Scale
using Compute0 = cutlass::epilogue::fusion::
Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
using Compute1 = cutlass::epilogue::fusion::
Sm90Compute<cutlass::multiplies, ElementOutput, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
// With bias
using ComputeWithBias = cutlass::epilogue::fusion::
Sm90Compute<cutlass::multiply_add, ElementOutput, ElementCompute, cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT<ComputeWithBias, XScale, EVTCompute0, Bias>;
using EpilogueEVT = typename cutlass::platform::conditional<WithBias, EVTComputeWithBias, EVTCompute1>::type;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementCompute,
ElementOutput,
cutlass::layout::RowMajor,
AlignmentC,
ElementOutput,
cutlass::layout::RowMajor,
AlignmentOutput,
EpilogueScheduleType,
EpilogueEVT>::CollectiveOp;
using Stages = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementInputA,
cutlass::layout::RowMajor,
AlignmentA,
ElementInputB,
cutlass::layout::ColumnMajor,
AlignmentB,
ElementAccumulator,
TileShape,
ClusterShape,
Stages,
MainloopScheduleType>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
TileSchedulerType>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
Gemm gemm_op;
auto a_ptr = static_cast<ElementInputA *>(const_cast<void *>(a));
auto b_ptr = static_cast<ElementInputB *>(const_cast<void *>(b));
auto o_ptr = static_cast<ElementOutput *>(const_cast<void *>(out));
auto a_s_ptr = static_cast<ElementCompute *>(const_cast<void *>(a_scale));
auto b_s_ptr = static_cast<ElementCompute *>(const_cast<void *>(b_scale));
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1));
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
StrideC stride_c;
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
typename Gemm::Arguments args = {
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
{a_ptr, stride_a, b_ptr, stride_b},
{{}, // epilogue.thread
nullptr,
stride_c,
o_ptr,
stride_d}};
if constexpr (WithBias) {
ElementOutput *bias_ptr = static_cast<ElementOutput *>(const_cast<void *>(bias));
// ElementOutput* bias_ptr = static_cast<ElementOutput*>(bias->data_ptr());
args.epilogue.thread = {
{a_s_ptr},
{{b_s_ptr}, {}, {}},
{bias_ptr},
{},
};
} else {
args.epilogue.thread = {
{a_s_ptr},
{{b_s_ptr}, {}, {}},
{},
};
}
// auto workspace = torch::empty(
// gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
check_cutlass_status(gemm_op.can_implement(args));
auto status = gemm_op(args, nullptr, (cudaStream_t)stream);
check_cutlass_status(status);
// TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
}
template <typename ElementOutput, typename TileShape, typename ClusterShape, typename MainloopScheduleType>
void sm90_dispatch_bias(
void *out,
const void *a,
const void *b,
const void *a_scale,
const void *b_scale,
const void *bias,
int m,
int n,
int k,
int lda,
int ldb,
int ldd,
void *stream) {
if (bias) {
cutlass_int8_scaled_mm_sm90<ElementOutput, TileShape, ClusterShape, MainloopScheduleType, true>(
out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else {
cutlass_int8_scaled_mm_sm90<ElementOutput, TileShape, ClusterShape, MainloopScheduleType, false>(
out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
}
template <typename ElementOutput>
void sm90_dispatch_shape(
void *out,
const void *a,
const void *b,
const void *a_scale,
const void *b_scale,
const void *bias,
int m,
int n,
int k,
int lda,
int ldb,
int ldd,
void *stream) {
if (m <= 32) {
if (n < 8192) {
return sm90_dispatch_bias<
ElementOutput,
Shape<_64, _64, _128>,
Shape<_1, _8, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else {
return sm90_dispatch_bias<
ElementOutput,
Shape<_64, _128, _128>,
Shape<_1, _8, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
} else if (m <= 64) {
if (n < 8192) {
return sm90_dispatch_bias<
ElementOutput,
Shape<_64, _64, _128>,
Shape<_1, _4, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else {
return sm90_dispatch_bias<
ElementOutput,
Shape<_64, _64, _256>,
Shape<_1, _1, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
} else if (m <= 128) {
if (n <= 4096) {
return sm90_dispatch_bias<
ElementOutput,
Shape<_64, _64, _128>,
Shape<_2, _1, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
} else {
return sm90_dispatch_bias<
ElementOutput,
Shape<_64, _128, _128>,
Shape<_2, _1, _1>,
cutlass::gemm::KernelTmaWarpSpecialized>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
} else {
return sm90_dispatch_bias<
ElementOutput,
Shape<_128, _128, _128>,
Shape<_2, _1, _1>,
cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, a, b, a_scale, b_scale, bias, m, n, k, lda, ldb, ldd, stream);
}
}
#include "../../../devices/nvidia/nvidia_handle.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#ifdef ENABLE_CUTLASS_API
#include "int8_gemm_kernel.cuh"
#endif
#include "../cuda/per_channel_dequant_int8.cuh"
#include "int8_gemm_nvidia.cuh"
template <typename Tdata>
INFINIOP_CUDA_KERNEL postSym(
Tdata *y, int32_t *y_packed, const Tdata *bias, const int8_t *x_packed, const float *x_scale, const int8_t *w_packed, const float *w_scale, int M, int K, int N) {
postSymKernel<Tdata>(y, y_packed, bias, x_packed, x_scale, w_packed, w_scale, M, K, N);
}
template <typename Tdata>
INFINIOP_CUDA_KERNEL postSym(
Tdata *y, int32_t *y_packed, const int8_t *x_packed, const float *x_scale, const int8_t *w_packed, const float *w_scale, int M, int K, int N) {
postSymKernel<Tdata>(y, y_packed, x_packed, x_scale, w_packed, w_scale, M, K, N);
}
namespace op::i8gemm::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
#ifdef ENABLE_NVIDIA_API
inline int getSMVersion() {
int device{-1};
CHECK_CUDA(cudaGetDevice(&device));
int sm_major = 0;
int sm_minor = 0;
CHECK_CUDA(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
CHECK_CUDA(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
return sm_major * 10 + sm_minor;
}
#endif
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t bias_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t a_scale_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t b_scale_desc) {
auto handle = reinterpret_cast<device::nvidia::Handle *>(handle_);
auto dtype = out_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
auto result = I8GemmInfo::create(out_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
size_t workspace_size = out_desc->dim(0) * out_desc->dim(1) * sizeof(int32_t);
*desc_ptr = new Descriptor(
new Opaque{handle->internal()},
result.take(), workspace_size, dtype,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <unsigned int BLOCK_SIZE, typename Tdata>
infiniStatus_t Descriptor::launchKernel(const I8GemmInfo &info, Tdata *y, const Tdata *bias, const int8_t *x_packed, const float *x_scale, const int8_t *w_packed, const float *w_scale, void *stream_, void *workspace) const {
cudaStream_t stream = (cudaStream_t)stream_;
int M = (int)info.m;
int K = (int)info.k;
int N = (int)info.n;
char *workspace_ptr = reinterpret_cast<char *>(workspace);
int32_t *y_packed = reinterpret_cast<int32_t *>(workspace_ptr);
const int32_t alpha_I = 1;
const int32_t beta_I = 0;
int lda = K; // w_packed is column-major [K, N]
int ldb = K; // x_packed is row-major [M, K]
int ldc = N; // y_packed is row-major [M, N]
CHECK_STATUS(this->_opaque->internal->useCublas(
stream,
[&](cublasHandle_t handle) {
CHECK_CUBLAS(cublasGemmEx(
handle,
CUBLAS_OP_T, // A = w_packed^T : [N, K]
CUBLAS_OP_N, // B = x_packed^T viewed column-major : [K, M]
N, // m
M, // n
K, // k
&alpha_I,
w_packed, CUDA_R_8I, lda,
x_packed, CUDA_R_8I, ldb,
&beta_I,
y_packed, CUDA_R_32I, ldc,
CUBLAS_COMPUTE_32I,
CUBLAS_GEMM_DEFAULT));
return INFINI_STATUS_SUCCESS;
}));
constexpr unsigned int BLOCK_SIZE_x = 32;
constexpr unsigned int BLOCK_SIZE_y = 32;
int num_block_x = (N + BLOCK_SIZE_x - 1) / BLOCK_SIZE_x;
int num_block_y = (M + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
dim3 grid_dim(num_block_x, num_block_y, 1);
if (bias == nullptr) {
postSym<Tdata><<<grid_dim, block_dim, 0, stream>>>(y, y_packed, x_packed, x_scale, w_packed, w_scale, M, K, N);
} else {
postSym<Tdata><<<grid_dim, block_dim, 0, stream>>>(y, y_packed, bias, x_packed, x_scale, w_packed, w_scale, M, K, N);
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *out,
const void *bias,
const void *a,
const void *a_scale,
const void *b,
const void *b_scale,
void *stream) const {
#if defined(ENABLE_NVIDIA_API) && defined(ENABLE_CUTLASS_API)
auto sm_version = getSMVersion();
if (sm_version >= 75 && sm_version < 80) {
CHECK_DTYPE(this->_out_dtype, INFINI_DTYPE_F16);
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
out, a, b, a_scale, b_scale, bias, _info.m, _info.n, _info.k, _info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(), stream);
} else if (sm_version >= 80 && sm_version < 90) {
// sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
if (sm_version == 86 || sm_version == 89) {
if (this->_out_dtype == INFINI_DTYPE_BF16) {
sm89_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, a, b, a_scale, b_scale, bias, _info.m, _info.n, _info.k, _info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(), stream);
} else {
sm89_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, a, b, a_scale, b_scale, bias, _info.m, _info.n, _info.k, _info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(), stream);
}
} else {
if (this->_out_dtype == INFINI_DTYPE_BF16) {
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, a, b, a_scale, b_scale, bias, _info.m, _info.n, _info.k, _info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(), stream);
} else {
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, a, b, a_scale, b_scale, bias, _info.m, _info.n, _info.k, _info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(), stream);
}
}
} else if (sm_version == 90) {
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
// cutlass 3.x
if (this->_out_dtype == INFINI_DTYPE_BF16) {
sm90_dispatch_shape<cutlass::bfloat16_t>(
out, a, b, a_scale, b_scale, bias,
_info.m, _info.n, _info.k,
_info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(),
stream);
} else {
sm90_dispatch_shape<cutlass::half_t>(
out, a, b, a_scale, b_scale, bias,
_info.m, _info.n, _info.k,
_info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(),
stream);
}
#else
// // fallback to cutlass 2.x
if (this->_out_dtype == INFINI_DTYPE_BF16) {
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, a, b, a_scale, b_scale, bias, _info.m, _info.n, _info.k, _info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(), stream);
} else {
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, a, b, a_scale, b_scale, bias, _info.m, _info.n, _info.k, _info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(), stream);
}
#endif
} else {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
#elif defined ENABLE_QY_API
#define CALCULATE_LINEAR(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)out, (const TDATA *)bias, (const int8_t *)a, (const float *)a_scale, (const int8_t *)b, (const float *)b_scale, stream, workspace)
#define CALCULATE_LINEAR_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (this->_out_dtype == INFINI_DTYPE_F16) \
return CALCULATE_LINEAR(BLOCK_SIZE, half); \
else if (this->_out_dtype == INFINI_DTYPE_F32) \
return CALCULATE_LINEAR(BLOCK_SIZE, float); \
else if (this->_out_dtype == INFINI_DTYPE_BF16) \
return CALCULATE_LINEAR(BLOCK_SIZE, __nv_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CALCULATE_LINEAR_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CALCULATE_LINEAR_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CALCULATE_LINEAR_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
#endif
return INFINI_STATUS_SUCCESS;
}
} // namespace op::i8gemm::nvidia
#ifndef __INT8_GEMM_NVIDIA_API_H__
#define __INT8_GEMM_NVIDIA_API_H__
#include "../int8_gemm.h"
DESCRIPTOR(nvidia)
#endif // __INT8_GEMM_NVIDIA_API_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/int8_gemm.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/int8_gemm_nvidia.cuh"
#endif
__C infiniStatus_t infiniopCreateI8GemmDescriptor(infiniopHandle_t handle,
infiniopI8GemmDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t bias_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t a_scale_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t b_scale_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::i8gemm::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, \
bias_desc, \
a_desc, \
a_scale_desc, \
b_desc, \
b_scale_desc);
switch (handle->device) {
#if defined(ENABLE_NVIDIA_API)
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#if defined(ENABLE_QY_API)
CREATE(INFINI_DEVICE_QY, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetI8GemmWorkspaceSize(infiniopI8GemmDescriptor_t desc, size_t *size) {
switch (desc->device_type) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor *>(desc)->minWorkspaceSize(); \
return INFINI_STATUS_SUCCESS;
#if defined(ENABLE_NVIDIA_API)
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#if defined(ENABLE_QY_API)
GET(INFINI_DEVICE_QY, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
}
__C infiniStatus_t infiniopI8Gemm(infiniopI8GemmDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *bias,
const void *a,
const void *a_scale,
const void *b,
const void *b_scale,
void *stream) {
#define CACULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, bias, a, a_scale, b, b_scale, stream);
switch (desc->device_type) {
#if defined(ENABLE_NVIDIA_API)
CACULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#if defined(ENABLE_QY_API)
CACULATE(INFINI_DEVICE_QY, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CACULATE
}
__C infiniStatus_t infiniopDestroyI8GemmDescriptor(infiniopI8GemmDescriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#if defined(ENABLE_NVIDIA_API)
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#if defined(ENABLE_QY_API)
DESTROY(INFINI_DEVICE_QY, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DESTROY
}
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/sigmoid_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/sigmoid_nvidia.cuh"
#endif
......@@ -34,6 +34,9 @@ __C infiniStatus_t infiniopCreateSigmoidDescriptor(
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -59,6 +62,10 @@ __C infiniStatus_t infiniopGetSigmoidWorkspaceSize(infiniopSigmoidDescriptor_t d
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia)
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -91,6 +98,9 @@ __C infiniStatus_t infiniopSigmoid(
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -118,6 +128,9 @@ infiniopDestroySigmoidDescriptor(infiniopSigmoidDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/silu_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_ALI_API)
#include "nvidia/silu_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -46,6 +46,9 @@ __C infiniStatus_t infiniopCreateSiluDescriptor(
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -77,6 +80,10 @@ __C infiniStatus_t infiniopGetSiluWorkspaceSize(infiniopSiluDescriptor_t desc, s
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -115,6 +122,9 @@ __C infiniStatus_t infiniopSilu(
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -148,6 +158,9 @@ infiniopDestroySiluDescriptor(infiniopSiluDescriptor_t desc) {
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#ifndef __SILU_AND_MUL_INFO_H__
#define __SILU_AND_MUL_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::silu_and_mul {
class SiluAndMulInfo {
SiluAndMulInfo() = default;
public:
infiniDtype_t dtype;
size_t batch_size;
size_t out_hidden_dim;
static utils::Result<SiluAndMulInfo> create(infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t x_desc) {
auto dtype = y_desc->dtype();
auto x_shape = x_desc->shape();
auto y_shape = y_desc->shape();
auto ndim = x_desc->ndim();
if (ndim != y_desc->ndim()) {
return INFINI_STATUS_BAD_PARAM;
}
if (x_shape[ndim - 1] != 2 * y_shape[ndim - 1]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch = 1;
for (int i = 0; i < (int)ndim - 1; ++i) {
if (x_shape[i] != y_shape[i]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
batch *= y_shape[i];
}
return utils::Result<SiluAndMulInfo>(SiluAndMulInfo{
dtype,
batch,
y_shape[ndim - 1]});
}
private:
SiluAndMulInfo(infiniDtype_t dtype, size_t batch, size_t hidden)
: dtype(dtype), batch_size(batch), out_hidden_dim(hidden) {}
};
} // namespace op::silu_and_mul
#endif // __SILU_AND_MUL_INFO_H__
#ifndef __SILU_ADN_MUL_MOORE_API_H__
#define __SILU_ADN_MUL_MOORE_API_H__
#include "../silu_and_mul.h"
DESCRIPTOR(moore)
#endif // __SILU_ADN_MUL_MOORE_API_H__
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_handle.h"
#include "silu_and_mul_moore.h"
#include <musa_bf16.h>
#include <memory>
namespace op::silu_and_mul::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
if (!desc_ptr) {
return INFINI_STATUS_BAD_PARAM;
}
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = y_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
if (x_desc->dtype() != dtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
auto result = SiluAndMulInfo::create(y_desc, x_desc);
CHECK_RESULT(result);
auto info = result.take();
*desc_ptr = new Descriptor(
new Opaque{handle->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename T>
infiniStatus_t calculate_impl(
const SiluAndMulInfo &info,
std::shared_ptr<device::moore::Handle::Internal> &internal,
void *y,
const void *x,
void *stream) {
return internal->useMudnn(
(musaStream_t)stream,
[&](::musa::dnn::Handle &mudnn_handle) -> infiniStatus_t {
::musa::dnn::Tensor x_t, y_t;
if constexpr (std::is_same_v<T, half>) {
x_t.SetType(::musa::dnn::Tensor::Type::HALF);
y_t.SetType(::musa::dnn::Tensor::Type::HALF);
} else if constexpr (std::is_same_v<T, __mt_bfloat16>) {
x_t.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
y_t.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
} else {
x_t.SetType(::musa::dnn::Tensor::Type::FLOAT);
y_t.SetType(::musa::dnn::Tensor::Type::FLOAT);
}
x_t.SetAddr(const_cast<void *>(x));
y_t.SetAddr(y);
// --- Construct 2D dimension information ---
// Explicitly distinguish between Batch and Hidden dimensions
int64_t b = static_cast<int64_t>(info.batch_size);
int64_t h = static_cast<int64_t>(info.out_hidden_dim);
// Input x logical shape is [batch, 2 * hidden]
std::array<int64_t, 2> x_dims = {b, h * 2};
std::array<int64_t, 2> x_strides = {h * 2, 1};
// Output y logical shape is [batch, hidden]
std::array<int64_t, 2> y_dims = {b, h};
std::array<int64_t, 2> y_strides = {h, 1};
x_t.SetNdInfo(2, x_dims.data(), x_strides.data());
y_t.SetNdInfo(2, y_dims.data(), y_strides.data());
// Invoke muDNN SwiGLU
// muDNN will split each row (length 2*h) internally,
// muDNN treats the first h elements of input x as the 'gate'
// and the following h elements as the 'up' projection.
::musa::dnn::SwiGlu swiglu;
swiglu.Run(mudnn_handle, y_t, x_t);
return INFINI_STATUS_SUCCESS;
});
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *x,
void *stream) const {
infiniDtype_t dtype = _info.dtype;
switch (dtype) {
case INFINI_DTYPE_F16:
return calculate_impl<half>(_info, _opaque->internal, y, x, stream);
case INFINI_DTYPE_F32:
return calculate_impl<float>(_info, _opaque->internal, y, x, stream);
case INFINI_DTYPE_BF16:
return calculate_impl<__mt_bfloat16>(_info, _opaque->internal, y, x, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace op::silu_and_mul::moore
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/silu_and_mul.h"
#ifdef ENABLE_MOORE_API
#include "moore/silu_and_mul_moore.h"
#endif
__C infiniStatus_t infiniopCreateSiluAndMulDescriptor(
infiniopHandle_t handle,
infiniopSiluAndMulDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::silu_and_mul::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::silu_and_mul::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \
x_desc);
switch (handle->device) {
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopGetSiluAndMulWorkspaceSize(infiniopSiluAndMulDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::silu_and_mul::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore);
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopSiluAndMul(
infiniopSiluAndMulDescriptor_t desc,
void *workspace, size_t workspace_size,
void *y,
const void *x,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::silu_and_mul::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, y, x, stream);
switch (desc->device_type) {
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopDestroySiluAndMulDescriptor(infiniopSiluAndMulDescriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::silu_and_mul::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore);
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#ifndef SILU_AND_MUL_H
#define SILU_AND_MUL_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::silu_and_mul::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
SiluAndMulInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
SiluAndMulInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
const void *x, \
void *stream) const; \
}; \
}
#endif // SILU_AND_MUL_H
......@@ -2,7 +2,7 @@
#include "../../handle.h"
#include "infiniop/ops/softmax.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API)
#include "nvidia/softmax_nvidia.cuh"
#endif
......@@ -33,6 +33,9 @@ __C infiniStatus_t infiniopCreateSoftmaxDescriptor(
#endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -57,6 +60,9 @@ __C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t d
#endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -86,6 +92,9 @@ __C infiniStatus_t infiniopSoftmax(
#endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -110,6 +119,9 @@ __C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t
#endif
#ifdef ENABLE_HYGON_API
DESTROY(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia);
#endif
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/softplus_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/softplus_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -49,6 +49,10 @@ __C infiniStatus_t infiniopCreateSoftplusDescriptor(
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -82,6 +86,10 @@ __C infiniStatus_t infiniopGetSoftplusWorkspaceSize(infiniopSoftplusDescriptor_t
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -123,6 +131,10 @@ __C infiniStatus_t infiniopSoftplus(
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -158,6 +170,10 @@ infiniopDestroySoftplusDescriptor(infiniopSoftplusDescriptor_t desc) {
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/sub_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/sub_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -51,6 +51,9 @@ __C infiniStatus_t infiniopCreateSubDescriptor(
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -85,6 +88,9 @@ __C infiniStatus_t infiniopGetSubWorkspaceSize(infiniopSubDescriptor_t desc, siz
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -128,6 +134,9 @@ __C infiniStatus_t infiniopSub(
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -164,6 +173,9 @@ infiniopDestroySubDescriptor(infiniopSubDescriptor_t desc) {
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment