Unverified Commit ed04d3e6 authored by qinyiqun's avatar qinyiqun Committed by GitHub
Browse files

Issue/840: 英伟达支持Int8 Gemm (#852)

* can commit

* can exec sm_90a

* can exec < sm_90

* fix format

* fix format

* 增加测试,测试对标sglang test

* fix format 1

* fix format 2

* add compile option to disable cutlass
parent c8a11a6a
#ifndef __INFINIOP_I8GEMM_API_H__
#define __INFINIOP_I8GEMM_API_H__
#include "../operator_descriptor.h"
typedef InfiniopDescriptor *infiniopI8GemmDescriptor_t;
__C __export infiniStatus_t infiniopCreateI8GemmDescriptor(infiniopHandle_t handle,
infiniopI8GemmDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t bias_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t x_scale_desc,
infiniopTensorDescriptor_t weights_desc,
infiniopTensorDescriptor_t weights_scale_desc);
__C __export infiniStatus_t infiniopGetI8GemmWorkspaceSize(infiniopI8GemmDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopI8Gemm(infiniopI8GemmDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *bias,
const void *x,
const void *x_scale,
const void *weights,
const void *weights_scale,
void *stream);
__C __export infiniStatus_t infiniopDestroyI8GemmDescriptor(infiniopI8GemmDescriptor_t desc);
#endif
#ifndef __GEMM_INFO_H__
#define __I8GEMM_INFO_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include <algorithm>
namespace op::i8gemm {
struct BlasMatrix {
int ndim;
int batch;
int stride;
int rows;
int cols;
int row_stride;
int col_stride;
static utils::Result<BlasMatrix> create(infiniopTensorDescriptor_t layout) {
BlasMatrix ans;
if (layout->ndim() == 2) {
ans.ndim = 2;
ans.batch = 1;
ans.stride = 0;
ans.rows = layout->dim(0);
ans.cols = layout->dim(1);
ans.row_stride = layout->stride(0);
ans.col_stride = layout->stride(1);
} else if (layout->ndim() == 3) {
ans.ndim = 3;
ans.batch = layout->dim(0);
ans.stride = ans.batch == 1 ? 0 : layout->stride(0);
ans.rows = layout->dim(1);
ans.cols = layout->dim(2);
ans.row_stride = layout->stride(1);
ans.col_stride = layout->stride(2);
} else {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (ans.row_stride != 1 && ans.col_stride != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
return utils::Result<BlasMatrix>(ans);
}
bool match_batch(int _batch) const {
return batch == _batch || batch == 1;
}
void transpose() {
std::swap(rows, cols);
std::swap(row_stride, col_stride);
}
int ld() const {
return row_stride == 1 ? col_stride : row_stride;
}
};
enum class MatrixLayout : char {
COL_MAJOR,
ROW_MAJOR,
};
class I8GemmInfo {
I8GemmInfo() = default;
public:
BlasMatrix a_matrix;
BlasMatrix b_matrix;
BlasMatrix out_matrix;
int m, n, k, batch;
static utils::Result<I8GemmInfo> create(
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
MatrixLayout layout) {
auto a_matrix = BlasMatrix::create(a_desc);
CHECK_RESULT(a_matrix);
auto b_matrix = BlasMatrix::create(b_desc);
CHECK_RESULT(b_matrix);
auto out_matrix = BlasMatrix::create(out_desc);
CHECK_RESULT(out_matrix);
if (out_matrix->rows != a_matrix->rows || out_matrix->cols != b_matrix->cols || a_matrix->cols != b_matrix->rows) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
auto batch = out_matrix->batch;
if (!a_matrix->match_batch(batch) || !b_matrix->match_batch(batch)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
auto m = out_matrix->rows;
auto n = out_matrix->cols;
auto k = a_matrix->cols;
return utils::Result<I8GemmInfo>(I8GemmInfo{
a_matrix.take(),
b_matrix.take(),
out_matrix.take(),
m,
n,
k,
batch});
}
};
} // namespace op::i8gemm
#endif // __I8GEMM_INFO_H__
#ifndef __I8GEMM_H__
#define __I8GEMM_H__
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::i8gemm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
size_t _workspace_size; \
I8GemmInfo _info; \
infiniDtype_t _out_dtype; \
\
Descriptor(Opaque *opaque, I8GemmInfo info, \
size_t workspace_size, \
infiniDtype_t out_dtype, \
infiniDevice_t device_type, int device_id) \
: InfiniopDescriptor{device_type, device_id}, _out_dtype(out_dtype), \
_opaque(opaque), _info(info), _workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t minWorkspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t bias_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t a_scale_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t b_scale_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *out, const void *bias, const void *a, \
const void *a_scale, const void *b, \
const void *b_scale, void *stream) const; \
}; \
}
#endif // __I8GEMM_H__
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
#pragma once
#include <cutlass/arch/memory.h>
#include <cutlass/numeric_conversion.h>
namespace cutlass {
namespace epilogue {
namespace threadblock {
template <
typename ThreadblockShape_,
int ThreadCount,
typename ScaleTileIterator_,
typename OutputTileIterator_,
typename ElementAccumulator_,
typename ElementCompute_,
typename ElementwiseFunctor_,
bool UseMasking_ = false>
class EpilogueVisitorPerRowPerCol {
public:
using ThreadblockShape = ThreadblockShape_;
static int const kThreadCount = ThreadCount;
using ScaleTileIterator = ScaleTileIterator_;
using OutputTileIterator = OutputTileIterator_;
using ElementwiseFunctor = ElementwiseFunctor_;
static int const kIterations = OutputTileIterator::kIterations;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ElementOutput = typename OutputTileIterator::Element;
using LayoutOutput = cutlass::layout::RowMajor;
using ElementAccumulator = ElementAccumulator_;
using AlphaScaleElementType = typename ScaleTileIterator::Element;
using ElementCompute = ElementCompute_;
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
/// Argument structure
struct Arguments {
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_alpha;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
Arguments(typename ElementwiseFunctor::Params elementwise_)
: elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
Arguments(
typename ElementwiseFunctor::Params elementwise_,
int64_t batch_stride_alpha_,
int64_t batch_stride_C_,
int64_t batch_stride_D_)
: elementwise(elementwise_),
batch_stride_alpha(batch_stride_alpha_),
batch_stride_C(batch_stride_C_),
batch_stride_D(batch_stride_D_) {}
};
struct Params {
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_alpha;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() {}
CUTLASS_HOST_DEVICE
Params(Arguments const &args)
: elementwise(args.elementwise),
batch_stride_alpha(args.batch_stride_alpha),
batch_stride_C(args.batch_stride_C),
batch_stride_D(args.batch_stride_D) {}
};
/// Shared storage
struct SharedStorage {};
private:
Params const &params_;
SharedStorage &shared_storage_;
MatrixCoord extent_;
MatrixCoord extent_real_;
ElementwiseFunctor elementwise_;
bool const with_bias_;
bool const per_token_quant_;
bool const per_channel_quant_;
AlphaScaleElementType *ptr_alpha_row_;
AlphaScaleElementType *ptr_alpha_col_;
ScaleTileIterator iterator_alpha_col_;
OutputTileIterator iterator_C_;
OutputTileIterator iterator_D_;
AlphaScaleElementType element_alpha_row_ = 1.0f;
AlphaScaleElementType element_alpha_col_ = 1.0f;
typename ScaleTileIterator::Fragment fragment_alpha_col_;
typename OutputTileIterator::Fragment fragment_C_;
typename OutputTileIterator::Fragment fragment_D_;
ElementAccumulator beta_;
int column_offset_;
MatrixCoord thread_offset_;
public:
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol(
Params const &params,
SharedStorage &shared_storage,
cutlass::MatrixCoord const &problem_size,
int thread_idx,
int warp_idx,
int lane_idx,
typename ScaleTileIterator::Params params_alpha_col,
typename OutputTileIterator::Params params_C,
typename OutputTileIterator::Params params_D,
bool with_bias,
bool per_token_quant,
bool per_channel_quant,
AlphaScaleElementType *ptr_alpha_row,
AlphaScaleElementType *ptr_alpha_col,
typename OutputTileIterator::Element *ptr_C,
typename OutputTileIterator::Element *ptr_D,
cutlass::MatrixCoord const &threadblock_offset = cutlass::MatrixCoord(0, 0),
int column_offset = 0,
cutlass::MatrixCoord const &problem_size_real = cutlass::MatrixCoord(0, 0))
: params_(params),
shared_storage_(shared_storage),
extent_(problem_size),
elementwise_(params.elementwise),
with_bias_(with_bias),
per_token_quant_(per_token_quant),
per_channel_quant_(per_channel_quant),
ptr_alpha_row_(ptr_alpha_row),
ptr_alpha_col_(ptr_alpha_col),
iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset),
iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
extent_real_(problem_size_real) {
if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) {
element_alpha_col_ = *ptr_alpha_col_;
}
if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) {
element_alpha_row_ = *ptr_alpha_row_;
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(
int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void begin_epilogue() {
if (per_channel_quant_) {
iterator_alpha_col_.load(fragment_alpha_col_);
}
if (with_bias_) {
iterator_C_.load(fragment_C_);
}
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void begin_step(int step_idx) {
fragment_D_.clear();
}
/// Called at the start of a row
CUTLASS_DEVICE
void begin_row(int row_idx) {
// load alpha_row in begin_step only when per token(row) scaling is used
if (per_token_quant_) {
int thread_offset_row = iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
}
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const &accum) {
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
ComputeFragment result = source_converter(accum);
if (per_channel_quant_) {
ComputeFragment alpha_col = reinterpret_cast<ComputeFragment *>(&fragment_alpha_col_)[column_idx];
result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
} else {
result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
}
if (with_bias_) {
NumericArrayConverter<ElementCompute, ElementOutput, kElementsPerAccess> bias_converter;
OutputVector bias = reinterpret_cast<OutputVector *>(&fragment_C_)[column_idx];
result = bias_accumulator_(result, bias_converter(bias));
}
// Convert to the output
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
OutputVector &output = reinterpret_cast<OutputVector *>(&fragment_D_)[frag_idx];
output = output_converter(result);
}
/// Called at the end of a row
CUTLASS_DEVICE
void end_row(int row_idx) {}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void end_step(int step_idx) {
iterator_D_.store(fragment_D_);
++iterator_D_;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void end_epilogue() {}
private:
CUTLASS_DEVICE
ComputeFragment per_token_channel_scale_accumulator_(
ComputeFragment const &accum, ComputeFragment const &scale_col, AlphaScaleElementType const &scale_row) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) {
result[i] = accum[i] * (scale_col[i] * scale_row);
}
return result;
}
CUTLASS_DEVICE
ComputeFragment per_token_scale_accumulator_(
ComputeFragment const &accum, AlphaScaleElementType const &scale_col, AlphaScaleElementType const &scale_row) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) {
result[i] = accum[i] * (scale_col * scale_row);
}
return result;
}
CUTLASS_DEVICE
ComputeFragment bias_accumulator_(ComputeFragment const &accum, ComputeFragment const &bias) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < OutputVector::kElements; ++i) {
result[i] = accum[i] + bias[i];
}
return result;
}
};
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/device_kernel.h>
#include <cutlass/trace.h>
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
that feature at the moment.
*/
template <typename GemmKernel_>
class GemmUniversalBaseCompat {
public:
using GemmKernel = GemmKernel_;
using ThreadblockShape = typename GemmKernel::Mma::Shape;
using ElementA = typename GemmKernel::ElementA;
using LayoutA = typename GemmKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
using ElementB = typename GemmKernel::ElementB;
using LayoutB = typename GemmKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
using ElementC = typename GemmKernel::ElementC;
using LayoutC = typename GemmKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
using Operator = typename GemmKernel::Operator;
/// Argument structure
using Arguments = typename GemmKernel::Arguments;
protected:
/// Kernel parameters object
typename GemmKernel::Params params_;
protected:
/// Private helper to obtain the grid dimensions with fix-up for split-K
static void get_grid_shape_(gemm::GemmCoord &grid_tiled_shape, int &gemm_k_size, Arguments const &args) {
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
gemm_k_size = args.problem_size.k();
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
int const kAlignK = const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
if (gemm_k_size) {
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
}
public:
/// Constructs the GEMM.
GemmUniversalBaseCompat() {}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) {
return Status::kErrorInvalidProblem;
}
return GemmKernel::can_implement(args);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()");
size_t workspace_bytes = 0;
// Determine grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
if (args.mode == GemmUniversalMode::kGemmSplitKParallel) {
// Split-K parallel always requires a temporary workspace
workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k());
} else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) {
// Serial split-K only requires a temporary workspace if the number of partitions along the
// GEMM K dimension is greater than one.
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
}
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
return workspace_bytes;
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const &args) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()");
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
CUTLASS_TRACE_HOST(
" grid_tiled_shape: " << grid_tiled_shape << "\n"
<< " result = {" << result << "}");
return result;
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
if (smem_size <= (48 << 10)) {
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
if (result == cudaSuccess) {
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
} else {
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1;
}
if (smem_capacity < 0) {
int device_idx = 0;
result = cudaGetDevice(&device_idx);
if (result != cudaSuccess) {
return -1;
}
cudaDeviceProp properties;
result = cudaGetDeviceProperties(&properties, device_idx);
if (result != cudaSuccess) {
return -1;
}
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
}
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
return occupancy;
}
CUTLASS_TRACE_HOST(" returning internal error");
return -1;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST(
"GemmUniversalBaseCompat::initialize() - workspace " << workspace
<< ", stream: " << (stream ? "non-null" : "null"));
size_t workspace_bytes = get_workspace_size(args);
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
if (workspace_bytes) {
if (!workspace) {
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
return Status::kErrorWorkspaceNull;
}
if (args.mode == GemmUniversalMode::kGemm) {
CUTLASS_TRACE_HOST(" clearing device workspace");
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
}
// Get CUDA grid shape
cutlass::gemm::GemmCoord grid_tiled_shape;
int gemm_k_size = 0;
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
// Initialize the Params structure
params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast<int *>(workspace));
// Specify shared memory capacity for kernel.
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
cudaError_t result = cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace) {
return Status::kErrorWorkspaceNull;
}
params_.update(args, workspace);
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()");
//
// Configure grid and block dimensions
//
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(GemmKernel::kThreadCount, 1, 1);
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
//
// Launch kernel
//
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes");
// Launch
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
//
// Query for errors
//
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess) {
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h
#pragma once
#include <cutlass/complex.h>
#include <cutlass/cutlass.h>
#include <cutlass/fast_math.h>
#include <cutlass/matrix_coord.h>
#include <cutlass/trace.h>
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
>
struct GemmWithEpilogueVisitor {
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueVisitor = typename Epilogue::Visitor;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using TensorRefA = TensorRef<ElementA, LayoutA>;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using TensorRefB = TensorRef<ElementB, LayoutB>;
using ElementCompute = typename EpilogueVisitor::ElementCompute;
using LayoutAlphaCol = cutlass::layout::RowMajor;
using LayoutAlphaRow = cutlass::layout::ColumnMajor;
using TensorRefAlphaCol = TensorRef<ElementCompute, LayoutAlphaCol>;
using TensorRefAlphaRow = TensorRef<ElementCompute, LayoutAlphaRow>;
using ElementC = typename EpilogueVisitor::ElementOutput;
using LayoutC = typename Epilogue::Layout;
using TensorRefC = TensorRef<ElementC, LayoutC>;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformB;
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
using EpilogueOutputOp =
typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Split-K preserves splits that are 128b aligned
static int const kSplitKAlignment = const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
//
// Structures
//
/// Argument structure
struct Arguments {
//
// Data members
//
GemmUniversalMode mode;
GemmCoord problem_size;
int batch_count;
TensorRefA ref_A;
TensorRefB ref_B;
TensorRefAlphaCol ref_alpha_col;
TensorRefAlphaRow ref_alpha_row;
TensorRefC ref_C;
TensorRefC ref_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
int64_t batch_stride_D;
typename EpilogueVisitor::Arguments epilogue_visitor;
//
// Methods
//
Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {}
/// constructs an arguments structure
Arguments(
GemmCoord problem_size_,
TensorRefA ref_A_,
TensorRefB ref_B_,
TensorRefAlphaCol ref_alpha_col_,
TensorRefAlphaRow ref_alpha_row_,
TensorRefC ref_C_,
TensorRefC ref_D_,
typename EpilogueVisitor::Arguments epilogue_visitor_)
: mode(GemmUniversalMode::kGemm),
problem_size(problem_size_),
batch_count(1),
ref_A(ref_A_),
ref_B(ref_B_),
ref_alpha_col(ref_alpha_col_),
ref_alpha_row(ref_alpha_row_),
ref_C(ref_C_),
ref_D(ref_D_),
batch_stride_A(0),
batch_stride_B(0),
batch_stride_D(0),
epilogue_visitor(epilogue_visitor_) {}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size;
cutlass::gemm::GemmCoord grid_tiled_shape;
int swizzle_log_tile;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col;
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row;
typename EpilogueVisitor::OutputTileIterator::Params params_C;
typename EpilogueVisitor::OutputTileIterator::Params params_D;
GemmUniversalMode mode;
int batch_count;
int gemm_k_size;
void *ptr_A;
void *ptr_B;
typename EpilogueVisitor::ScaleTileIterator::Element *ptr_alpha_col;
typename EpilogueVisitor::ScaleTileIterator::Element *ptr_alpha_row;
ElementC *ptr_C;
ElementC *ptr_D;
int64_t batch_stride_A;
int64_t batch_stride_B;
typename EpilogueVisitor::Params epilogue_visitor;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: swizzle_log_tile(0),
params_A(0),
params_B(0),
params_alpha_col(0),
params_C(0),
params_D(0),
batch_count(0),
gemm_k_size(0),
mode(cutlass::gemm::GemmUniversalMode::kGemm),
ptr_A(nullptr),
ptr_B(nullptr),
ptr_alpha_col(nullptr),
ptr_alpha_row(nullptr),
ptr_C(nullptr),
ptr_D(nullptr),
batch_stride_A(0),
batch_stride_B(0) {}
Params(Arguments const &args, cutlass::gemm::GemmCoord const &grid_tiled_shape_, int gemm_k_size_, int *workspace_)
: problem_size(args.problem_size),
swizzle_log_tile(0),
params_A(args.ref_A.layout()),
params_B(args.ref_B.layout()),
params_alpha_col(args.ref_alpha_col.layout()),
params_alpha_row(args.ref_alpha_col.layout()),
params_C(args.ref_C.layout()),
params_D(args.ref_D.layout()),
mode(args.mode),
batch_count(args.batch_count),
gemm_k_size(args.problem_size.k()),
ptr_A(args.ref_A.data()),
ptr_B(args.ref_B.data()),
ptr_alpha_col(args.ref_alpha_col.data()),
ptr_alpha_row(args.ref_alpha_row.data()),
ptr_C(args.ref_C.data()),
ptr_D(args.ref_D.data()),
batch_stride_A(args.batch_stride_A),
batch_stride_B(args.batch_stride_B),
epilogue_visitor(args.epilogue_visitor) {
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
int const kAlignK = const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
if (gemm_k_size) {
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
}
}
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
struct {
typename Epilogue::SharedStorage epilogue;
typename EpilogueVisitor::SharedStorage visitor;
} epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmWithEpilogueVisitor() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const &problem_size) {
CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess;
bool isAMisaligned = false;
bool isBMisaligned = false;
bool isCMisaligned = false;
if (platform::is_same<LayoutA, layout::RowMajor>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
} else if (platform::is_same<LayoutA, layout::ColumnMajor>::value) {
isAMisaligned = problem_size.m() % kAlignmentA;
} else if (
platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value || platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value) {
isAMisaligned = problem_size.k() % kAlignmentA;
}
if (platform::is_same<LayoutB, layout::RowMajor>::value) {
isBMisaligned = problem_size.n() % kAlignmentB;
} else if (platform::is_same<LayoutB, layout::ColumnMajor>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
} else if (
platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value || platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value) {
isBMisaligned = problem_size.k() % kAlignmentB;
}
if (platform::is_same<LayoutC, layout::RowMajor>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
} else if (platform::is_same<LayoutC, layout::ColumnMajor>::value) {
isCMisaligned = problem_size.m() % kAlignmentC;
} else if (
platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value || platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value) {
isCMisaligned = problem_size.n() % kAlignmentC;
}
if (isAMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
return Status::kErrorMisalignedOperand;
}
if (isBMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
return Status::kErrorMisalignedOperand;
}
if (isCMisaligned) {
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
return Status::kErrorMisalignedOperand;
}
CUTLASS_TRACE_HOST(" returning kSuccess");
return Status::kSuccess;
}
static Status can_implement(Arguments const &args) {
return can_implement(args.problem_size);
}
static size_t get_extra_workspace_size(Arguments const &args, cutlass::gemm::GemmCoord const &grid_tiled_shape) {
return 0;
}
#define SPLIT_K_ENABLED 1
/// Executes one GEMM
CUTLASS_DEVICE
void run_kernel_(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
int offset_k = 0;
int problem_size_k = params.problem_size.k();
ElementA *ptr_A = static_cast<ElementA *>(params.ptr_A);
ElementB *ptr_B = static_cast<ElementB *>(params.ptr_B);
#if SPLIT_K_ENABLED
//
// Fetch pointers based on mode.
//
if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) {
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) {
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
} else if (params.mode == GemmUniversalMode::kBatched) {
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
} else if (params.mode == GemmUniversalMode::kArray) {
ptr_A = static_cast<ElementA *const *>(params.ptr_A)[threadblock_tile_offset.k()];
ptr_B = static_cast<ElementB *const *>(params.ptr_B)[threadblock_tile_offset.k()];
}
#endif
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
offset_k,
};
cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
//
// Construct the epilogue visitor
//
bool with_bias = true;
if (params.ptr_C == nullptr) {
with_bias = false;
}
EpilogueVisitor epilogue_visitor(
params.epilogue_visitor,
shared_storage.epilogue.visitor,
params.problem_size.mn(),
thread_idx,
warp_idx,
lane_idx,
params.params_alpha_col,
params.params_C,
params.params_D,
with_bias,
true,
true,
params.ptr_alpha_row,
params.ptr_alpha_col,
params.ptr_C,
params.ptr_D,
threadblock_offset,
blockIdx.y * params.problem_size.m());
if (params.mode == GemmUniversalMode::kGemm) {
// Indicate which position in a serial reduction the output operator is currently updating
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
} else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) {
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
}
// Construct the epilogue
Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(epilogue_visitor, accumulators);
}
template <typename CompilationArch>
CUTLASS_DEVICE void run_kernel(Params const &params, SharedStorage &shared_storage) {
if constexpr (platform::is_same<ArchTag, CompilationArch>::value) {
run_kernel_(params, shared_storage);
} else {
CUTLASS_NOT_IMPLEMENTED();
}
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
run_kernel<ArchTag>(params, shared_storage);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
This diff is collapsed.
#ifdef ENABLE_CUTLASS_API
#include "../../../devices/nvidia/nvidia_handle.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "int8_gemm_kernel.cuh"
#include "int8_gemm_nvidia.cuh"
namespace op::i8gemm::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
inline int getSMVersion() {
int device{-1};
CHECK_CUDA(cudaGetDevice(&device));
int sm_major = 0;
int sm_minor = 0;
CHECK_CUDA(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
CHECK_CUDA(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
return sm_major * 10 + sm_minor;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t bias_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t a_scale_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t b_scale_desc) {
auto handle = reinterpret_cast<device::nvidia::Handle *>(handle_);
auto dtype = out_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
auto result = I8GemmInfo::create(out_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
new Opaque{handle->internal()},
result.take(), 0, dtype,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *out,
const void *bias,
const void *a,
const void *a_scale,
const void *b,
const void *b_scale,
void *stream) const {
auto sm_version = getSMVersion();
if (sm_version >= 75 && sm_version < 80) {
CHECK_DTYPE(this->_out_dtype, INFINI_DTYPE_F16);
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
out, a, b, a_scale, b_scale, bias, _info.m, _info.n, _info.k, _info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(), stream);
} else if (sm_version >= 80 && sm_version < 90) {
// sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
if (sm_version == 86 || sm_version == 89) {
if (this->_out_dtype == INFINI_DTYPE_BF16) {
sm89_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, a, b, a_scale, b_scale, bias, _info.m, _info.n, _info.k, _info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(), stream);
} else {
sm89_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, a, b, a_scale, b_scale, bias, _info.m, _info.n, _info.k, _info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(), stream);
}
} else {
if (this->_out_dtype == INFINI_DTYPE_BF16) {
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, a, b, a_scale, b_scale, bias, _info.m, _info.n, _info.k, _info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(), stream);
} else {
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, a, b, a_scale, b_scale, bias, _info.m, _info.n, _info.k, _info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(), stream);
}
}
} else if (sm_version == 90) {
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
// cutlass 3.x
if (this->_out_dtype == INFINI_DTYPE_BF16) {
sm90_dispatch_shape<cutlass::bfloat16_t>(
out, a, b, a_scale, b_scale, bias,
_info.m, _info.n, _info.k,
_info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(),
stream);
} else {
sm90_dispatch_shape<cutlass::half_t>(
out, a, b, a_scale, b_scale, bias,
_info.m, _info.n, _info.k,
_info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(),
stream);
}
#else
// // fallback to cutlass 2.x
if (this->_out_dtype == INFINI_DTYPE_BF16) {
sm80_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, a, b, a_scale, b_scale, bias, _info.m, _info.n, _info.k, _info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(), stream);
} else {
sm80_dispatch_shape<cutlass::half_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, a, b, a_scale, b_scale, bias, _info.m, _info.n, _info.k, _info.a_matrix.ld(), _info.b_matrix.ld(), _info.out_matrix.ld(), stream);
}
#endif
} else {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::i8gemm::nvidia
#endif
\ No newline at end of file
#ifndef __INT8_GEMM_NVIDIA_API_H__
#define __INT8_GEMM_NVIDIA_API_H__
#include "../int8_gemm.h"
DESCRIPTOR(nvidia)
#endif // __INT8_GEMM_NVIDIA_API_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/int8_gemm.h"
#if defined(ENABLE_NVIDIA_API) && defined(ENABLE_CUTLASS_API)
#include "nvidia/int8_gemm_nvidia.cuh"
#endif
__C infiniStatus_t infiniopCreateI8GemmDescriptor(infiniopHandle_t handle,
infiniopI8GemmDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t bias_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t a_scale_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t b_scale_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::i8gemm::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, \
bias_desc, \
a_desc, \
a_scale_desc, \
b_desc, \
b_scale_desc);
switch (handle->device) {
#if defined(ENABLE_NVIDIA_API) && defined(ENABLE_CUTLASS_API)
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetI8GemmWorkspaceSize(infiniopI8GemmDescriptor_t desc, size_t *size) {
switch (desc->device_type) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor *>(desc)->minWorkspaceSize(); \
return INFINI_STATUS_SUCCESS;
#if defined(ENABLE_NVIDIA_API) && defined(ENABLE_CUTLASS_API)
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
}
__C infiniStatus_t infiniopI8Gemm(infiniopI8GemmDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *bias,
const void *a,
const void *a_scale,
const void *b,
const void *b_scale,
void *stream) {
#define CACULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, bias, a, a_scale, b, b_scale, stream);
switch (desc->device_type) {
#if defined(ENABLE_NVIDIA_API) && defined(ENABLE_CUTLASS_API)
CACULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CACULATE
}
__C infiniStatus_t infiniopDestroyI8GemmDescriptor(infiniopI8GemmDescriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#if defined(ENABLE_NVIDIA_API) && defined(ENABLE_CUTLASS_API)
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DESTROY
}
......@@ -938,3 +938,42 @@ def tanh_(lib):
lib.infiniopDestroyTanhDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def scaled_mm_int8_(lib):
lib.infiniopCreateI8GemmDescriptor.restype = c_int32
lib.infiniopCreateI8GemmDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetI8GemmWorkspaceSize.restype = c_int32
lib.infiniopGetI8GemmWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopI8Gemm.restype = c_int32
lib.infiniopI8Gemm.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyI8GemmDescriptor.restype = c_int32
lib.infiniopDestroyI8GemmDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
......@@ -336,7 +336,7 @@ def rearrange_tensor(tensor, new_strides):
torch.float32,
torch.float64,
]:
new_tensor.view(-1).index_add_(0, new_positions, tensor.view(-1))
new_tensor.view(-1).index_add_(0, new_positions, tensor.contiguous().view(-1))
elif tensor.dtype in [torch.uint16, torch.uint32, torch.uint64]:
new_tensor_int64 = new_tensor.to(dtype=torch.int64)
tensor_int64 = tensor.to(dtype=torch.int64)
......
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
)
from enum import Enum, auto
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# x_shape, w_shape, y_shape, alpha, beta
((128, 512), (512, 1024), (128, 1024)),
((256, 1024), (1024, 2048), (256, 2048)),
((1024, 2048), (2048, 1024), (1024, 1024)),
]
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE = auto()
# Inplace options applied for each test case in _TEST_CASES_
_INPLACE = [
Inplace.INPLACE,
]
_TEST_CASES = [
test_case + (inplace_item,)
for test_case in _TEST_CASES_
for inplace_item in _INPLACE
]
# Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 3e-1, "rtol": 1e-2},
InfiniDtype.BF16: {"atol": 3e-1, "rtol": 1e-2},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
o = torch.matmul(a.to(torch.float32), b.to(torch.float32))
if bias is not None:
o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1) + bias
else:
o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1)
return o.to(out_dtype)
def test(
handle,
device,
x_shape,
w_shape,
y_shape,
inplace=Inplace.OUT_OF_PLACE,
dtype=InfiniDtype.BF16,
sync=None,
):
print(
f"Testing Linear on {InfiniDeviceNames[device]} with x_shape:{x_shape}, w_shape:{w_shape}, inplace:{inplace} dtype:{InfiniDtypeNames[dtype]}"
)
M, K = x_shape
N = w_shape[1]
x_packed = to_int8(torch.randn((M, K), device="cuda") * 5)
weights = to_int8(torch.randn((N, K), device="cuda").t() * 5)
x_scale = torch.randn((M,), device="cuda", dtype=torch.float32)
weights_scale = torch.randn((N,), device="cuda", dtype=torch.float32)
bias = torch.randn((N,), device="cuda", dtype=torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16) * 10
ans = torch_scaled_mm(x_packed, weights, x_scale, weights_scale, torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, bias=bias)
x_packed = TestTensor(
(M, K), x_packed.stride(), InfiniDtype.I8, device, mode="manual", set_tensor=x_packed
)
x_scale = TestTensor(
(M,), x_scale.stride(), InfiniDtype.F32, device, mode="manual", set_tensor=x_scale
)
weights = TestTensor(
(K, N), weights.stride(), InfiniDtype.I8, device, mode="manual", set_tensor=weights
)
weights_scale = TestTensor(
(N,), weights_scale.stride(), InfiniDtype.F32, device, mode="manual", set_tensor=weights_scale
)
y = TestTensor(y_shape, None, dtype, device)
bias = TestTensor((N,), bias.stride(), dtype, device, mode="manual", set_tensor=bias)
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreateI8GemmDescriptor(
handle,
ctypes.byref(descriptor),
y.descriptor,
bias.descriptor,
x_packed.descriptor,
x_scale.descriptor,
weights.descriptor,
weights_scale.descriptor,
)
)
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetI8GemmWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, x_packed.device)
def lib_linear():
check_error(
LIBINFINIOP.infiniopI8Gemm(
descriptor,
workspace.data(),
workspace_size.value,
y.data(),
bias.data(),
x_packed.data(),
x_scale.data(),
weights.data(),
weights_scale.data(),
None,
)
)
lib_linear()
if sync is not None:
sync()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(y.actual_tensor(), ans, atol=atol, rtol=rtol)
assert torch.allclose(y.actual_tensor(), ans, atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: torch_scaled_mm(x_packed, weights, x_scale, weights_scale, torch.float16 if dtype == InfiniDtype.F16 else torch.bfloat16, bias=bias), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_linear(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(LIBINFINIOP.infiniopDestroyI8GemmDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
......@@ -66,6 +66,16 @@ if has_config("cudnn") then
add_defines("ENABLE_CUDNN_API")
end
option("cutlass")
set_default(false)
set_showmenu(true)
set_description("Whether to compile cutlass for Nvidia GPU")
option_end()
if has_config("cutlass") then
add_defines("ENABLE_CUTLASS_API")
end
option("cuda_arch")
set_showmenu(true)
set_description("Set CUDA GPU architecture (e.g. sm_90)")
......
......@@ -4,11 +4,9 @@ if CUDNN_ROOT ~= nil then
end
local CUTLASS_ROOT = os.getenv("CUTLASS_ROOT") or os.getenv("CUTLASS_HOME") or os.getenv("CUTLASS_PATH")
local CUTE_ROOT = os.getenv("CUTE_ROOT") or os.getenv("CUTE_HOME") or os.getenv("CUTE_PATH")
if CUTLASS_ROOT ~= nil then
add_includedirs(CUTLASS_ROOT)
add_includedirs(CUTE_ROOT)
end
target("infiniop-nvidia")
......@@ -22,7 +20,6 @@ target("infiniop-nvidia")
if has_config("cudnn") then
add_links("cudnn")
end
add_cugencodes("native")
on_load(function (target)
import("lib.detect.find_tool")
......@@ -36,11 +33,6 @@ target("infiniop-nvidia")
target:add("linkdirs", path.directory(path.directory(nvcc_path)) .. "/lib64/stubs")
target:add("links", "cuda")
local cuda_arch = get_config("cuda_arch")
if cuda_arch ~= nil then
target:add("cu-cxxflags", "-arch=", cuda_arch)
end
end
end)
......@@ -65,6 +57,17 @@ target("infiniop-nvidia")
add_cuflags("-Xcompiler=-Wno-error=deprecated-declarations")
local arch_opt = get_config("cuda_arch")
if arch_opt and type(arch_opt) == "string" then
for _, arch in ipairs(arch_opt:split(",")) do
arch = arch:trim()
local compute = arch:gsub("sm_", "compute_")
add_cuflags("-gencode=arch=" .. compute .. ",code=" .. arch)
end
else
add_cugencodes("native")
end
set_languages("cxx17")
add_files("../src/infiniop/devices/nvidia/*.cu", "../src/infiniop/ops/*/nvidia/*.cu")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment