Commit 37c494a7 authored by Zhekai Zhang's avatar Zhekai Zhang
Browse files

Initial release

parents
/*
* Modified from NVIDIA [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include "gemv_awq.h"
#include "../dispatch_utils.h"
#define ENABLE_BF16 1
#include "../utils.cuh"
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <stdio.h>
#include "dequantize.cuh"
#define PACK_FACTOR 8
#define WARP_SIZE 32
#define MEM_ACCESS_SIZE 128
// Reduce sum within the warp using the tree reduction algorithm.
template <typename float_t, int Num, int WarpSize>
__device__ __forceinline__ static void warp_reduce(float_t* psum, float (*out_smem)[Num * 4])
{
// kInterleave = 4
float fpsum[Num];
#pragma unroll
for (int i = 0; i < Num; ++i)
{
fpsum[i] = static_cast<float>(psum[i]);
}
#pragma unroll
for (int i = 0; i < Num; ++i)
{
// T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);
}
__syncthreads();
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
if (lane == 0 || lane == 2 || lane == 4 || lane == 6)
{
#pragma unroll
for (int i = 0; i < Num; ++i)
{
out_smem[warp][i * 4 + lane / 2] = fpsum[i];
}
}
__syncthreads();
};
__device__ __forceinline__ int make_divisible(int c, int divisor){
return (c + divisor - 1) / divisor;
}
template<typename half_t>
__device__ __forceinline__
packed_as<half_t, 2>::type half2half2(half_t x);
template<>
__device__ __forceinline__
packed_as<half, 2>::type half2half2<half>(half x) {
return __half2half2(x);
}
template<>
__device__ __forceinline__
packed_as<__nv_bfloat16, 2>::type half2half2<__nv_bfloat16>(__nv_bfloat16 x) {
return __bfloat162bfloat162(x);
}
template<typename T>
__device__ __forceinline__
float2 half22float2(T val);
template<>
__device__ __forceinline__
float2 half22float2<half2>(half2 val) {
return __half22float2(val);
}
template<>
__device__ __forceinline__
float2 half22float2<__nv_bfloat162>(__nv_bfloat162 val) {
return __bfloat1622float2(val);
}
template <typename half_t, int NPerBlock, int Batch, int BlockSize, int GroupSize>
__global__ void gemv_kernel(
const half_t* inputs, const uint32_t* weight, const half_t* scales, const half_t* zeros, half_t* outputs,
const int IC, const int OC)
{
using half2_t = typename packed_as<half_t, 2>::type;
using accum_t = float;
using accum2_t = typename packed_as<accum_t, 2>::type;
const int kStride = 64;
const int kElemsPerThread = MEM_ACCESS_SIZE / 4;
const int kThreadsNumPerTile = kStride / kElemsPerThread;
// assert(MEM_ACCESS_SIZE == 128);
// static constexpr int kShuffleSize = 32;
static constexpr int kShuffleBasicTile = 2;
static constexpr int kShuffleContinous = 4;
static constexpr int kShuffleStrided = 4;
constexpr int Num = NPerBlock * Batch;
constexpr int kInterleave = 4;
alignas(16) half_t local_inputs[kElemsPerThread];
alignas(16) uint32_t local_qweights[MEM_ACCESS_SIZE / 32];
alignas(16) half_t half_weight_buffer[kElemsPerThread];
alignas(16) half_t dequantized_weight[kElemsPerThread * NPerBlock];
alignas(16) half_t local_scale[NPerBlock];
alignas(16) half_t local_scaled_zeros[NPerBlock];
accum_t psum[Num];
for (int i = 0; i < Num; ++i)
psum[i] = static_cast<accum_t>(0.f);
extern __shared__ uint8_t shmem[];
float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;
const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;
const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride
+ (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;
const int group_offset = act_k_offset / GroupSize;
// TODO: use make_divisible
const uint32_t* blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;
const half_t* scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC;
const half_t* zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC;
const half_t* inputs_ptr = inputs + act_k_offset;
const int act_forward_step = BlockSize * kElemsPerThread / kInterleave;
const int scale_forward_step = act_forward_step / GroupSize * OC;
// Main loop iteration, each block completes the outputs for several OCs
for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread)
{
// Load qweight, scales and scaled_zeros
#pragma unroll
for (int idx = 0; idx < NPerBlock; ++idx)
{
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit)
*((float4*)(local_qweights)) =
*((float4*)(blk_weight_ptr + (idx * kInterleave * IC + kk)/ PACK_FACTOR));
local_scale[idx] = *(scale_ptr + idx * kInterleave);
local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave);
// Map int4 qweight to fp format
#pragma unroll
for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i)
{
// Converts 32 bits (8 x int4) to 8 fp16
dequantize_s4_to_fp16x2(*reinterpret_cast<half2_t *>(local_qweights + i), reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR));
}
// Dequantize (apply s/z) and shuffle elements to match the weight packing format
#pragma unroll
for (int i = 0; i < kShuffleContinous; ++i)
{
#pragma unroll
for (int j = 0; j < kShuffleStrided; ++j)
{
half2_t w =
*reinterpret_cast<half2_t*>(
half_weight_buffer + (i + j * kShuffleContinous)* kShuffleBasicTile
);
w = __hfma2(w, half2half2(local_scale[idx]), half2half2(local_scaled_zeros[idx]));
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0)
* NPerBlock + idx]
= w.x;
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1)
* NPerBlock + idx]
= w.y;
}
}
}
#pragma unroll
for (int batch_idx = 0; batch_idx < Batch; ++batch_idx)
{
const half_t* local_inputs_ptr = inputs_ptr + batch_idx * IC;
#pragma unroll
for (int idx = 0; idx < kElemsPerThread / 8; ++idx)
{
// load activation, 8 halves (128 bits) / step.
*((float4*)(local_inputs + idx * 8)) = *((float4*)(local_inputs_ptr + idx * 8));
}
// Perform the MACs
#pragma unroll
for (int x = 0; x < NPerBlock / 2; ++x)
{
#pragma unroll
for (int y = 0; y < kElemsPerThread; ++y)
{
accum2_t prod = cuda_cast<accum2_t>(__hmul2(*reinterpret_cast<half2_t*>(dequantized_weight + y * NPerBlock + x * 2), half2half2(local_inputs[y])));
*reinterpret_cast<accum2_t*>(psum + batch_idx * NPerBlock + x * 2)
= prod + *reinterpret_cast<accum2_t*>(psum + batch_idx * NPerBlock + x * 2);
// *reinterpret_cast<half2_t*>(psum + batch_idx * NPerBlock + x * 2)
// = __hfma2(*reinterpret_cast<half2_t*>(dequantized_weight + y * NPerBlock + x * 2),
// half2half2(local_inputs[y]),
// *reinterpret_cast<half2_t*>(psum + batch_idx * NPerBlock + x * 2));
}
}
}
inputs_ptr += act_forward_step;
scale_ptr += scale_forward_step;
zeros_ptr += scale_forward_step;
}
warp_reduce<accum_t, Num, WARP_SIZE>(psum, out_smem);
// Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num
for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize)
{
int batch_idx = i / (NPerBlock * kInterleave);
int oc_idx = i % (NPerBlock * kInterleave);
float acc = 0.f;
for (int j = 0; j < BlockSize / WARP_SIZE; ++j)
{
acc += out_smem[j][i];
}
outputs[batch_idx * OC + blk_row_offset + oc_idx] = static_cast<half_t>(acc);
}
}
/*
Computes GEMV (PyTorch interface).
Args:
_in_feats: tensor of shape [B, IC];
_kernel: int tensor of shape [OC, IC // 8];
_zeros: int tensor of shape [OC, IC // G // 8];
_scaling_factors: tensor of shape [OC, IC // G];
blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
Returns:
out_feats: tensor of shape [B, OC];
*/
Tensor gemv_awq(
Tensor _in_feats,
Tensor _kernel,
Tensor _scaling_factors,
Tensor _zeros,
int m,
int n,
int k,
int group_size)
{
using half_t = __nv_bfloat16;
// using half_t = half;
assert(isTypeMatch<half_t>(_in_feats.dtype()));
auto output_shape = _in_feats.shape.dataExtent;
output_shape.back() = n;
auto in_feats = reinterpret_cast<half_t*>(_in_feats.data_ptr<half_t>());
auto kernel = reinterpret_cast<uint32_t*>(_kernel.data_ptr());
auto zeros = reinterpret_cast<half_t*>(_zeros.data_ptr<half_t>());
auto scaling_factors = reinterpret_cast<half_t*>(_scaling_factors.data_ptr<half_t>());
Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
half_t * out_feats = reinterpret_cast<half_t *>(_out_feats.data_ptr());
static constexpr int N_PER_BLOCK = 2;
static constexpr int K_INTERLEAVE = 4;
static constexpr int BLOCK_SIZE = 256;
dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
dim3 num_threads(BLOCK_SIZE);
constexpr int GROUP_SIZE = 64;
assert(m > 0 && m < 8);
assert(group_size == GROUP_SIZE);
dispatchVal(m, std::make_integer_sequence<int, 8>(), [&]<int M>() {
if constexpr (M == 0) {
assert(false);
return;
}
if constexpr (M > 0) {
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
checkCUDA(cudaGetLastError());
}
});
return _out_feats;
}
#pragma once
#include "common.h"
#include "Tensor.h"
Tensor gemv_awq(
Tensor _in_feats,
Tensor _kernel,
Tensor _scaling_factors,
Tensor _zeros,
int m,
int n,
int k,
int group_size);
\ No newline at end of file
#pragma once
#include "common.h"
#include "Tensor.h"
#include <cuda_fp16.h>
template<typename F>
inline auto dispatchFloat(Tensor::ScalarType scalarType, F &&func) {
switch (scalarType) {
case Tensor::BF16:
return func.template operator()<__nv_bfloat16>();
case Tensor::FP16:
return func.template operator()<half>();
case Tensor::FP32:
return func.template operator()<float>();
default:
assert(false);
}
}
template<typename F>
inline auto dispatch(Tensor::ScalarType scalarType, F &&func) {
switch (scalarType) {
case Tensor::BF16:
return func.template operator()<__nv_bfloat16>();
case Tensor::FP16:
return func.template operator()<half>();
case Tensor::FP32:
return func.template operator()<float>();
case Tensor::INT8:
return func.template operator()<int8_t>();
case Tensor::INT32:
return func.template operator()<int32_t>();
case Tensor::INT64:
return func.template operator()<int64_t>();
default:
throw std::runtime_error("Unsupported scalar type");
}
}
#pragma nv_diagnostic push
// warning #445-D: template parameter "scalar_t" is not used in declaring the parameter types of function template "lambda []()->auto::operator auto (*)()"
#pragma nv_diag_suppress 445
template<typename T>
inline bool isTypeMatch(Tensor::ScalarType scalarType) {
return dispatch(scalarType, []<typename scalar_t>() {
return std::is_same_v<scalar_t, T>;
});
}
#pragma nv_diagnostic pop
template<typename F, int ...N>
inline auto dispatchVal(int val, std::integer_sequence<int, N...>, F &&func) {
auto call = [&]<int i>() {
if (val == i) {
func.template operator()<i>();
}
};
(call.template operator()<N>(), ...);
}
template<typename F>
inline auto dispatchBool(bool val, F &&func) {
if (val) {
func.template operator()<true>();
} else {
func.template operator()<false>();
}
}
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) dispatchFloat(TYPE, [&]<typename scalar_t>() { __VA_ARGS__(); });
\ No newline at end of file
#include "common.h"
#include "Tensor.h"
#include <cuda_runtime.h>
#include "cutlass/cutlass.h"
#include "cutlass/conv/device/direct_convolution.h"
#include "cutlass/conv/kernel/default_depthwise_fprop.h"
#include "cutlass/conv/device/implicit_gemm_convolution.h"
// depthwise_Conv2d operation cutlass_sm80_tensorop_f16_s16x8x16fprop_analytic_f16_256x128_64x3_nhwc_align8
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, 64>;
using FilterShape = cutlass::MatrixShape<3, 3>;
using ThreadblockShape =
cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, 64, FilterShape::kCount>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, FilterShape::kCount>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
cutlass::half_t,
cutlass::layout::TensorNHWC,
cutlass::half_t,
cutlass::layout::TensorNHWC,
cutlass::half_t,
cutlass::layout::TensorNHWC,
cutlass::half_t,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80, // TODO
ThreadblockShape,
ThreadBlockOutputShape,
FilterShape,
WarpShape,
InstructionShape,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, float, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>,
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
1,
ThreadBlockOutputShape::kN,
ThreadBlockOutputShape::kH,
ThreadBlockOutputShape::kW>,
4,
cutlass::arch::OpMultiplyAdd,
cutlass::conv::IteratorAlgorithm::kFixedStrideDilation,
cutlass::conv::StrideSupport::kFixed,
cutlass::MatrixShape<1, 1>,
cutlass::MatrixShape<1, 1>>::Kernel;
using DeviceKernel =
typename cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;
using UnderlyingKernel = typename DeviceKernel::UnderlyingKernel;
namespace {
using TensorRefA = typename UnderlyingKernel::TensorRefA;
using TensorRefB = typename UnderlyingKernel::TensorRefB;
using TensorRefC = typename UnderlyingKernel::TensorRefC;
using ElementCompute = typename UnderlyingKernel::EpilogueOutputOp::ElementCompute;
}
template <typename TensorRef, typename Element>
TensorRef get_tensor_ref(cutlass::Tensor4DCoord tensor_coord, Element *ptr) {
cutlass::layout::TensorNHWC layout = cutlass::layout::TensorNHWC::packed(tensor_coord);
TensorRef tensor_ref(ptr, layout);
return tensor_ref;
}
static cutlass::Status depthwise_conv2d_kernel_run(cutlass::conv::Conv2dProblemSize *problem_size,
UnderlyingKernel::ElementA *A, UnderlyingKernel::ElementB *B,
UnderlyingKernel::ElementC *C, UnderlyingKernel::ElementC *D,
ElementCompute alpha, ElementCompute beta, std::string split_k_mode,
cudaStream_t stream, int device_id = 0)
{
// create the tensor references
cutlass::Tensor4DCoord tensor_coord_A = cutlass::conv::implicit_gemm_tensor_a_extent(
cutlass::conv::Operator::kFprop, *problem_size);
cutlass::Tensor4DCoord tensor_coord_B = cutlass::conv::implicit_gemm_tensor_b_extent(
cutlass::conv::Operator::kFprop, *problem_size);
cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent(
cutlass::conv::Operator::kFprop, *problem_size);
TensorRefA tensor_ref_A = get_tensor_ref<TensorRefA, UnderlyingKernel::ElementA>(tensor_coord_A, A);
TensorRefB tensor_ref_B = get_tensor_ref<TensorRefB, UnderlyingKernel::ElementB>(tensor_coord_B, B);
TensorRefC tensor_ref_C = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, C);
TensorRefC tensor_ref_D = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, D);
cutlass::conv::SplitKMode mode;
if (split_k_mode == "serial") {
mode = cutlass::conv::SplitKMode::kSerial;
} else if (split_k_mode == "parallel") {
mode = cutlass::conv::SplitKMode::kParallel;
} else {
throw std::runtime_error("Invalid split_k_mode: " + split_k_mode);
}
typename DeviceKernel::Arguments arguments{
*problem_size,
tensor_ref_A,
tensor_ref_B,
tensor_ref_C,
tensor_ref_D,
{alpha, beta},
tensor_ref_B,
// mode
};
DeviceKernel implicit_gemm_op;
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
BufferCUDA workspace(workspace_size);
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
return status;
}
status = implicit_gemm_op.initialize(arguments, workspace.getPtr(), stream);
if (status != cutlass::Status::kSuccess) {
return status;
}
//
// Launch initialized CUTLASS kernel
//
status = implicit_gemm_op(stream);
return status;
}
Tensor depthwise_conv2d_kernel(Tensor A, Tensor B) {
int N, H, W, C_, K, C__, R, S, P, Q;
N = A.size(0);
H = A.size(1);
W = A.size(2);
C_ = A.size(3);
// std::cout << A.dtype() << std::endl;
// std::cout << "A: " << N << ", " << H << ", " << W << ", " << C_ << std::endl;
// for (int h = 0; h < H; ++h)
// {
// for (int w = 0; w < W; ++w)
// std::cout << A[0][h][w][0].item() << " ";
// std::cout << std::endl;
// }
K = B.size(0);
R = B.size(1);
S = B.size(2);
C__ = B.size(3);
// std::cout << B.dtype() << std::endl;
// std::cout << "B: " << K << ", " << R << ", " << S << ", " << C__ << std::endl;
// for (int h = 0; h < R; ++h)
// {
// for (int w = 0; w < S; ++w)
// std::cout << B[0][h][w][0].item() << " ";
// std::cout << std::endl;
// }
cutlass::conv::Conv2dProblemSize problem_size(
cutlass::Tensor4DCoord(N, H, W, C_),
cutlass::Tensor4DCoord(K, R, S, C__),
cutlass::Tensor4DCoord(1, 1, 1, 1),
cutlass::MatrixCoord(1, 1),
cutlass::MatrixCoord(1, 1),
cutlass::conv::Mode::kCrossCorrelation,
1,
C_ // groups
);
P = problem_size.P;
Q = problem_size.Q;
// printf("P=%d Q=%d\n", P, Q);
typename UnderlyingKernel::ElementC *ptrC = nullptr;
Tensor D = Tensor::allocate({N, P, Q, K}, A.dtype(), A.device());
auto stream = getCurrentCUDAStream();
cutlass::Status status = depthwise_conv2d_kernel_run(
&problem_size,
reinterpret_cast<typename UnderlyingKernel::ElementA *>(A.data_ptr()),
reinterpret_cast<typename UnderlyingKernel::ElementB *>(B.data_ptr()),
ptrC,
reinterpret_cast<typename UnderlyingKernel::ElementC *>(D.data_ptr()),
1, 0,
"serial", stream, B.device().idx);
assert(status == cutlass::Status::kSuccess);
return D;
}
#pragma once
#include "common.h"
#include "Tensor.h"
Tensor depthwise_conv2d_kernel(Tensor A, Tensor B);
\ No newline at end of file
../../third_party/Block-Sparse-Attention/csrc/block_sparse_attn
\ No newline at end of file
#include "gemm_batched.h"
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/gemm/device/gemm_batched.h>
#include <cutlass/numeric_types.h>
using spdlog::fmt_lib::format;
Tensor gemm_batched_fp16(
Tensor a, // FP16 row-major [(... batch ...), M, K]
Tensor b, // FP16 col-major [(... batch ...), N, K]
Tensor out // FP32 row-major [(... batch ...), M, N]
)
{
const int M = a.shape[-2];
const int K = a.shape[-1];
const int N = a.shape[-2];
const int batch = a.numel() / (M * K);
using ElementInput = cutlass::half_t;
using ElementOutput = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutO = cutlass::layout::RowMajor;
using Gemm = cutlass::gemm::device::GemmBatched<
ElementInput, LayoutA,
ElementInput, LayoutB,
ElementOutput, LayoutO,
ElementOutput,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<32, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementOutput, ElementOutput>,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
2>;
auto sizeA = cutlass::MatrixCoord(M, K);
auto sizeB = cutlass::MatrixCoord(K, N);
auto sizeO = cutlass::MatrixCoord(M, N);
if (!out.valid()) {
auto outShape = a.shape;
outShape[-1] = N;
out = Tensor::empty(outShape, Tensor::FP32, a.device());
}
assert(K == b.shape[-1]);
assert(M == out.shape[-2]);
assert(N == out.shape[-1]);
assert(a.dtype() == Tensor::FP16);
assert(a.dtype() == b.dtype());
assert(out.dtype() == Tensor::FP32);
cutlass::gemm::GemmCoord problemSize(M, N, K);
cutlass::TensorRef<ElementInput, LayoutA> refA(
a.data_ptr<ElementInput>(), LayoutA(a.stride(-2)));
cutlass::TensorRef<ElementInput, LayoutB> refB(
b.data_ptr<ElementInput>(), LayoutB(b.stride(-2)));
cutlass::TensorRef<ElementOutput, LayoutO> refO(
out.data_ptr<ElementOutput>(), LayoutO(out.stride(-2)));
typename Gemm::Arguments arguments{
problemSize,
refA,
(int)a.stride(-3),
refB,
(int)b.stride(-3),
refO,
(int)out.stride(-3),
refO,
(int)out.stride(-3),
{ ElementOutput(1), ElementOutput(0) },
batch
};
Gemm op;
BufferCUDA workspace(Gemm::get_workspace_size(arguments));
cutlass::Status status = op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(format("cutlass cannot implement M={} N={} K={}", M, N, K));
}
status = op.initialize(arguments, workspace.getPtr());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = op();
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
return out;
}
\ No newline at end of file
#pragma once
#include "common.h"
#include "Tensor.h"
Tensor gemm_batched_fp16(
Tensor a, // FP16 row-major [(... batch ...), M, K]
Tensor b, // FP16 col-major [(... batch ...), N, K]
Tensor out // FP32 row-major [(... batch ...), M, N]
);
\ No newline at end of file
#include "gemm_f16.h"
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/bfloat16.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>
using spdlog::fmt_lib::format;
Tensor gemm_f16(Tensor input, // FP16
Tensor weight, // FP16
Tensor out, // FP16
float alpha,
float beta
) {
auto N = weight.size(0);
auto K = input.size(-1);
auto M = input.numel() / K;
assert(weight.size(1) == K);
spdlog::debug("gemm_f16: M={} K={} N={}", M, K, N);
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using ElementComputeEpilogue = cutlass::bfloat16_t;
using ElementInputA = cutlass::bfloat16_t; // <- data type of elements in input matrix A
using ElementInputB = cutlass::bfloat16_t; // <- data type of elements in input matrix B
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// #if CUDA_ARCH >= 800
using Gemm = cutlass::gemm::device::Gemm<
ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementComputeEpilogue>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
auto input_size = cutlass::MatrixCoord(M, K);
auto weight_size = cutlass::MatrixCoord(K, N);
auto output_size = cutlass::MatrixCoord(M, N);
auto device = input.device();
// use the broadcasted bias as the output
// auto out = bias.to(device).view({1, -1}).repeat({M, 1});
if (!out.valid()) {
auto out_shape = input.shape;
out_shape[-1] = N;
out = Tensor::empty(out_shape, input.scalar_type(), input.device());
}
// FIXME: check contiguous of input if dims >= 3
assert(input.stride(-1) == 1);
// assert(input.is_contiguous());
assert(weight.is_contiguous());
assert(out.dtype() == input.scalar_type());
assert(out.shape[-1] == N);
assert(out.numel() / out.shape[-1] == M);
assert(out.stride(-1) == 1);
// FIXME: check contiguous of output if dims >= 3
// constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass::gemm::GemmCoord problem_size(M, N, K);
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
input.data_ptr<ElementInputA>(), LayoutInputA(input.stride(-2)));
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
out.data_ptr<ElementOutput>(), LayoutOutput(out.stride(-2)));
typename Gemm::Arguments arguments{
problem_size, // <- problem size of matrix multiplication
input_ref, // <- reference to matrix A on device
weight_ref, // <- reference to matrix B on device
out_ref, // <- reference to matrix C on device
out_ref, // <- reference to matrix D on device
{ElementOutput(alpha), ElementOutput(beta)},
1};
Gemm gemm_op;
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
// cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
BufferCUDA workspace(workspace_size);
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(format("cutlass cannot implement M={} N={} K={}", M, N, K));
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.getPtr());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
return out;
}
#pragma once
#include "common.h"
#include "Tensor.h"
Tensor gemm_f16(
Tensor input, // FP16
Tensor weight, // FP16
Tensor out, // FP16
float alpha,
float beta
);
\ No newline at end of file
#pragma once
#include <cstdint>
#include "common.h"
#include "utils.cuh"
static constexpr int clamp(int val, int min, int max) {
if (val < min)
return min;
if (val > max)
return max;
return val;
}
template<bool shmem = false, typename T>
__device__ __forceinline__
static T load(const T *addr) {
if constexpr (shmem) {
if constexpr (sizeof(T) == 8) {
uint2 data;
asm volatile ("ld.shared.v2.b32 {%0, %1}, [%2];" : "=r"(data.x), "=r"(data.y) : "l"(__cvta_generic_to_shared(addr)));
return *reinterpret_cast<T *>(&data);
}
if constexpr (sizeof(T) == 16) {
uint4 data;
asm volatile ("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];" : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) : "l"(__cvta_generic_to_shared(addr)));
return *reinterpret_cast<T *>(&data);
}
return *addr;
}
if constexpr (sizeof(T) == 8) {
uint2 data = __ldg(reinterpret_cast<const uint2 *>(addr));
return *reinterpret_cast<T *>(&data);
}
if constexpr (sizeof(T) == 16) {
uint4 data = __ldg(reinterpret_cast<const uint4 *>(addr));
return *reinterpret_cast<T *>(&data);
}
return *addr;
}
template<bool shmem = false, typename T>
__device__ __forceinline__
static void store(T *addr, T val) {
if constexpr (shmem) {
if constexpr (sizeof(T) == 8) {
uint2 data = *reinterpret_cast<uint2 *>(&val);
asm volatile ("st.shared.v2.b32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(addr)), "r"(data.x), "r"(data.y));
return;
}
if constexpr (sizeof(T) == 16) {
uint4 data = *reinterpret_cast<uint4 *>(&val);
asm volatile ("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(addr)), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w));
return;
}
*addr = val;
return;
}
if constexpr (sizeof(T) == 4) {
__stcg(reinterpret_cast<uint *>(addr), *reinterpret_cast<uint *>(&val));
return;
}
if constexpr (sizeof(T) == 8) {
__stcg(reinterpret_cast<uint2 *>(addr), *reinterpret_cast<uint2 *>(&val));
return;
}
if constexpr (sizeof(T) == 16) {
__stcg(reinterpret_cast<uint4 *>(addr), *reinterpret_cast<uint4 *>(&val));
return;
}
*addr = val;
}
template<typename T>
__device__ __forceinline__
float2 half22float2(T val);
template<>
__device__ __forceinline__
float2 half22float2<half2>(half2 val) {
return __half22float2(val);
}
template<>
__device__ __forceinline__
float2 half22float2<__nv_bfloat162>(__nv_bfloat162 val) {
return __bfloat1622float2(val);
}
template<typename T>
__device__ __forceinline__
T float22half2(float2 val);
template<>
__device__ __forceinline__
half2 float22half2<half2>(float2 val) {
return __float22half2_rn(val);
}
template<>
__device__ __forceinline__
__nv_bfloat162 float22half2<__nv_bfloat162>(float2 val) {
return __float22bfloat162_rn(val);
}
template<typename T>
__device__ __forceinline__
void unused_var(T &val, bool alwaysfalse) {
volatile T *ptr = nullptr;
if (alwaysfalse) {
*ptr = val;
}
}
__device__ __forceinline__
static void ldmatrix(const void *ptr, uint4 &out) {
asm volatile(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(out.x), "=r"(out.y), "=r"(out.z), "=r"(out.w)
: "l"(__cvta_generic_to_shared(ptr))
);
}
// x in low bit, y in high bit
template<int bitwidth, bool use_unsigned>
__device__ __forceinline__
uint32_t quantize_float2(float2 value) = delete;
template<>
__device__ __forceinline__
uint32_t quantize_float2<4, false>(float2 value) {
int v1, v2;
uint32_t result;
asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
asm volatile ("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
return result;
}
template<>
__device__ __forceinline__
uint32_t quantize_float2<4, true>(float2 value) {
int v1, v2;
uint32_t result;
asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
asm volatile ("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
return result;
}
template<>
__device__ __forceinline__
uint32_t quantize_float2<8, false>(float2 value) {
int v1, v2;
uint32_t result;
asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v1) : "f"(value.x));
asm volatile ("cvt.rni.s32.f32 %0, %1;" : "=r"(v2) : "f"(value.y));
asm volatile ("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
return result;
}
__device__ __forceinline__
static float cuda_tanhf(float x) {
float result;
asm ("tanh.approx.f32 %0, %1;" : "=f"(result) : "f"(x));
return result;
}
__device__ __forceinline__
static float cuda_frcp(float x) {
float result;
asm ("rcp.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
return result;
}
__device__ __forceinline__
static float cuda_frsqrt(float x) {
float result;
asm ("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
return result;
}
__device__ __forceinline__
static float cuda_sin(float x) {
float result;
asm ("sin.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
return result;
}
__device__ __forceinline__
static float cuda_cos(float x) {
float result;
asm ("cos.approx.ftz.f32 %0, %1;" : "=f"(result) : "f"(x));
return result;
}
// https://forums.developer.nvidia.com/t/hardware-accelerated-computation-of-the-sigmoid-logistic-function/266206
__forceinline__ __device__
static float cuda_sigmoidf (float a)
{
#if USE_TANH
return fmaf (0.5, __tanhf (0.5f * a), 0.5f);
#else // USE_TANH
const float L2E = 1.442695041f; // log2(exp(1))
float t, d, e, r;
t = -L2E * a;
asm ("ex2.approx.ftz.f32 %0,%1;\n\t" : "=f"(e) : "f"(t));
d = e + 1.0f;
asm ("rcp.approx.ftz.f32 %0,%1;\n\t" : "=f"(r) : "f"(d));
return r;
#endif // USE_TANH
}
template<typename T>
__device__ __forceinline__
static T gelu_half2(T x) {
float2 xf = half22float2<T>(x);
float2 x3f = xf * xf * xf;
float t1 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.x + (0.044715f * x3f.x)));
float t2 = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf.y + (0.044715f * x3f.y)));
return float22half2<T>(xf * make_float2(t1, t2));
}
template<typename T>
__device__ __forceinline__
static T gelu_half(T x) {
float xf = float(x);
float x3f = xf * xf * xf;
float t = 0.5f + 0.5f * cuda_tanhf(0.79788456f * (xf + (0.044715f * x3f)));
return (T)(xf * t);
}
template <typename T>
__device__ __forceinline__
static T silu(const T &x) {
// x * sigmoid(x)
return (T)((float)x * cuda_sigmoidf((float)x));
// return (T)__fdividef((float)x, 1.0f + __expf((float)-x));
}
__device__ __forceinline__
static void reduce_add(float *addr, float val) {
asm volatile ("red.relaxed.gpu.global.add.f32 [%0], %1;" :: "l"(addr), "f"(val));
}
template<int cnt, typename F>
__device__ __forceinline__
static void unrolled_loop(F &&lambda) {
auto call = [&]<int ...Is>(std::integer_sequence<int, Is...>) {
(lambda.template operator()<Is>(), ...);
};
call(std::make_integer_sequence<int, cnt>());
}
\ No newline at end of file
#include "common.h"
#include "Tensor.h"
#include "utils.cuh"
#include "gemm_utils.cuh"
#include "dispatch_utils.h"
#pragma nv_diag_suppress 177
#define ALWAYSINLINE __attribute__((always_inline))
// #define ENABLE_NAN_CHECK 1
#if ENABLE_NAN_CHECK
#define STRINGIZE(x) STRINGIZE2(x)
#define STRINGIZE2(x) #x
#define CHECK_NAN(data, name) checkNan(data, name " at " STRINGIZE(__LINE__))
#else
#define CHECK_NAN(data, name)
#endif
class GEMMConfig_W4A4 {
public:
// BE CAREFUL: weights need to be repacked when the tiling size changes
static constexpr int BLOCK_M = 256;
static constexpr int BLOCK_N = 128;
static constexpr int WARP_SIZE = 32;
static constexpr int NUM_WARPS = 8;
static constexpr int INSN_M = 16;
static constexpr int INSN_N = 16;
static constexpr int INSN_K = 64;
#if 0
using half_t = half;
using half2_t = half2;
#else
using half_t = __nv_bfloat16;
using half2_t = __nv_bfloat162;
#endif
};
class GEMMConfig_W8A8 {
public:
static constexpr int BLOCK_M = 256;
static constexpr int BLOCK_N = 128;
static constexpr int WARP_SIZE = 32;
static constexpr int NUM_WARPS = 8;
static constexpr int INSN_M = 16;
static constexpr int INSN_N = 16;
static constexpr int INSN_K = 32;
using half_t = half;
using half2_t = half2;
};
template<class Config>
class GEMMBase : public Config {
public:
using Config::BLOCK_M;
using Config::BLOCK_N;
using Config::WARP_SIZE;
using Config::NUM_WARPS;
using Config::INSN_M;
using Config::INSN_N;
using Config::INSN_K;
using typename Config::half_t;
using typename Config::half2_t;
static constexpr int WARP_M = BLOCK_M / NUM_WARPS;
static constexpr int WARP_N = BLOCK_N;
static constexpr int WARP_K = INSN_K;
static constexpr int WARP_M_TILES = WARP_M / INSN_M;
static constexpr int WARP_N_TILES = WARP_N / INSN_N;
static constexpr int WARP_K_TILES = WARP_K / INSN_K;
/**
* refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c
*
* wscales store order: (pack = 4)
* 0 1 8 9 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
* 2 3 10 11 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
* 4 5 12 13 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
* 6 7 14 15 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* 16 17 24 25 <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 22 23 30 31 <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x)
* ... ...
* 112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* wscales store order: (pack = 8)
* 0 1 8 9 16 17 24 25 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
* 2 3 10 11 18 19 26 27 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
* 4 5 12 13 20 21 28 29 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
* 6 7 14 15 22 23 30 31 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* 224 225 232 233 240 241 248 249 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 230 231 238 239 246 247 254 255 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* {k}-th wscale used by lane {i} => {k // (WSCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {4*(k // WSCALES_PACK_SIZE) + i % 4}, element {k % WSCALES_PACK_SIZE}
*
* max pack size set to 8 since max load size is 16 bytes / lane
* min pack size set to 2 since shuffle granularity is 32b 2*half
* */
static constexpr int WSCALES_PACK_SIZE = clamp(WARP_N / WARP_SIZE, 4 / sizeof(half), 16 / sizeof(half));
static constexpr int WSCALES_NUM_PACKS = ceilDiv(WARP_N, (WSCALES_PACK_SIZE * WARP_SIZE));
static constexpr int WSCALES_VALID_LANES = std::min(WARP_SIZE, WARP_N / WSCALES_PACK_SIZE);
/**
* ascales store order: (pack = 2)
* 0 8 <-- load by lane 0, broadcast to lane {0, 1, 2, 3} (4x)
* 1 9 <-- load by lane 1, broadcast to lane {4, 5, 6, 7} (4x)
* 2 10
* ...
* 6 14
* 7 15 <-- load by lane 7, broadcast to lane {28, 29, 30, 31} (4x)
* ... ...
* 48 56 <-- load by lane 24, broadcast to lane {0, 1, 2, 3} (4x)
* 49 57
* ...
* 54 62
* 55 63 <-- load by lane 31, broadcast to lane {28, 29, 30, 31} (4x)
*
* {k}-th wscale used by lane {i} => {k // (ASCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {8*(k // ASCALES_PACK_SIZE) + i // 4}, element {k % ASCALES_PACK_SIZE}
*/
static constexpr int ASCALES_PACK_SIZE = clamp(WARP_M / WARP_SIZE, 4 / sizeof(half), 16 / sizeof(half));
static constexpr int ASCALES_NUM_PACKS = ceilDiv(WARP_M, (ASCALES_PACK_SIZE * WARP_SIZE));
static constexpr int ASCALES_VALID_LANES = std::min(WARP_SIZE, WARP_M / ASCALES_PACK_SIZE);
using packed_act_t = uint4;
using packed_wgt_t = uint4;
struct alignas(32) packed_psum_t {
int data[8];
};
struct alignas(16) packed_fpsum_t {
half2_t data[4];
};
struct alignas(8) packed_gated_fpsum_t {
half_t data[4];
};
// 16 * 16 matrix
struct alignas(32) packed_f32psum_t {
float data[8];
static constexpr packed_f32psum_t zeros() {
packed_f32psum_t result;
for (int i = 0; i < 8; i++) {
result.data[i] = 0;
}
return result;
};
};
struct packed_wscale_t {
half2_t data[WSCALES_PACK_SIZE / 2];
};
struct packed_ascale_t {
half2_t data[ASCALES_PACK_SIZE / 2];
};
using act_warp = std::array<packed_act_t, WARP_M_TILES>;
using wgt_warp = std::array<packed_wgt_t, WARP_N_TILES>;
using ascale_warp = std::array<packed_ascale_t, ASCALES_NUM_PACKS>;
using wscale_warp = std::array<packed_wscale_t, WSCALES_NUM_PACKS>;
using fpsum_warp = std::array<packed_fpsum_t, WARP_M_TILES * WARP_N_TILES>;
using f32psum_warp = std::array<packed_f32psum_t, WARP_M_TILES * WARP_N_TILES>;
using gated_fpsum_warp = std::array<packed_gated_fpsum_t, WARP_M_TILES * WARP_N_TILES>;
struct BlockInfo {
int bm;
int bn;
int numBlocksM;
int numBlocksN;
};
__device__ __forceinline__
static packed_f32psum_t mma_f16xf16_f32(packed_fpsum_t a, packed_fpsum_t b, packed_f32psum_t psum) {
static_assert(std::is_same_v<half_t, half> || std::is_same_v<half_t, __nv_bfloat16>);
if constexpr (std::is_same_v<half_t, half>) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[0]), "=f"(psum.data[1]), "=f"(psum.data[2]), "=f"(psum.data[3])
:
"r"(*reinterpret_cast<uint *>(&a.data[0])),
"r"(*reinterpret_cast<uint *>(&a.data[1])),
"r"(*reinterpret_cast<uint *>(&a.data[2])),
"r"(*reinterpret_cast<uint *>(&a.data[3])),
"r"(*reinterpret_cast<uint *>(&b.data[0])),
"r"(*reinterpret_cast<uint *>(&b.data[1])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[0]), "f"(psum.data[1]), "f"(psum.data[2]), "f"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[4]), "=f"(psum.data[5]), "=f"(psum.data[6]), "=f"(psum.data[7])
:
"r"(*reinterpret_cast<uint *>(&a.data[0])),
"r"(*reinterpret_cast<uint *>(&a.data[1])),
"r"(*reinterpret_cast<uint *>(&a.data[2])),
"r"(*reinterpret_cast<uint *>(&a.data[3])),
"r"(*reinterpret_cast<uint *>(&b.data[2])),
"r"(*reinterpret_cast<uint *>(&b.data[3])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[4]), "f"(psum.data[5]), "f"(psum.data[6]), "f"(psum.data[7])
);
}
if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[0]), "=f"(psum.data[1]), "=f"(psum.data[2]), "=f"(psum.data[3])
:
"r"(*reinterpret_cast<uint *>(&a.data[0])),
"r"(*reinterpret_cast<uint *>(&a.data[1])),
"r"(*reinterpret_cast<uint *>(&a.data[2])),
"r"(*reinterpret_cast<uint *>(&a.data[3])),
"r"(*reinterpret_cast<uint *>(&b.data[0])),
"r"(*reinterpret_cast<uint *>(&b.data[1])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[0]), "f"(psum.data[1]), "f"(psum.data[2]), "f"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=f"(psum.data[4]), "=f"(psum.data[5]), "=f"(psum.data[6]), "=f"(psum.data[7])
:
"r"(*reinterpret_cast<uint *>(&a.data[0])),
"r"(*reinterpret_cast<uint *>(&a.data[1])),
"r"(*reinterpret_cast<uint *>(&a.data[2])),
"r"(*reinterpret_cast<uint *>(&a.data[3])),
"r"(*reinterpret_cast<uint *>(&b.data[2])),
"r"(*reinterpret_cast<uint *>(&b.data[3])),
// "r"(0), "r"(0), "r"(0), "r"(0)
"f"(psum.data[4]), "f"(psum.data[5]), "f"(psum.data[6]), "f"(psum.data[7])
);
}
return psum;
}
__device__ __forceinline__
static packed_fpsum_t packed_fp32_to_fp16(packed_f32psum_t input) {
packed_fpsum_t results;
for (int i = 0; i < 4; i++) {
results.data[i] = float22half2<half2_t>(float2(input.data[i * 2], input.data[i * 2 + 1]));
}
return results;
}
__device__ __forceinline__
static packed_f32psum_t packed_fp16_to_fp32(packed_fpsum_t input) {
packed_f32psum_t results;
for (int i = 0; i < 4; i++) {
float2 tmp = half22float2(input.data[i]);
results.data[i * 2] = tmp.x;
results.data[i * 2 + 1] = tmp.y;
}
return results;
}
__device__ __forceinline__
static fpsum_warp packed_fp32_to_fp16(f32psum_warp input) {
fpsum_warp results;
#pragma unroll
for (int i = 0; i < results.size(); i++) {
results[i] = packed_fp32_to_fp16(input[i]);
}
return results;
}
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t
__device__ __forceinline__
static void load_act(const packed_act_t *act, int k, int K, act_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE;
int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
if (pred) {
// out[i] = load(&act[((warpId * WARP_M_TILES + i) * K / WARP_K + k) * WARP_SIZE + laneId]);
out[i] = load(&act[((k * NUM_WARPS + warpId) * WARP_M_TILES + i) * WARP_SIZE + laneId]);
}
}
}
// weight: column major: [N / BLOCK_N, 1, K / WARP_K, WARP_N_TILES, WARP_SIZE] of packed_wgt_t
__device__ __forceinline__
static void load_wgt(const packed_wgt_t *wgt, int k, int K, wgt_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE;
// const packed_wgt_t *ptr = &wgt[(0 * K / WARP_K + k) * WARP_SIZE + laneId];
const packed_wgt_t *ptr = &wgt[(0 + k * WARP_N_TILES) * WARP_SIZE + laneId];
// int offset = K / WARP_K * WARP_SIZE;
#pragma unroll
for (int i = 0; i < WARP_N_TILES; i++) {
if (pred) {
// out[i] = load(&wgt[(i * K / WARP_K + k) * WARP_SIZE + laneId]);
// out[i] = load(&wgt[(i + k * WARP_N_TILES) * WARP_SIZE + laneId]);
out[i] = load(&ptr[i * WARP_SIZE]);
// ptr += offset;
}
}
}
// ascales: row major [M / BLOCK_M, K / group size, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t
__device__ __forceinline__
static void load_ascale(const packed_ascale_t *ascales, int group, int M, ascale_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE;
int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
for (int i = 0; i < ASCALES_NUM_PACKS; i++) {
if (pred && laneId < ASCALES_VALID_LANES) {
// out[i] = ascales[(group * M / WARP_M + warpId) * ASCALES_VALID_LANES * ASCALES_NUM_PACKS + i * ASCALES_VALID_LANES + laneId];
out[i] = ascales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES + i * ASCALES_VALID_LANES + laneId];
}
}
}
// wscales: column major [N / BLOCK_N, K / group size, 1, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t </del>
__device__ __forceinline__
static void load_wscale(const packed_wscale_t *wscales, int group, int N, wscale_warp &out, bool pred) {
int laneId = threadIdx.x % WARP_SIZE;
// static_assert(WSCALES_NUM_PACKS * WSCALES_VALID_LANES == 32);
// static_assert(sizeof(packed_wscale_t) == 8);
// const packed_wscale_t *ptr = &wscales[(group * WSCALES_NUM_PACKS + 0) * WSCALES_VALID_LANES + laneId];
// // const packed_wscale_t *ptr = (const packed_wscale_t *)((const char *)wscales) + ((group * WSCALES_NUM_PACKS + 0) * WSCALES_VALID_LANES + laneId) * sizeof(packed_wscale_t);
#pragma unroll
for (int i = 0; i < WSCALES_NUM_PACKS; i++) {
if (pred && laneId < WSCALES_VALID_LANES) {
// out[i] = wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId];
// out[i] = load(&wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId]);
out[i] = load(&wscales[(group * WSCALES_NUM_PACKS + i) * WSCALES_VALID_LANES + laneId]);
// out[i] = load(&ptr[i * WSCALES_VALID_LANES]);
}
}
}
// get {k}-th and {k+1}-th wscale from the block, k must be multiples of 2, k must be uniform across all lanes
__device__ __forceinline__
static half2_t broadcast_wscale(wscale_warp block, int k, int laneId) {
const int packIdx = k / (WSCALES_PACK_SIZE * WARP_SIZE);
const int srcLane = 4 * (k / WSCALES_PACK_SIZE) + laneId % 4;
const int elementIdx = k % WSCALES_PACK_SIZE / 2;
return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane);
}
// get {k}-th and {k+1}-th ascale from the block, k must be multiples of 2, k must be uniform across all lanes
__device__ __forceinline__
static half2_t broadcast_ascale(ascale_warp block, int k, int laneId) {
const int packIdx = k / (ASCALES_PACK_SIZE * WARP_SIZE);
const int srcLane = 8 * (k / ASCALES_PACK_SIZE) + laneId / 4;
const int elementIdx = k % ASCALES_PACK_SIZE / 2;
return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane);
}
template<typename F>
__device__ __forceinline__
static void apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, fpsum_warp &fpsum) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
half2_t asx[WARP_M_TILES];
half2_t asy[WARP_M_TILES];
for (int i = 0; i < WARP_M_TILES; i++) {
half2_t as = broadcast_ascale(ascale, i * 2, laneId);
asx[i] = half2_t(as.x, as.x);
asy[i] = half2_t(as.y, as.y);
}
for (int j = 0; j < WARP_N_TILES; j++) {
half2_t ws1 = broadcast_wscale(wscale, j * 4, laneId);
half2_t ws2 = broadcast_wscale(wscale, j * 4 + 2, laneId);
for (int i = 0; i < WARP_M_TILES; i++) {
auto &fsum = fpsum[i * WARP_N_TILES + j];
packed_psum_t psum = getpsum(i, j);
// constexpr int target = 0;
// if (threadIdx.x == 3 && j == 1 && i == 0) {
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
fsum.data[0] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[0]), __int2float_rn(psum.data[1]))), __hmul2(asx[i], ws1), fsum.data[0]);
fsum.data[1] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[2]), __int2float_rn(psum.data[3]))), __hmul2(asy[i], ws1), fsum.data[1]);
fsum.data[2] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[4]), __int2float_rn(psum.data[5]))), __hmul2(asx[i], ws2), fsum.data[2]);
fsum.data[3] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[6]), __int2float_rn(psum.data[7]))), __hmul2(asy[i], ws2), fsum.data[3]);
// if (threadIdx.x == 3 && j == 1 && i == 0) {
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
}
}
}
template<typename F>
__device__ __forceinline__
static void apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, f32psum_warp &fpsum) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
float2 asx[WARP_M_TILES];
float2 asy[WARP_M_TILES];
for (int i = 0; i < WARP_M_TILES; i++) {
half2_t as = broadcast_ascale(ascale, i * 2, laneId);
asx[i] = half22float2(half2_t(as.x, as.x));
asy[i] = half22float2(half2_t(as.y, as.y));
}
auto fma2 = [](float2 a, float2 b, float &cx, float &cy) ALWAYSINLINE {
cx += a.x * b.x;
cy += a.y * b.y;
};
for (int j = 0; j < WARP_N_TILES; j++) {
float2 ws1 = half22float2(broadcast_wscale(wscale, j * 4, laneId));
float2 ws2 = half22float2(broadcast_wscale(wscale, j * 4 + 2, laneId));
for (int i = 0; i < WARP_M_TILES; i++) {
auto &fsum = fpsum[i * WARP_N_TILES + j];
packed_psum_t psum = getpsum(i, j);
fma2(make_float2(__int2float_rn(psum.data[0]), __int2float_rn(psum.data[1])), asx[i] * ws1, fsum.data[0], fsum.data[1]);
fma2(make_float2(__int2float_rn(psum.data[2]), __int2float_rn(psum.data[3])), asy[i] * ws1, fsum.data[2], fsum.data[3]);
fma2(make_float2(__int2float_rn(psum.data[4]), __int2float_rn(psum.data[5])), asx[i] * ws2, fsum.data[4], fsum.data[5]);
fma2(make_float2(__int2float_rn(psum.data[6]), __int2float_rn(psum.data[7])), asy[i] * ws2, fsum.data[6], fsum.data[7]);
}
}
}
/**
* input: WARP_M of half (in shared memory, per-warp)
* output: [..., ASCALES_NUM_PACKS, ASCALES_VALID_LANES] in global memory, per-warp
*/
__device__ __forceinline__
static void pack_ascales(const half_t *input, packed_ascale_t *output) {
const int laneId = threadIdx.x % WARP_SIZE;
#pragma unroll
for (int j = 0; j < ASCALES_NUM_PACKS; j++) {
if (laneId < ASCALES_VALID_LANES) {
packed_ascale_t tmp;
#pragma unroll
for (int i = 0; i < ASCALES_PACK_SIZE; i += 2) {
tmp.data[i / 2].x = input[j * ASCALES_PACK_SIZE * WARP_SIZE + laneId / 8 * 8 * ASCALES_PACK_SIZE + laneId % 8 + i * 8];
tmp.data[i / 2].y = input[j * ASCALES_PACK_SIZE * WARP_SIZE + laneId / 8 * 8 * ASCALES_PACK_SIZE + laneId % 8 + (i + 1) * 8];
}
output[j * ASCALES_VALID_LANES + laneId] = tmp;
}
}
}
/**
* input: WARP_N of half (in shared memory, per-warp)
* output: [..., WSCALES_NUM_PACKS, WSCALES_VALID_LANES] in global memory, per-warp
*/
__device__ __forceinline__
static void pack_wscales(const half_t *input, packed_wscale_t *output) {
const int laneId = threadIdx.x % WARP_SIZE;
#pragma unroll
for (int j = 0; j < WSCALES_NUM_PACKS; j++) {
if (laneId < WSCALES_VALID_LANES) {
packed_wscale_t tmp;
#pragma unroll
for (int i = 0; i < WSCALES_PACK_SIZE; i += 2) {
tmp.data[i / 2] = *reinterpret_cast<const half2_t *>(&input[j * WSCALES_PACK_SIZE * WARP_SIZE + laneId / 4 * 4 * WSCALES_PACK_SIZE + laneId % 4 * 2 + i * 4]);
}
store(&output[j * WSCALES_VALID_LANES + laneId], tmp);
}
}
}
struct unpack_fpsum {
// +8 to prevent bank conflicts
using matrix_t = half_t[8][WARP_N + 8];
static constexpr int SHMEM_SIZE = sizeof(matrix_t);
static constexpr int PACK_SIZE = WARP_N / WARP_SIZE;
using pack_t = std::array<half_t, PACK_SIZE>;
// F (int rowId, pack_t &pack)
template<typename ...F>
__device__ __forceinline__
void operator()(fpsum_warp fpsum, half_t *output, int stride, void *shmem, F &&...plugins) {
const int laneId = threadIdx.x % WARP_SIZE;
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
// pack_t reduce_tmp;
// constexpr bool enableReduce = !std::is_void_v<FuncReduce>;
// if constexpr (enableReduce) {
// reduce_tmp.fill(reduce_initval);
// // reduce_tmp = load<true>(reinterpret_cast<pack_t *>(&reduce_result[laneId * PACK_SIZE]));
// }
// auto doReduce = [&reduce_tmp](pack_t pack) {
// if constexpr (enableReduce) {
// for (int i = 0; i < PACK_SIZE; i++) {
// reduce_tmp[i] = FuncReduce()(reduce_tmp[i], pack[i]);
// }
// }
// };
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
packed_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
int row = laneId / 4;
int col = laneId % 4 * 2 + j * INSN_N;
*reinterpret_cast<half2_t *>(&mat[row][col + 0]) = fsum.data[0];
*reinterpret_cast<half2_t *>(&mat[row][col + 8]) = fsum.data[2];
}
__syncwarp();
#pragma unroll
for (int row = 0; row < 8; row++) {
pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]);
// if constexpr (enableReduce) {
// doReduce(pack);
// }
(plugins(i * INSN_M + row, pack), ...);
store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + row) * stride + laneId * PACK_SIZE]), pack);
}
__syncwarp();
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
packed_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
int row = laneId / 4;
int col = laneId % 4 * 2 + j * INSN_N;
*reinterpret_cast<half2_t *>(&mat[row][col + 0]) = fsum.data[1];
*reinterpret_cast<half2_t *>(&mat[row][col + 8]) = fsum.data[3];
}
__syncwarp();
#pragma unroll
for (int row = 0; row < 8; row++) {
pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]);
// if constexpr (enableReduce) {
// doReduce(pack);
// }
(plugins(i * INSN_M + 8 + row, pack), ...);
store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + 8 + row) * stride + laneId * PACK_SIZE]), pack);
}
__syncwarp();
}
// if (enableReduce) {
// store<true>(reinterpret_cast<pack_t *>(&reduce_result[laneId * PACK_SIZE]), reduce_tmp);
// }
}
};
struct EpilogueDefault {
struct Arguments {};
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, shmem[warpId]);
}
};
struct EpilogueNop {
struct Arguments {};
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
}
};
template<typename ...Epilogues>
struct EpilogueCombination {
using Arguments = std::tuple<typename Epilogues::Arguments...>;
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, half_t *out, int M, int N, int K, Arguments args) {
// this function makes intellisense crashes :(
#if __INTELLISENSE__
__trap(); // should not happen when actually compiling
#else
std::tuple<Epilogues...> epilogues;
auto run = [&]<size_t idx>() {
std::get<idx>(epilogues).operator()(binfo, fpsum, out, M, N, K, std::get<idx>(args));
};
auto foreach = [&]<size_t ...Is>(std::index_sequence<Is...>) {
(run.template operator()<Is>(), ...);
};
foreach(std::make_index_sequence<sizeof...(Epilogues)>());
#endif
}
};
};
class GEMM_W4A4 : public GEMMBase<GEMMConfig_W4A4> {
public:
template<bool ACT_UNSIGNED>
__device__ __forceinline__
static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt) {
packed_psum_t psum;
if constexpr (!ACT_UNSIGNED) {
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
}
if constexpr (ACT_UNSIGNED) {
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
"r"(0), "r"(0), "r"(0), "r"(0)
// "r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
}
return psum;
}
// template<bool si>
template<bool use_unsigned>
__device__ __forceinline__
static void quantize_w4a4_from_fpsum_warp(const packed_fpsum_t (&fpsum)[INSN_K / INSN_N], packed_act_t &output, half_t *output_scale) {
const int laneId = threadIdx.x % WARP_SIZE;
constexpr float QVALUE_MAX_SIGNED = 7.0f;
constexpr float QVALUE_MAX_UNSIGNED = 15.0f;
constexpr float RECPI_QVALUE_MAX_SIGNED = 1 / QVALUE_MAX_SIGNED;
constexpr float RECPI_QVALUE_MAX_UNSIGNED = 1 / QVALUE_MAX_UNSIGNED;
constexpr float QVALUE_MAX = use_unsigned ? QVALUE_MAX_UNSIGNED : QVALUE_MAX_SIGNED;
constexpr float RECPI_QVALUE_MAX = use_unsigned ? RECPI_QVALUE_MAX_UNSIGNED : RECPI_QVALUE_MAX_SIGNED;
// constexpr int QUANTIZE_BITMASK = 0xf;
// 0 for row 0-7; 1 for row 8-15
half2_t input[2][INSN_K / INSN_N * 2];
#pragma unroll
for (int i = 0; i < INSN_K / INSN_N; i++) {
input[0][i * 2 + 0] = fpsum[i].data[0];
input[0][i * 2 + 1] = fpsum[i].data[2];
input[1][i * 2 + 0] = fpsum[i].data[1];
input[1][i * 2 + 1] = fpsum[i].data[3];
}
half_t maxvalue[2];
maxvalue[0] = 0;
maxvalue[1] = 0;
#pragma unroll
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
half2_t abs0 = __habs2(input[0][i]);
half2_t abs1 = __habs2(input[1][i]);
maxvalue[0] = __hmax(maxvalue[0], __hmax(abs0.x, abs0.y));
maxvalue[1] = __hmax(maxvalue[1], __hmax(abs1.x, abs1.y));
}
#pragma unroll
for (int mask = 2; mask > 0; mask /= 2) {
maxvalue[0] = __hmax(maxvalue[0], __shfl_xor_sync(~0, maxvalue[0], mask));
maxvalue[1] = __hmax(maxvalue[1], __shfl_xor_sync(~0, maxvalue[1], mask));
}
maxvalue[0] = __shfl_sync(~0, maxvalue[0], laneId / 4 * 4);
maxvalue[1] = __shfl_sync(~0, maxvalue[1], laneId / 4 * 4);
float scale[2];
// scale[0] = float(maxvalue[0]) / QVALUE_MAX;
// scale[1] = float(maxvalue[1]) / QVALUE_MAX;
scale[0] = float(maxvalue[0]) * RECPI_QVALUE_MAX;
scale[1] = float(maxvalue[1]) * RECPI_QVALUE_MAX;
if (laneId % 4 == 0) {
output_scale[laneId / 4] = half_t(scale[0]);
output_scale[laneId / 4 + 8] = half_t(scale[1]);
}
float rscale[2];
// rscale[0] = QVALUE_MAX / float(maxvalue[0]);
// rscale[1] = QVALUE_MAX / float(maxvalue[1]);
rscale[0] = cuda_frcp(scale[0]);
rscale[1] = cuda_frcp(scale[1]);
uint32_t qpacks[2][INSN_K / INSN_M * 2];
#pragma unroll
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
#pragma unroll
for (int j = 0; j < 2; j++) {
// half2_t hval = __hmul2(input[j][i], half2_t(rscale[j], rscale[j]));
// float2 fval = half22float2(hval);
float2 fval = half22float2(input[j][i]) * make_float2(rscale[j], rscale[j]);
qpacks[j][i] = quantize_float2<4, use_unsigned>(fval) << (laneId % 4 * 8);
}
}
// 2 * 8 * 2 = 32 instructions => 256 cycles
#pragma unroll
for (int mask = 1; mask <= 2; mask *= 2) {
#pragma unroll
for (int i = 0; i < INSN_K / INSN_M * 2; i++) {
#pragma unroll
for (int j = 0; j < 2; j++) {
qpacks[j][i] |= __shfl_xor_sync(~0, qpacks[j][i], mask);
}
}
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical qpacks now
#pragma unroll
for (int i = 0; i < 4; i++) {
if (laneId % 4 == i) {
output.x = qpacks[0][0 + i];
output.y = qpacks[1][0 + i];
output.z = qpacks[0][4 + i];
output.w = qpacks[1][4 + i];
}
}
}
// loads act of [WARP_M, WARP_N] and stores to fpsum_warp
struct load_act_to_fpsum {
using matrix_t = half_t[WARP_M][WARP_N + 8];
static constexpr size_t SHMEM_SIZE = sizeof(matrix_t);
__device__ __forceinline__
void operator()(const half_t *input, int stride, fpsum_warp &out, void *shmem /*, const half_t *smooth_factor */) {
const int laneId = threadIdx.x % WARP_SIZE;
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
constexpr int PACK_SIZE = WARP_N / WARP_SIZE;
using packed_input = std::array<half_t, PACK_SIZE>;
// packed_input pack_smooth;
// if (smooth_factor) {
// pack_smooth = load(reinterpret_cast<const packed_input *>(input + laneId * PACK_SIZE));
// }
for (int row = 0; row < WARP_M; row++) {
auto pack = load(reinterpret_cast<const packed_input *>(input + row * stride + laneId * PACK_SIZE));
// if (smooth_factor) {
// for (int i = 0; i < PACK_SIZE; i++) {
// pack[i] *= pack_smooth[i];
// }
// }
store<true>(reinterpret_cast<packed_input *>(&mat[row][laneId * PACK_SIZE]), pack);
}
__syncwarp();
for (int m = 0; m < WARP_M_TILES; m++) {
for (int n = 0; n < WARP_N_TILES; n++) {
const int row = m * INSN_M + laneId % 16;
const int col = n * INSN_N + laneId / 16 * 8;
uint4 tmp;
ldmatrix(&mat[row][col], tmp);
*reinterpret_cast<uint4 *>(&out[m * WARP_N_TILES + n]) = tmp;
}
}
__syncwarp();
}
};
/**
* each warp quantizes a INSN_M * INSN_K (16 * 64) matrix
* input is per-warp (in global memory)
* output is per-thread (in regs)
* output_scale is per-warp (in shared memory)
* shmem must be at least INSN_M * INSN_K * sizeof(element) (16 * 64 * 0.5 = 512 Bytes)
* default to quantize activation, if quantize weight, input should be column-majored and output should be transposed ({x, y, z, w} = {x, z, y, w})
*/
__device__ __forceinline__
static void quantize_w4a4_warp(const half_t *input, int stride, packed_act_t &output, half_t *output_scale, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
constexpr int QUANTIZE_BITWIDTH = 4;
constexpr int QVALUE_MAX = 7; // 4 bit => [-8, 7]
// 1 lane = 1 pack
// 1 warp = 32 lanes = 32 packs = 1 packwarp
// a pack is {a0, ..., a7} in figure https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a
// PACK_SIZE * 4 = INSN_K / 2
constexpr int PACK_SIZE = INSN_K / 8; // = 8 for 4bit
constexpr int NUM_PACKS_PER_ROW = INSN_K / PACK_SIZE;
constexpr int NUM_ROWS_PER_PACKWARP = PACK_SIZE * WARP_SIZE / INSN_K;
constexpr int NUM_PACKWARPS = INSN_M / NUM_ROWS_PER_PACKWARP;
using packed_input = std::array<half_t, PACK_SIZE>;
packed_input packs[NUM_PACKWARPS];
// load
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
int rowId = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW;
int colId = laneId % NUM_PACKS_PER_ROW * PACK_SIZE;
packs[i] = load(reinterpret_cast<const packed_input *>(input + rowId * stride + colId));
}
// find max
half_t maxvalue[NUM_PACKWARPS];
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __habs(packs[i][0]);
#pragma unroll
for (int j = 1; j < PACK_SIZE; j++) {
maxvalue[i] = __hmax(maxvalue[i], __habs(packs[i][j]));
}
}
// warp reduce (max)
#pragma unroll
for (int mask = NUM_PACKS_PER_ROW / 2; mask > 0; mask /= 2) {
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __hmax(maxvalue[i], __shfl_xor_sync(~0, maxvalue[i], mask));
}
}
// broadcast (max)
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
maxvalue[i] = __shfl_sync(~0, maxvalue[i], laneId / NUM_PACKS_PER_ROW * NUM_PACKS_PER_ROW);
}
// quantize
using matrix_t = uint32_t[INSN_M][NUM_PACKS_PER_ROW];
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
half_t scale = maxvalue[i] / half_t(QVALUE_MAX);
half_t rscale = half_t(QVALUE_MAX) / maxvalue[i];
if (laneId % NUM_PACKS_PER_ROW == 0) {
output_scale[i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW] = scale;
}
uint32_t qpack = 0;
// #pragma unroll
// for (int j = 0; j < PACK_SIZE; j++) {
// int intvalue = __half2int_rn(packs[i][j] / scale);
// intvalue = clamp(intvalue, -QVALUE_MAX, QVALUE_MAX);
// qpack |= (intvalue & QUANTIZE_BITMASK) << (QUANTIZE_BITWIDTH * j);
// }
#pragma unroll
for (int j = 0; j < PACK_SIZE; j += 2) {
half2_t hval = __hmul2(half2_t(rscale, rscale), half2_t(packs[i][j], packs[i][j + 1]));
qpack |= quantize_float2<QUANTIZE_BITWIDTH, false>(half22float2(hval)) << (j * QUANTIZE_BITWIDTH);
}
mat[i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW][laneId % NUM_PACKS_PER_ROW] = qpack;
}
__syncwarp();
// convert to imma format
int row = laneId % 16;
int col = laneId / 16 * 4;
ldmatrix(&mat[row][col], output);
__syncwarp();
}
// each thread block (1 warp) quantize WARP_M * WARP_K tile (32 * 64)
struct quantize_w4a4_act_kernel {
__device__
void operator()(const half_t *input, packed_act_t *output, packed_ascale_t *oscales, int K) {
const int laneId = threadIdx.x % WARP_SIZE;
const int bm = blockIdx.x / (BLOCK_M / WARP_M);
const int bk = blockIdx.y;
const int warpId = blockIdx.x % (BLOCK_M / WARP_M);
const int row = blockIdx.x * WARP_M;
const int col = blockIdx.y * WARP_K;
__shared__ alignas(128) half_t oscale_shmem[WARP_M];
__shared__ alignas(128) uint8_t tmp_shmem[INSN_M * INSN_K / 2];
for (int tileId = 0; tileId < WARP_M_TILES; tileId++) {
packed_act_t tmpout;
quantize_w4a4_warp(
input + (row + tileId * INSN_M) * K + col,
K,
tmpout,
oscale_shmem + tileId * INSN_M,
tmp_shmem
);
store(&output[(((bm * K / WARP_K + bk) * NUM_WARPS + warpId) * WARP_M_TILES + tileId) * WARP_SIZE + laneId], tmpout);
}
// if (threadIdx.x == 0) {
// printf("Block (%d, %d) => offset = %d\n", blockIdx.x, blockIdx.y, (bm * K / WARP_K + bk) * NUM_WARPS + warpId);
// }
pack_ascales(oscale_shmem, &oscales[((bm * K / WARP_K + bk) * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
}
};
// each thread block (1 warp) quantize WARP_N * WARP_K tile (128 * 64)
struct quantize_w4a4_wgt_kernel {
__device__
void operator()(const half_t *input, packed_wgt_t *output, packed_wscale_t *oscales, int K) {
const int laneId = threadIdx.x % WARP_SIZE;
const int bn = blockIdx.x / (BLOCK_N / WARP_N);
const int bk = blockIdx.y;
const int col = blockIdx.x * WARP_N;
const int row = blockIdx.y * WARP_K;
__shared__ alignas(128) half_t oscale_shmem[WARP_N];
__shared__ alignas(128) uint8_t tmp_shmem[INSN_M * INSN_K / 2];
for (int tileId = 0; tileId < WARP_N_TILES; tileId++) {
packed_wgt_t tmpout;
quantize_w4a4_warp(
input + (col + tileId * INSN_N) * K + row,
K,
tmpout,
oscale_shmem + tileId * INSN_N,
tmp_shmem
);
std::swap(tmpout.y, tmpout.z);
store(&output[((bn * K / WARP_K + bk) * WARP_N_TILES + tileId) * WARP_SIZE + laneId], tmpout);
}
pack_wscales(oscale_shmem, &oscales[(bn * K / WARP_K + bk) * WSCALES_NUM_PACKS * WSCALES_VALID_LANES]);
}
};
template<bool ACT_UNSIGNED, typename T>
__device__ __forceinline__
static void compute(act_warp A, wgt_warp W, ascale_warp ascale, wscale_warp wscale, T &fpsum) {
apply_scales([&](int i, int j) {
return mma<ACT_UNSIGNED>(A[i], W[j]);
}, ascale, wscale, fpsum);
}
__device__ __forceinline__
static void checkNan(fpsum_warp fpsum, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
for (int i = 0; i < fpsum.size(); i++) {
for (int j = 0; j < 4; j++) {
bool abnormal = !isfinite((float)fpsum[i].data[j].x) || !isfinite((float)fpsum[i].data[j].y);
if (abnormal) {
printf("abnormal value detected at block.x=%d block.y=%d warpId=%d laneId=%d fpsum_warp (%s) i=%d j=%d data.x=%f data.y=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
info,
i, j,
(float)fpsum[i].data[j].x,
(float)fpsum[i].data[j].y
);
__trap();
}
}
}
#endif
}
__device__ __forceinline__
static void checkNan(packed_f32psum_t fpsum, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
for (int j = 0; j < 8; j++) {
bool abnormal = !isfinite(fpsum.data[j]);
if (abnormal) {
printf("abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_f32psum_t (%s) j=%d data=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
info,
j,
fpsum.data[j]
);
__trap();
}
}
#endif
}
__device__ __forceinline__
static void checkNan(packed_fpsum_t fpsum, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
for (int j = 0; j < 4; j++) {
bool abnormal = !isfinite((float)fpsum.data[j].x) || !isfinite((float)fpsum.data[j].y);
if (abnormal) {
printf("abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_fpsum_t (%s) j=%d data.x=%f data.y=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
info,
j,
(float)fpsum.data[j].x,
(float)fpsum.data[j].y
);
__trap();
}
}
#endif
}
__device__ __forceinline__
static void checkNan(float data, const char *info = "") {
#if ENABLE_NAN_CHECK
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
bool abnormal = !isfinite(data);
if (abnormal) {
printf("abnormal value detected at bm=%d bn=%d warpId=%d laneId=%d packed_fpsum_t (%s) data=%f\n",
blockIdx.x, blockIdx.y,
warpId, laneId,
info,
data
);
__trap();
}
#endif
}
// out: [M / BLOCK_M, N / BLOCK_N, NUM_WARPS, 1, NUM_M_TILES, NUM_N_TILES, WARP_SIZE] of fpsum_warp
template<typename Epilogue, bool ACT_UNSIGNED>
__device__ __forceinline__
static void gemm_w4a4_block(
const BlockInfo binfo,
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
// const packed_wscale_t *bias_ptr,
half_t *out,
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
bool alwaysfalse)
{
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
act_warp A[NUM_STAGES]; // 8
wgt_warp W[NUM_STAGES]; // 32
ascale_warp ascale[NUM_STAGES]; // 1
wscale_warp wscale[NUM_STAGES]; // 2
fpsum_warp fpsum; // 64
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true);
// load_wscale<false>(wscales, wscale[2], true);
for (int k = 0; k < NUM_STAGES - 1; k++) {
load_act(act, k, K, A[k], true);
load_wgt(wgt, k, K, W[k], true);
load_ascale(ascales, k, M, ascale[k], true);
load_wscale(wscales, k, N, wscale[k], true);
}
for (auto &pack : fpsum) {
#if 1
for (int i = 0; i < 4; i++) {
pack.data[i].x = 0;
pack.data[i].y = 0;
}
#else
for (int i = 0; i < 8; i++) {
pack.data[i] = 0;
}
#endif
}
int dummy = 0;
for (int k1 = 0; k1 < K / WARP_K; k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < K / WARP_K;
load_act(act, nextk, K, A[idx], pred);
load_wgt(wgt, nextk, K, W[idx], pred);
load_ascale(ascales, nextk, M, ascale[idx], pred);
load_wscale(wscales, nextk, N, wscale[idx], pred);
// load_wscale<false>(wscales, wscale[idx], pred);
// __syncthreads();
// if (alwaysfalse) {
// dummy = clock();
// }
compute<ACT_UNSIGNED>(A[k2], W[k2], ascale[k2], wscale[k2], fpsum);
if (alwaysfalse) {
dummy = clock();
}
// asm volatile ("membar.cta;");
}
}
unused_var(dummy, alwaysfalse);
#if 0
auto f16psum = packed_fp32_to_fp16(fpsum);
#else
auto f16psum = fpsum;
#endif
CHECK_NAN(f16psum, "f16psum");
Epilogue()(binfo, f16psum, out, M, N, K, epilogueArgs);
}
template<bool FUSE_GELU, bool USE_UNSIGNED>
struct EpilogueQuantize {
struct Arguments {
packed_act_t *qout;
packed_ascale_t *oscales;
half_t shift_value;
const packed_wscale_t *smooth_factor;
};
static constexpr int NUM_PACKS = INSN_K / INSN_N;
static constexpr int NUM_GROUPS = WARP_N_TILES / NUM_PACKS;
__device__ __forceinline__
void apply_quantize(fpsum_warp fpsum, half_t *out, int M, int N, int K, packed_act_t *qout, packed_ascale_t *oscales, half_t shift_value, const packed_wscale_t *smooth_factor) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ half_t oscale_shmem[NUM_WARPS][WARP_M];
wscale_warp smooth;
load_wscale(smooth_factor, 0, N, smooth, true);
#pragma unroll
for (int group = 0; group < NUM_GROUPS; group++) {
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
packed_fpsum_t tmp[NUM_PACKS];
#pragma unroll
for (int j = 0; j < NUM_PACKS; j++) {
half2_t ws1 = broadcast_wscale(smooth, (group * NUM_PACKS + j) * 4, laneId);
half2_t ws2 = broadcast_wscale(smooth, (group * NUM_PACKS + j) * 4 + 2, laneId);
#pragma unroll
for (int k = 0; k < 4; k++) {
half2_t src = fpsum[i * WARP_N_TILES + group * NUM_PACKS + j].data[k];
half2_t &dst = tmp[j].data[k];
// dst.x = gelu(src.x);
// dst.y = gelu(src.y);
if constexpr (FUSE_GELU) {
dst = gelu_half2(src);
} else {
dst = src;
}
dst += half2_t(shift_value, shift_value);
// dst = src;
}
auto h2div = [](half2_t a, half2_t b) ALWAYSINLINE {
float2 af = half22float2(a);
float2 bf = half22float2(b);
float2 of;
of.x = __fdividef(af.x, bf.x);
of.y = __fdividef(af.y, bf.y);
return float22half2<half2_t>(of);
};
tmp[j].data[0] = h2div(tmp[j].data[0], ws1);
tmp[j].data[1] = h2div(tmp[j].data[1], ws1);
tmp[j].data[2] = h2div(tmp[j].data[2], ws2);
tmp[j].data[3] = h2div(tmp[j].data[3], ws2);
}
packed_act_t qresult;
quantize_w4a4_from_fpsum_warp<USE_UNSIGNED>(tmp, qresult, &oscale_shmem[warpId][i * INSN_M]);
store(&qout[((group * NUM_WARPS + warpId) * WARP_M_TILES + i) * WARP_SIZE + laneId], qresult);
}
__syncwarp();
pack_ascales(&oscale_shmem[warpId][0], &oscales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
__syncwarp();
}
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
apply_quantize(
fpsum, out, M, N, K,
args.qout + (bm * N / WARP_K + bn * NUM_GROUPS) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
args.oscales + (bm * N / WARP_K + bn * NUM_GROUPS) * NUM_WARPS * ASCALES_NUM_PACKS * ASCALES_VALID_LANES,
args.shift_value,
args.smooth_factor + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES
);
}
};
// using EpilogueQuantizeFuseGelu = EpilogueQuantize<true>;
template<int rank = 32>
struct Lora {
static_assert(rank % 16 == 0);
static constexpr int LORA_RANK = rank;
static constexpr int LORA_M_TILES = WARP_M / 16;
static constexpr int LORA_R_TILES = LORA_RANK / 16;
static constexpr int LORA_N_TILES = WARP_N / 16;
static_assert(LORA_M_TILES == WARP_M_TILES);
static_assert(LORA_N_TILES == WARP_N_TILES);
// lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R]
// lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N]
// we use fp32 for lora activation since there's no bf16 reduction in sm_89 :(
using lora_act_warp = std::array<packed_f32psum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_act16_warp = std::array<packed_fpsum_t, LORA_M_TILES * LORA_R_TILES>;
using lora_wgt_warp = std::array<packed_fpsum_t, LORA_N_TILES * LORA_R_TILES>;
using scale_t = std::array<float, LORA_R_TILES>;
// lora_wgt: [N / 16, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
__device__ __forceinline__
static lora_wgt_warp load_lora_wgt(const packed_fpsum_t *ptr) {
const int laneId = threadIdx.x % WARP_SIZE;
const packed_fpsum_t *ptr_lane = ptr + laneId;
lora_wgt_warp result;
#if 0
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
result[n * LORA_R_TILES + r] = load(ptr_lane + (n * LORA_R_TILES + r) * WARP_SIZE);
}
}
#else
unrolled_loop<LORA_N_TILES>([&]<int n>() {
unrolled_loop<LORA_R_TILES>([&]<int r>() {
constexpr int offset = (n * LORA_R_TILES + r) * WARP_SIZE;
result[n * LORA_R_TILES + r] = load(ptr_lane + offset);
});
});
#endif
return result;
}
// lora_act: [M / BLOCK_M, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
__device__ __forceinline__
static lora_act16_warp load_lora_act(const float *ptr, scale_t scales) {
const int laneId = threadIdx.x % WARP_SIZE;
const float *ptrlane = ptr + laneId;
lora_act16_warp result;
#if 0
#pragma unroll
for (int i = 0; i < LORA_M_TILES * LORA_R_TILES; i++) {
packed_f32psum_t tmp;
#pragma unroll
for (int j = 0; j < 8; j++) {
const int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
tmp.data[j] = ptrlane[offset];
// tmp.data[j] = ptr[i * 8 * WARP_SIZE + j * WARP_SIZE + laneId];
}
CHECK_NAN(tmp, "load_lora_act.tmp");
result[i] = packed_fp32_to_fp16(tmp);
}
#else
unrolled_loop<LORA_M_TILES>([&]<int m>() {
unrolled_loop<LORA_R_TILES>([&]<int r>{
constexpr int i = m * LORA_R_TILES + r;
packed_f32psum_t tmp;
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
tmp.data[j] = ptrlane[offset] * scales[r];
});
CHECK_NAN(tmp, "load_lora_act.tmp");
result[i] = packed_fp32_to_fp16(tmp);
});
});
#endif
return result;
}
// no vector reduction in sm_89 :(
__device__ __forceinline__
static void reduce_lora_act(float *ptr, lora_act_warp val) {
const int laneId = threadIdx.x % WARP_SIZE;
float *ptrlane = ptr + laneId;
// #pragma unroll
// for (int i = 0; i < LORA_M_TILES * LORA_R_TILES; i++) {
// #pragma unroll
// for (int j = 0; j < 8; j++) {
// int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
// reduce_add(&ptrlane[offset], val[i].data[j]);
// }
// }
unrolled_loop<LORA_M_TILES * LORA_R_TILES>([&]<int i>() {
unrolled_loop<8>([&]<int j>() {
constexpr int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
reduce_add(&ptrlane[offset], val[i].data[j]);
});
});
}
// __device__ __forceinline__
// static void reduce_lora_act(float *ptr, lora_act_warp val, int m) {
// const int laneId = threadIdx.x % WARP_SIZE;
// float *ptrlane = ptr + laneId + m * LORA_R_TILES * 8 * WARP_SIZE;
// unrolled_loop<LORA_R_TILES>([&]<int r>() {
// unrolled_loop<8>([&]<int j>() {
// constexpr int offset = r * 8 * WARP_SIZE + j * WARP_SIZE;
// reduce_add(&ptrlane[offset], val[m * LORA_R_TILES + r].data[j]);
// });
// });
// }
struct EpilogueLoraUp {
struct Arguments {
const float *lora_act;
const packed_fpsum_t *lora_wgt_up;
scale_t scales;
};
__device__ __forceinline__
static void apply_lora_up(fpsum_warp &fpsum, half_t *out, int M, int N, int K, const float *act, const packed_fpsum_t *wgt, const scale_t scales, const BlockInfo binfo) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
lora_act16_warp lora_act = load_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), scales);
lora_wgt_warp lora_wgt = load_lora_wgt(wgt);
for (int m = 0; m < LORA_M_TILES; m++) {
for (int n = 0; n < LORA_N_TILES; n++) {
packed_f32psum_t psum = packed_fp16_to_fp32(fpsum[m * WARP_N_TILES + n]);
for (int r = 0; r < LORA_R_TILES; r++) {
CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt");
psum = mma_f16xf16_f32(lora_act[m * LORA_R_TILES + r], lora_wgt[n * LORA_R_TILES + r], psum);
}
fpsum[m * WARP_N_TILES + n] = packed_fp32_to_fp16(psum);
}
}
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
CHECK_NAN(fpsum, "fpsum");
if constexpr (rank == 0) {
return;
}
apply_lora_up(
fpsum, out, M, N, K,
args.lora_act + bm * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_up + bn * (BLOCK_N / 16) * LORA_R_TILES * WARP_SIZE,
args.scales,
binfo // for debug
);
CHECK_NAN(fpsum, "fpsum");
}
};
struct EpilogueLoraDown {
struct Arguments {
const packed_fpsum_t *lora_wgt_down;
float *lora_act;
};
__device__ __forceinline__
static void apply_lora_down(fpsum_warp &fpsum, half_t *out, int M, int N, int K, float *act, const packed_fpsum_t *wgt) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
lora_act_warp lora_act;
lora_act.fill(packed_f32psum_t::zeros());
lora_wgt_warp lora_wgt = load_lora_wgt(wgt);
// clock_t dummy = 0;
#pragma unroll
for (int m = 0; m < LORA_M_TILES; m++) {
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
auto &psum = lora_act[m * LORA_R_TILES + r];
CHECK_NAN(fpsum[m * WARP_N_TILES + n], "apply_lora_down.fpsum");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "apply_lora_down.lora_wgt");
psum = mma_f16xf16_f32(fpsum[m * WARP_N_TILES + n], lora_wgt[n * LORA_R_TILES + r], psum);
CHECK_NAN(psum, "apply_lora_down.psum");
}
}
// reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act, m);
// if (alwaysfalse) {
// dummy = clock();
// }
}
reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act);
// unused_var(dummy, alwaysfalse);
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
if constexpr (rank == 0) {
return;
}
apply_lora_down(
fpsum, out, M, N, K,
args.lora_act + bm * (NUM_WARPS * LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE),
args.lora_wgt_down + bn * (BLOCK_N / 16) * LORA_R_TILES * WARP_SIZE
);
}
};
struct quantize_w4a4_fuse_lora_kernel {
static constexpr size_t SHMEM_PER_WARP = ceilDiv<size_t>(load_act_to_fpsum::SHMEM_SIZE, 128) * 128;
static constexpr size_t SHMEM_SIZE = SHMEM_PER_WARP * NUM_WARPS;
struct Arguments {
const half_t *input;
const packed_wscale_t *smooth_factor;
packed_act_t *output;
packed_ascale_t *oscales;
const packed_fpsum_t *lora_wgt_down;
float *lora_act;
int M, N;
};
__device__ __forceinline__
void operator()(Arguments args)
{
const BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
const int bm = binfo.bm;
const int bn = binfo.bn;
const int warpId = threadIdx.x / WARP_SIZE;
const int m_offset = bm * BLOCK_M + warpId * WARP_M;
const int n_offset = bn * BLOCK_N;
extern __shared__ uint8_t shmem[];
fpsum_warp fpsum;
// FIXME: smooth factor should change to EpilogueQuantize
load_act_to_fpsum()(
args.input + m_offset * args.N + n_offset,
args.N,
fpsum,
shmem + warpId * SHMEM_PER_WARP
// args.smooth_factor ? args.smooth_factor + n_offset : nullptr
);
CHECK_NAN(fpsum, "fpsum");
// for (int i = 0; i < 16; i++) {
// printf("bm=%d bn=%d warp=%d lane=%d fpsum[%d][0:1]=%f %f\n",
// bm, bn, warpId, threadIdx.x % WARP_SIZE, i,
// (float)fpsum[i].data[0].x, (float)fpsum[i].data[0].y);
// }
EpilogueLoraDown()(binfo, fpsum, nullptr, args.M, args.N, 0, typename EpilogueLoraDown::Arguments{
.lora_wgt_down = args.lora_wgt_down,
.lora_act = args.lora_act,
});
EpilogueQuantize<false, false>()(binfo, fpsum, nullptr, args.M, args.N, 0, EpilogueQuantize<false, false>::Arguments{
.qout = args.output,
.oscales = args.oscales,
.shift_value = 0,
.smooth_factor = args.smooth_factor
});
}
};
};
struct EpilogueBias {
struct Arguments {
const packed_wscale_t *bias; // [N / BLOCK_N, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t
};
__device__ __forceinline__
void apply_bias(fpsum_warp &fpsum, half_t *out, int M, int N, int K, const packed_wscale_t *bias) {
const int laneId = threadIdx.x % WARP_SIZE;
wscale_warp b;
load_wscale(bias, 0, N, b, true);
for (int j = 0; j < WARP_N_TILES; j++) {
half2_t b1 = broadcast_wscale(b, j * 4, laneId);
half2_t b2 = broadcast_wscale(b, j * 4 + 2, laneId);
for (int i = 0; i < WARP_M_TILES; i++) {
auto &fsum = fpsum[i * WARP_N_TILES + j];
fsum.data[0] = __hadd2(fsum.data[0], b1);
fsum.data[1] = __hadd2(fsum.data[1], b1);
fsum.data[2] = __hadd2(fsum.data[2], b2);
fsum.data[3] = __hadd2(fsum.data[3], b2);
}
}
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int bn = binfo.bn;
apply_bias(
fpsum, out, M, N, K,
args.bias + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES
);
}
};
struct EpilogueGelu {
struct Arguments {};
// static constexpr float SHIFT_VALUE = 0.171875f;
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp &fpsum, half_t *out, int M, int N, int K, Arguments args) {
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
#pragma unroll
for (int k = 0; k < 4; k++) {
half2_t &data = fpsum[i * WARP_N_TILES + j].data[k];
data = gelu_half2(data);
// data = __hadd2(data, half2_t(SHIFT_VALUE, SHIFT_VALUE));
}
}
}
}
};
// template<int PoolSize = 128>
struct EpilogueQKVProj {
struct Arguments {
half_t *pool_out; // [M / PoolSize, N]
const float *rotary_emb; // [M, HEAD_DIM / 2, ROTARY_EMB_NUM_ELEMENTS]
const half_t *rmsnorm_weight_q; // [HEAD_DIM]
const half_t *rmsnorm_weight_k; // [HEAD_DIM]
float epsilon;
};
static constexpr int HEAD_DIM = 128;
static constexpr int NUM_HEADS_PER_WARP = WARP_N / HEAD_DIM;
static constexpr int PoolSize = 128;
static constexpr int NUM_WARPS_PER_POOL = PoolSize / WARP_M;
static constexpr int NUM_POOLS_PER_BLOCK = BLOCK_M / PoolSize;
static constexpr int ROTARY_EMB_NUM_ELEMENTS = 2; // 1 for theta, 2 for {sin, cos} pair
__device__ __forceinline__
static void apply(fpsum_warp fpsum, half_t *out, int M, int N, int K, half_t *pool_out, const float *rotary_emb, const half_t *rmsnorm_weight, float epsilon) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
constexpr int PACK_SIZE = unpack_fpsum::PACK_SIZE;
using pack_t = unpack_fpsum::pack_t;
using pack_rope_t = std::array<float, PACK_SIZE / 2 * ROTARY_EMB_NUM_ELEMENTS>;
constexpr int LANES_PER_HEAD = HEAD_DIM / PACK_SIZE;
pack_t reduce_tmp;
__shared__ alignas(128) pack_t pool[NUM_WARPS];
// load rmsnorm scales
pack_t rms;
if (laneId < LANES_PER_HEAD) {
rms = load(reinterpret_cast<const pack_t *>(&rmsnorm_weight[laneId * PACK_SIZE]));
}
if constexpr (LANES_PER_HEAD < WARP_SIZE) {
for (int i = 0; i < PACK_SIZE; i++) {
rms[i] = __shfl_sync(~0, rms[i], laneId % LANES_PER_HEAD);
}
}
const float *rotary_emb_base_addr = &rotary_emb[(warpId * WARP_M) * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS + laneId * PACK_SIZE / 2 * ROTARY_EMB_NUM_ELEMENTS];
CHECK_NAN(fpsum, "fpsum");
unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, shmem[warpId], [&](int rowId, pack_t &pack) ALWAYSINLINE {
// load rope
pack_rope_t rope;
if (laneId < LANES_PER_HEAD) {
// freq = load(reinterpret_cast<pack_freq_t *>(&freqs_cis[(warpId * WARP_M + rowId) * HEAD_DIM * 2 + laneId * PACK_SIZE * 2]));
rope = load(reinterpret_cast<const pack_rope_t *>(&rotary_emb_base_addr[rowId * HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS]));
}
if constexpr (LANES_PER_HEAD < WARP_SIZE) {
for (int i = 0; i < rope.size(); i++) {
rope[i] = __shfl_sync(~0, rope[i], laneId % LANES_PER_HEAD);
}
}
// rmsnorm
float sqrsum = 0.0f;
for (int i = 0; i < PACK_SIZE; i++) {
sqrsum += float(pack[i]) * float(pack[i]);
CHECK_NAN(sqrsum, "sqrsum");
}
#pragma unroll
for (int mask = LANES_PER_HEAD / 2; mask > 0; mask /= 2) {
sqrsum += __shfl_xor_sync(~0, sqrsum, mask);
}
sqrsum /= HEAD_DIM;
float coef = cuda_frsqrt(sqrsum + epsilon);
CHECK_NAN(coef, "coef");
for (int i = 0; i < PACK_SIZE; i++) {
pack[i] *= coef * float(rms[i]);
CHECK_NAN(rms[i], "rms.wgt");
CHECK_NAN(pack[i], "rms.out");
}
#if 1
// rope
for (int i = 0; i < PACK_SIZE; i += 2) {
float2 pack2 = half22float2(half2_t(pack[i], pack[i+1]));
CHECK_NAN(freq[i].x, "rope.freq");
CHECK_NAN(freq[i].y, "rope.freq");
CHECK_NAN(freq[i+1].x, "rope.freq");
CHECK_NAN(freq[i+1].y, "rope.freq");
// half2_t tmp = __hmul2(freq[i], pack2);
// tmp = __hfma2(freq[i+1], pack2, tmp);
// pack[i] = tmp.x;
// pack[i+1] = tmp.y;
// printf("block.x=%d block.y=%d warpId=%d rowId=%d (%d) freqs = %f %f %f %f\n",
// blockIdx.x, blockIdx.y, warpId, rowId,
// blockIdx.x * BLOCK_M + warpId * WARP_M + rowId,
// (float)freq[i].x, (float)freq[i].y, (float)freq[i+1].x, (float)freq[i+1].y
// );
// __trap();
// half2_t tmp = __hmul2(half2_t(pack2.x, pack2.x), freq[i]);
// tmp = __hfma2(half2_t(pack2.y, pack2.y), freq[i+1], tmp);
// pack[i] = tmp.x;
// pack[i+1] = tmp.y;
float sin, cos;
if constexpr (ROTARY_EMB_NUM_ELEMENTS == 1) {
sin = cuda_sin(rope[i / 2]);
cos = cuda_cos(rope[i / 2]);
}
if constexpr (ROTARY_EMB_NUM_ELEMENTS == 2) {
sin = rope[i];
cos = rope[i+1];
}
// pack[i] = pack2.x * freq[i].x + pack2.y * freq[i].y;
// pack[i+1] = pack2.x * freq[i+1].x + pack2.y * freq[i+1].y;
pack[i] = half_t(pack2.x * cos - pack2.y * sin);
pack[i+1] = half_t(pack2.x * sin + pack2.y * cos);
CHECK_NAN(pack[i], "rope.out");
CHECK_NAN(pack[i+1], "rope.out");
}
#endif
// mean pool
for (int i = 0; i < PACK_SIZE; i++) {
reduce_tmp[i] += pack[i];
}
});
if (!pool_out) {
return;
}
store<true>(&pool[warpId], reduce_tmp);
__syncthreads();
if (warpId < NUM_POOLS_PER_BLOCK) {
const int row = warpId * NUM_WARPS_PER_POOL;
reduce_tmp = load<true>(&pool[row]);
for (int i = 1; i < NUM_WARPS_PER_POOL; i++) {
pack_t pack = load<true>(&pool[row + i]);
for (int j = 0; j < PACK_SIZE; j++) {
reduce_tmp[j] += pack[j];
}
}
for (int j = 0; j < PACK_SIZE; j++) {
reduce_tmp[j] /= PoolSize;
}
store(reinterpret_cast<pack_t *>(pool_out + warpId * N), reduce_tmp);
}
__syncthreads();
}
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
assert(binfo.numBlocksN % 3 == 0);
const bool is_q = bn < binfo.numBlocksN / 3;
const bool is_k = !is_q && bn < binfo.numBlocksN / 3 * 2;
if (is_q || is_k) {
apply(
fpsum, out, M, N, K,
args.pool_out ? args.pool_out + bm * BLOCK_M / PoolSize * N : nullptr,
args.rotary_emb + bm * BLOCK_M * (HEAD_DIM / 2 * ROTARY_EMB_NUM_ELEMENTS),
is_q ? args.rmsnorm_weight_q : args.rmsnorm_weight_k,
args.epsilon
);
} else {
EpilogueDefault()(binfo, fpsum, out, M, N, K, {});
}
}
};
template<typename Epilogue, bool ACT_UNSIGNED>
struct gemm_w4a4_kernel {
__device__
void operator()(
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
// const packed_wscale_t *bias,
half_t *out,
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
bool swapBlockXY,
bool alwaysfalse)
{
BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
if (swapBlockXY) {
std::swap(binfo.bm, binfo.bn);
std::swap(binfo.numBlocksM, binfo.numBlocksN);
}
const int bm = binfo.bm;
const int bn = binfo.bn;
// bool fusequant = !out;
gemm_w4a4_block<Epilogue, ACT_UNSIGNED>(
binfo,
act + bm * (K / WARP_K) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
wgt + bn * (K / WARP_K) * WARP_N_TILES * WARP_SIZE,
ascales + bm * (K / WARP_K) * NUM_WARPS * ASCALES_NUM_PACKS * ASCALES_VALID_LANES,
wscales + bn * (K / WARP_K) * WSCALES_NUM_PACKS * WSCALES_VALID_LANES,
// bias ? bias + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES : nullptr,
out + (bm * BLOCK_M * N) + bn * BLOCK_N,
// out + (bm * N / BLOCK_N + bn) * NUM_WARPS * WARP_M_TILES * WARP_N_TILES * WARP_SIZE,
M, N, K,
epilogueArgs,
alwaysfalse
);
}
};
};
class GEMM_W8A8 : public GEMMBase<GEMMConfig_W8A8> {
public:
using psum_warp = std::array<packed_psum_t, WARP_M_TILES * WARP_N_TILES>;
__device__ __forceinline__
static packed_psum_t mma(packed_act_t act, packed_wgt_t wgt, packed_psum_t psum) {
// packed_psum_t psum;
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.x), "r"(wgt.y),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
:
"=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
:
"r"(act.x), "r"(act.y), "r"(act.z), "r"(act.w),
"r"(wgt.z), "r"(wgt.w),
// "r"(0), "r"(0), "r"(0), "r"(0)
"r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
return psum;
}
__device__ __forceinline__
static void compute(act_warp A, wgt_warp W, psum_warp &psum) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
#pragma unroll
for (int j = 0; j < WARP_N_TILES; j++) {
#pragma unroll
for (int i = 0; i < WARP_M_TILES; i++) {
psum[i * WARP_N_TILES + j] = mma(A[i], W[j], psum[i * WARP_N_TILES + j]);
}
}
}
/**
* each warp quantizes a INSN_M * INSN_K (16 * 32) matrix
* input is per-warp (in global memory / shared memory)
* rscales is per-warp (in shared memory)
* output is per-thread (in regs)
* shmem must be at least INSN_M * (INSN_K * sizeof(element) + 16) (16 * 32 = 512 Bytes)
* default to quantize activation, if quantize weight, input should be column-majored and output should be transposed ({x, y, z, w} = {x, z, y, w})
*/
template<bool input_shmem = false>
__device__ __forceinline__
static void quantize_w8a8_warp(const half *input, const half *rscales, int stride, packed_act_t &output, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
constexpr int QUANTIZE_BITWIDTH = 8;
// constexpr int QUANTIZE_BITMASK = 0xff;
// constexpr int QVALUE_MAX = 128; // 4 bit => [-128, 127]
// 1 lane = 1 pack
// 1 warp = 32 lanes = 32 packs = 1 packwarp
// a pack is {a0, ..., a7} in figure https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=ex2#mma-16864-a
// PACK_SIZE * 4 = INSN_K / 2
constexpr int PACK_SIZE = INSN_K / 8; // = 4 for 8bit
constexpr int NUM_PACKS_PER_ROW = INSN_K / PACK_SIZE;
constexpr int NUM_ROWS_PER_PACKWARP = PACK_SIZE * WARP_SIZE / INSN_K;
constexpr int NUM_PACKWARPS = INSN_M / NUM_ROWS_PER_PACKWARP;
using packed_input = std::array<half, PACK_SIZE>;
packed_input packs[NUM_PACKWARPS];
// load
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
int rowId = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW;
int colId = laneId % NUM_PACKS_PER_ROW * PACK_SIZE;
packs[i] = load<input_shmem>(reinterpret_cast<const packed_input *>(input + rowId * stride + colId));
}
// quantize
using matrix_t = uint32_t[INSN_M][NUM_PACKS_PER_ROW];
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
#pragma unroll
for (int i = 0; i < NUM_PACKWARPS; i++) {
const int row = i * NUM_ROWS_PER_PACKWARP + laneId / NUM_PACKS_PER_ROW;
const int col = laneId % NUM_PACKS_PER_ROW;
half rscale = rscales[row];
uint32_t qpack = 0;
#pragma unroll
for (int j = 0; j < PACK_SIZE; j += 2) {
half2 hval = __hmul2(make_half2(rscale, rscale), make_half2(packs[i][j], packs[i][j + 1]));
qpack |= quantize_float2<QUANTIZE_BITWIDTH, false>(__half22float2(hval)) << (j * QUANTIZE_BITWIDTH);
}
mat[row][col] = qpack;
}
__syncwarp();
// convert to imma format
int row = laneId % 16;
int col = laneId / 16 * 4;
ldmatrix(&mat[row][col], output);
__syncwarp();
}
/**
* each warp finds absmax from a row
*/
template<bool fuse_glu = false>
__device__ __forceinline__
static half findmax_warp(const half_t *input, half_t *output_shmem, int K, bool alwaysfalse) {
const int laneId = threadIdx.x % WARP_SIZE;
using packed_input = std::array<half2_t, 4>;
using packed_gated_input = std::array<half_t, 4>;
constexpr int PACK_SIZE = sizeof(packed_input) / sizeof(half_t);
constexpr int NUM_STAGES = 2;
half2_t maxvalue2 = { 0, 0 };
packed_input pack[NUM_STAGES];
#pragma unroll
for (int k = 0; k < NUM_STAGES - 1; k++) {
const int idx = k * PACK_SIZE * WARP_SIZE + laneId * PACK_SIZE;
if (idx < K) {
pack[k] = load(reinterpret_cast<const packed_input *>(&input[idx]));
} else {
pack[k].fill(make_half2(0, 0));
}
}
// int dummy = 0;
// FIXME: pipeline does not work
// TODO: store quantized data to shmem (instead of half)
for (int k1 = 0; k1 < ceilDiv(K, PACK_SIZE * WARP_SIZE); k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
const int nextidx = (k1 + k2 + NUM_STAGES - 1) * PACK_SIZE * WARP_SIZE + laneId * PACK_SIZE;
const int nextk2 = (k2 + NUM_STAGES - 1) % NUM_STAGES;
if (nextidx < K) {
pack[nextk2] = load(reinterpret_cast<const packed_input *>(&input[nextidx]));
} else {
pack[nextk2].fill(make_half2(0, 0));
}
packed_input &p = pack[k2];
if constexpr (fuse_glu) {
packed_gated_input gated;
#pragma unroll
for (int j = 0; j < p.size(); j++) {
gated[j] = p[j].x * gelu_half(p[j].y);
p[j].x = gated[j];
p[j].y = 0;
}
int idx = (k1 + k2) * PACK_SIZE / 2 * WARP_SIZE + laneId * PACK_SIZE / 2;
if (idx < K) {
store<true>(reinterpret_cast<packed_gated_input *>(&output_shmem[idx]), gated);
}
}
#pragma unroll
for (int j = 0; j < p.size(); j++) {
maxvalue2 = __hmax2(maxvalue2, __habs2(p[j]));
}
}
}
// unused_var(dummy, alwaysfalse);
#pragma unroll
for (int mask = 32 / 2; mask > 0; mask /= 2) {
maxvalue2 = __hmax2(maxvalue2, __shfl_xor_sync(~0, maxvalue2, mask));
}
return __hmax(maxvalue2.x, maxvalue2.y);
}
// each thread block quantize WARP_M * K tile (32 * K)
template<bool fuse_glu>
struct quantize_w8a8_act_kernel {
static bool check(int M, int K) {
const int K2 = fuse_glu ? K / 2 : K;
return M % WARP_M == 0 && K2 % WARP_K == 0;
}
static dim3 gridSize(int M, int K) {
return dim3(M / WARP_M);
}
static dim3 blockSize(int M, int K) {
return dim3(NUM_WARPS * 32);
}
static size_t smemSize(int M, int K) {
if constexpr (!fuse_glu) {
return 0;
}
const int K2 = fuse_glu ? K / 2 : K;
return INSN_M * K2 * sizeof(half_t);
}
__device__
void operator()(const half *input, packed_act_t *output, packed_ascale_t *oscales, int K, bool alwaysfalse) {
// for quantize kernel
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
const int numWarps = blockDim.x / WARP_SIZE;
// for GEMM kernel
const int bm = blockIdx.x / (BLOCK_M / WARP_M);
const int gemmWarpId = blockIdx.x % (BLOCK_M / WARP_M);
__shared__ alignas(128) half_t oscale_shmem[WARP_M];
__shared__ alignas(128) half_t rscale_shmem[WARP_M];
__shared__ alignas(128) uint8_t tmp_shmem[NUM_WARPS][512];
const int K2 = fuse_glu ? K / 2 : K;
// INSN_M * K2
extern __shared__ uint8_t smem[];
half_t *shmem = reinterpret_cast<half_t *>(smem);
for (int tileM = 0; tileM < WARP_M_TILES; tileM++) {
for (int i = warpId; i < INSN_M; i += numWarps) {
const int rowLocal = tileM * INSN_M + i;
const int rowGlobal = blockIdx.x * WARP_M + rowLocal;
half maxv = findmax_warp<fuse_glu>(input + rowGlobal * K, shmem + i * K2, K, alwaysfalse);
oscale_shmem[rowLocal] = maxv / half(127);
rscale_shmem[rowLocal] = half(127) / maxv;
}
__syncthreads();
for (int bk = warpId; bk < K2 / WARP_K; bk += numWarps) {
const int rowLocal = tileM * INSN_M;
const int rowGlobal = blockIdx.x * WARP_M + rowLocal;
const int col = bk * WARP_K;
packed_act_t tmpout;
if constexpr (fuse_glu) {
quantize_w8a8_warp<true>(
shmem + col,
rscale_shmem + rowLocal,
K2,
tmpout,
&tmp_shmem[warpId]
);
} else {
quantize_w8a8_warp<false>(
input + rowGlobal * K + col,
rscale_shmem + rowLocal,
K,
tmpout,
&tmp_shmem[warpId]
);
}
store(&output[(((bm * K2 / WARP_K + bk) * NUM_WARPS + gemmWarpId) * WARP_M_TILES + tileM) * WARP_SIZE + laneId], tmpout);
}
__syncthreads();
}
// [M / BLOCK_M, 1, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t
pack_ascales(oscale_shmem, &oscales[(bm * NUM_WARPS + gemmWarpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES]);
}
};
__device__ __forceinline__
static gated_fpsum_warp apply_glu(fpsum_warp fpsum) {
gated_fpsum_warp result;
for (int i = 0; i < WARP_M_TILES; i++) {
for (int j = 0; j < WARP_N_TILES; j++) {
for (int k = 0; k < 4; k++) {
half_t &dst = result[i * WARP_N_TILES + j].data[k];
half2_t src = fpsum[i * WARP_N_TILES + j].data[k];
dst = src.x * gelu_half(src.y);
}
}
}
return result;
}
template<typename F>
__device__ __forceinline__
static fpsum_warp apply_act(fpsum_warp fpsum, F func) {
fpsum_warp result;
for (int i = 0; i < WARP_M_TILES; i++) {
for (int j = 0; j < WARP_N_TILES; j++) {
for (int k = 0; k < 4; k++) {
half2_t &dst = result[i * WARP_N_TILES + j].data[k];
half2_t src = fpsum[i * WARP_N_TILES + j].data[k];
dst.x = func(src.x);
dst.y = func(src.y);
}
}
}
return result;
}
static constexpr int unpack_gated_fpsum_shmem_size = INSN_M * (WARP_N / 2 + 8) * sizeof(half_t);
__device__ __forceinline__
static void unpack_gated_fpsum(gated_fpsum_warp fpsum, half_t *output, int stride, void *shmem) {
const int laneId = threadIdx.x % WARP_SIZE;
constexpr int PACK_SIZE = WARP_N / 2 / WARP_SIZE;
using pack_t = std::array<half_t, PACK_SIZE>;
// +8 to prevent bank conflicts
using matrix_t = half_t[INSN_M][WARP_N / 2 + 8];
matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);
for (int i = 0; i < WARP_M_TILES; i++) {
for (int j = 0; j < WARP_N_TILES; j++) {
packed_gated_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
int row = laneId / 4;
int col = laneId % 4 + j * INSN_N / 2;
*reinterpret_cast<half_t *>(&mat[row][col + 0]) = fsum.data[0];
*reinterpret_cast<half_t *>(&mat[row][col + 4]) = fsum.data[2];
*reinterpret_cast<half_t *>(&mat[row + 8][col + 4]) = fsum.data[1];
*reinterpret_cast<half_t *>(&mat[row + 8][col + 4]) = fsum.data[3];
}
__syncwarp();
for (int row = 0; row < INSN_M; row++) {
pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]);
store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + row) * stride + laneId * PACK_SIZE]), pack);
}
__syncwarp();
}
}
// out: [M, N] <=> [..., NUM_WARPS, WARP_M, N] of half
template<typename Epilogue>
__device__ __forceinline__
static void gemm_w8a8_block(
const BlockInfo binfo,
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
half_t *out,
int M, int N, int K,
Epilogue::Arguments epilogeParams,
bool alwaysfalse)
{
constexpr int NUM_STAGES = 2;
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
act_warp A[NUM_STAGES]; // 8
wgt_warp W[NUM_STAGES]; // 32
ascale_warp ascale; // 1
wscale_warp wscale; // 2
psum_warp psum; // 128
for (auto &pack : psum) {
for (int i = 0; i < 8; i++) {
pack.data[i] = 0;
}
}
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true);
// load_wscale<false>(wscales, wscale[2], true);
load_ascale(ascales, 0, M, ascale, true);
load_wscale(wscales, 0, N, wscale, true);
for (int k = 0; k < NUM_STAGES - 1; k++) {
load_act(act, k, K, A[k], true);
load_wgt(wgt, k, K, W[k], true);
}
int dummy = 0;
for (int k1 = 0; k1 < K / WARP_K; k1 += NUM_STAGES) {
#pragma unroll
for (int k2 = 0; k2 < NUM_STAGES; k2++) {
int nextk = k1 + k2 + NUM_STAGES - 1;
int idx = (k2 + NUM_STAGES - 1) % NUM_STAGES;
bool pred = nextk < K / WARP_K;
load_act(act, nextk, K, A[idx], pred);
load_wgt(wgt, nextk, K, W[idx], pred);
// load_wscale<false>(wscales, wscale[idx], pred);
// __syncthreads();
// if (alwaysfalse) {
// dummy = clock();
// }
// if (alwaysfalse) {
// dummy = clock();
// }
compute(A[k2], W[k2], psum);
// if (alwaysfalse) {
// dummy = clock();
// }
// asm volatile ("membar.cta;");
}
}
unused_var(dummy, alwaysfalse);
fpsum_warp fpsum;
apply_scales([&](int i, int j) {
return psum[i * WARP_N_TILES + j];
}, ascale, wscale, fpsum);
Epilogue()(binfo, fpsum, out, M, N, K, epilogeParams);
}
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template<typename Epilogue>
struct gemm_w8a8_kernel {
__device__
void operator()(
const packed_act_t *act,
const packed_wgt_t *wgt,
const packed_ascale_t *ascales,
const packed_wscale_t *wscales,
half_t *out,
int M, int N, int K,
Epilogue::Arguments epilogueArgs,
bool swapBlockXY,
bool alwaysfalse)
{
BlockInfo binfo = {
.bm = (int)blockIdx.x,
.bn = (int)blockIdx.y,
.numBlocksM = (int)gridDim.x,
.numBlocksN = (int)gridDim.y,
};
if (swapBlockXY) {
std::swap(binfo.bm, binfo.bn);
std::swap(binfo.numBlocksM, binfo.numBlocksN);
}
const int bm = binfo.bm;
const int bn = binfo.bn;
gemm_w8a8_block<Epilogue>(
binfo,
act + bm * (K / WARP_K) * NUM_WARPS * WARP_M_TILES * WARP_SIZE,
wgt + bn * (K / WARP_K) * WARP_N_TILES * WARP_SIZE,
ascales + bm * (1) * NUM_WARPS * ASCALES_NUM_PACKS * ASCALES_VALID_LANES, // only 1 group in W8A8
wscales + bn * (1) * WSCALES_NUM_PACKS * WSCALES_VALID_LANES,
#if 1
out + (bm * BLOCK_M * N) + bn * BLOCK_N,
#else
out + (bm * BLOCK_M * N / 2) + bn * BLOCK_N / 2,
#endif
M, N, K,
epilogueArgs,
alwaysfalse
);
}
};
struct EpilogueGLU {
struct Arguments {};
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int warpId = threadIdx.x / WARP_SIZE;
gated_fpsum_warp gated_fpsum = apply_glu(fpsum);
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_gated_fpsum_shmem_size, 128) * 128];
unpack_gated_fpsum(gated_fpsum, out + warpId * WARP_M * N / 2, N / 2, shmem[warpId]);
}
};
struct EpilogueSilu {
struct Arguments {};
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int warpId = threadIdx.x / WARP_SIZE;
fpsum = apply_act(fpsum, [](half_t x) { return silu(x); });
__shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];
unpack_fpsum()(fpsum, out + warpId * WARP_M * N, N, shmem[warpId]);
}
};
struct EpilogueLiteLA {
__device__ __forceinline__
static half2_t movmatrix(half2_t x) {
asm volatile ("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" : "=r"(*reinterpret_cast<uint32_t *>(&x)) : "r"(*reinterpret_cast<uint32_t *>(&x)));
return x;
}
// __device__ __forceinline__
// static uint4 hmma_fp32(uint4 a, uint2 b, uint4 c) {
// asm volatile(
// "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
// "{%0, %1, %2, %3},"
// "{%4, %5, %6, %7},"
// "{%8, %9},"
// "{%10, %11, %12, %13};\n"
// :
// "=r"(c.x), "=r"(c.y), "=r"(c.z), "=r"(c.w)
// :
// "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
// "r"(b.x), "r"(b.y),
// // "r"(0), "r"(0), "r"(0), "r"(0)
// "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
// );
// return c;
// }
__device__ __forceinline__
static packed_f32psum_t mma_litela(packed_fpsum_t k, packed_fpsum_t v, packed_f32psum_t psum) {
for (int i = 0; i < 4; i++) {
k.data[i] = movmatrix(k.data[i]);
v.data[i] = movmatrix(v.data[i]);
}
std::swap(v.data[1], v.data[2]);
return mma_f16xf16_f32(v, k, psum);
}
static constexpr int LITELA_HEAD_DIM = 32;
static constexpr int LITELA_K_TILES = LITELA_HEAD_DIM / 16;
static constexpr int LITELA_V_TILES = LITELA_HEAD_DIM / 16;
static constexpr int SHMEM_SIZE = NUM_WARPS * (LITELA_HEAD_DIM + 1) * (LITELA_HEAD_DIM + 8) * sizeof(float);
// out_vk: [batch_size, num_heads, head_dim + 1, head_dim]
__device__ __forceinline__
static void apply_litela(const BlockInfo binfo, fpsum_warp fpsum, float *out_vk, int batch_m) {
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
using vk_t = float[NUM_WARPS][LITELA_HEAD_DIM + 1][LITELA_HEAD_DIM + 8];
extern __shared__ uint8_t shmem[];
vk_t &shmem_vk = *reinterpret_cast<vk_t *>(shmem);
static_assert(sizeof(vk_t) == SHMEM_SIZE);
static_assert(WARP_N == BLOCK_N);
assert(binfo.numBlocksN % 3 == 0);
const int num_heads = binfo.numBlocksN / 3 * 2 * (WARP_N / (LITELA_HEAD_DIM * 2));
const int batch_id = binfo.bm * BLOCK_M / batch_m;
for (int head_id = 0; head_id < WARP_N / (LITELA_HEAD_DIM * 2); head_id++) {
const int global_head_id = (binfo.bn - binfo.numBlocksN / 3) * (WARP_N / (LITELA_HEAD_DIM * 2)) + head_id;
float *out_vk_current_head = out_vk + (batch_id * num_heads + global_head_id) * (LITELA_HEAD_DIM + 1) * LITELA_HEAD_DIM;
for (int i = laneId; i < sizeof(shmem_vk) / sizeof(float) / NUM_WARPS; i += WARP_SIZE) {
*((&shmem_vk[warpId][0][0]) + i) = 0;
}
__syncwarp();
for (int tile_v = 0; tile_v < LITELA_V_TILES; tile_v++) {
for (int tile_k = 0; tile_k < LITELA_K_TILES; tile_k++) {
packed_f32psum_t attn_sum = { 0 };
for (int i = 0; i < WARP_M_TILES; i++) {
packed_fpsum_t k = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + tile_k];
packed_fpsum_t v = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + LITELA_HEAD_DIM / 16 + tile_v];
for (int j = 0; j < 4; j++) {
k.data[j] = __hmax2(k.data[j], make_half2(0, 0)); // relu
}
attn_sum = mma_litela(k, v, attn_sum);
}
const int row = tile_v * 16 + laneId / 4;
const int col = tile_k * 16 + laneId % 4 * 2;
shmem_vk[warpId][row + 0][col + 0] = attn_sum.data[0];
shmem_vk[warpId][row + 0][col + 1] = attn_sum.data[1];
shmem_vk[warpId][row + 8][col + 0] = attn_sum.data[2];
shmem_vk[warpId][row + 8][col + 1] = attn_sum.data[3];
shmem_vk[warpId][row + 0][col + 8] = attn_sum.data[4];
shmem_vk[warpId][row + 0][col + 9] = attn_sum.data[5];
shmem_vk[warpId][row + 8][col + 8] = attn_sum.data[6];
shmem_vk[warpId][row + 8][col + 9] = attn_sum.data[7];
}
}
for (int tile_k = 0; tile_k < LITELA_K_TILES; tile_k++) {
packed_f32psum_t attn_sum = { 0 };
for (int i = 0; i < WARP_M_TILES; i++) {
packed_fpsum_t k = fpsum[i * WARP_N_TILES + head_id * (LITELA_HEAD_DIM * 2) / 16 + tile_k];
packed_fpsum_t v = {}; // TODO fill to 0
if (laneId < 4) {
v.data[0] = make_half2(1, 1);
v.data[2] = make_half2(1, 1);
}
attn_sum = mma_litela(k, v, attn_sum);
}
const int row = LITELA_HEAD_DIM + laneId / 4;
const int col = tile_k * 16 + laneId % 4 * 2;
if (laneId < 4) {
shmem_vk[warpId][row + 0][col + 0] = attn_sum.data[0];
shmem_vk[warpId][row + 0][col + 1] = attn_sum.data[1];
shmem_vk[warpId][row + 0][col + 8] = attn_sum.data[4];
shmem_vk[warpId][row + 0][col + 9] = attn_sum.data[5];
}
}
__syncthreads();
for (int i = warpId; i < LITELA_HEAD_DIM + 1; i += NUM_WARPS) {
for (int j = laneId; j < LITELA_HEAD_DIM; j += WARP_SIZE) {
float sum = 0;
for (int k = 0; k < NUM_WARPS; k++) {
sum += shmem_vk[k][i][j];
}
reduce_add(&out_vk_current_head[i * LITELA_HEAD_DIM + j], sum);
}
}
__syncthreads();
}
}
struct Arguments {
half_t *out_q;
float *out_vk;
int batch_m;
};
__device__ __forceinline__
void operator()(const BlockInfo binfo, fpsum_warp fpsum, half_t *out, int M, int N, int K, Arguments args) {
const int bm = binfo.bm;
const int bn = binfo.bn;
if (bn < binfo.numBlocksN / 3) {
fpsum = apply_act(fpsum, [](half_t x) { return __hmax(x, 0); }); // relu
return EpilogueDefault()(
binfo,
fpsum,
args.out_q + (bm * BLOCK_M * N / 3) + bn * BLOCK_N,
M, N / 3, K, EpilogueDefault::Arguments{});
}
return apply_litela(binfo, fpsum, args.out_vk, args.batch_m);
}
// each thread block mults BlockSize*HEAD_DIM q and (HEAD_DIM+1)*HEAD_DIM vk, in-place writes back to q
// q: [batch_size, #blocks, block_size, #heads, HEAD_DIM]
// vk: [batch_size, #heads, HEAD_DIM+1, HEAD_DIM]
struct vk_mul_q_kernel {
// FIXME FIXME FIXME
__device__
void operator()(half_t *q, const float *vk, float eps) {
const int block_id = blockIdx.x;
const int head_id = blockIdx.y;
const int batch_id = blockIdx.z;
const int num_blocks = gridDim.x;
const int num_heads = gridDim.y;
const int block_size = blockDim.x;
half_t *localq = &q[(((batch_id * num_blocks + block_id) * block_size + threadIdx.x) * num_heads + head_id) * LITELA_HEAD_DIM];
const float *localvk = &vk[(batch_id * num_heads + head_id) * (LITELA_HEAD_DIM + 1) * LITELA_HEAD_DIM];
// half_t *localout = &out[(((batch_id * num_blocks + block_id) * block_size + threadIdx.x) * num_heads + head_id) * LITELA_HEAD_DIM];
using packed_q = std::array<half_t, 8>;
using packed_vk = std::array<float, 4>;
half_t qblock[LITELA_HEAD_DIM];
for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_q) / sizeof(half_t)) {
*reinterpret_cast<packed_q *>(&qblock[i]) = load(reinterpret_cast<const packed_q *>(&localq[i]));
}
float outblock[LITELA_HEAD_DIM + 1];
#pragma unroll
for (int j = 0; j < LITELA_HEAD_DIM + 1; j++) {
outblock[j] = 0;
#pragma unroll
for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_vk) / sizeof(float)) {
packed_vk vkpack = load(reinterpret_cast<const packed_vk *>(&localvk[j * LITELA_HEAD_DIM + i]));
#pragma unroll
for (int k = 0; k < vkpack.size(); k++) {
outblock[j] += (float)qblock[i + k] * vkpack[k];
}
}
}
for (int i = 0; i < LITELA_HEAD_DIM; i += sizeof(packed_q) / sizeof(half_t)) {
packed_q opack;
for (int k = 0; k < opack.size(); k++) {
opack[k] = __fdividef(outblock[i + k], outblock[LITELA_HEAD_DIM] + eps);
}
store(reinterpret_cast<packed_q *>(&localq[i]), opack);
}
}
};
};
};
template<typename kernel, typename ...T>
__global__
static void invoke_kernel(T ...args) {
kernel()(args...);
}
void gemm_w4a4(
Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
bool act_unsigned,
std::vector<float> lora_scales // [R / 16]
) {
using GEMM = GEMM_W4A4;
int M = act.numel() / act.shape[-1];
int N = wgt.shape[0];
int K = act.shape[-1] * 2;
assert(K == wgt.shape[1] * 2);
auto launch = [&]<typename Epilogue>(Epilogue::Arguments args) {
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
bool swapBlockMN = M > N * 2;
if (swapBlockMN) {
std::swap(grid.x, grid.y);
}
dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
invoke_kernel<GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>><<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS>>>(
act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
// bias.valid() ? bias.data_ptr<GEMM::packed_wscale_t>() : nullptr,
out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
M, N, K,
args,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
});
};
auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
if (!bias.valid()) {
return launch.template operator()<NextEpilogue>(nextArgs);
}
assert(bias.numel() == N);
using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias, NextEpilogue>;
return launch.template operator()<Epilogue>({
GEMM::EpilogueBias::Arguments{
.bias = bias.data_ptr<GEMM::packed_wscale_t>(),
},
nextArgs
});
};
// auto launch_bias = launch;
auto launch_lora = [&]<typename NextEpilogue, typename MidEpilogue>(NextEpilogue::Arguments nextArgs, MidEpilogue::Arguments midArgs) {
assert(lora_up.valid() == lora_act_in.valid());
assert(lora_down.valid() == lora_act_out.valid());
if (!lora_up.valid()) {
assert(!lora_down.valid());
return launch_bias.template operator()<GEMM::EpilogueCombination<MidEpilogue, NextEpilogue>>({midArgs, nextArgs});
}
const int rank_up = lora_up.shape[1];
assert(lora_up.shape[0] == N);
// assert(lora_up.shape[1] == Lora::LORA_RANK);
assert(lora_act_in.shape[0] == M);
assert(lora_act_in.shape[1] == rank_up);
dispatchVal(rank_up, std::integer_sequence<int, 0, 32, 48, 64, 80, 96>(), [&]<int RANK_UP>() {
using LoraUp = GEMM::Lora<RANK_UP>;
using scale_t = typename LoraUp::scale_t;
scale_t scales;
if constexpr (scales.size() > 0) {
assert(lora_scales.size() >= scales.size());
for (size_t i = 0; i < scales.size(); i++) {
scales[i] = lora_scales[i];
}
}
if (!lora_down.valid()) {
using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, NextEpilogue>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<GEMM::packed_fpsum_t>(),
.scales = scales,
},
midArgs,
nextArgs
});
}
const int rank_down = lora_down.shape[1];
assert(rank_down == rank_up);
assert(lora_down.shape[0] == N);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert(lora_act_out.shape[0] == M);
assert(lora_act_out.shape[1] == rank_down);
lora_act_out.zero_();
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using LoraDown = LoraUp; // GEMM::Lora<RANK_DOWN>;
using Epilogue = GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, typename LoraDown::EpilogueLoraDown, NextEpilogue>;
return launch_bias.template operator()<Epilogue>({
typename LoraUp::EpilogueLoraUp::Arguments{
.lora_act = lora_act_in.data_ptr<float>(),
.lora_wgt_up = lora_up.data_ptr<GEMM::packed_fpsum_t>(),
.scales = scales,
},
midArgs,
typename LoraDown::EpilogueLoraDown::Arguments{
.lora_wgt_down = lora_down.data_ptr<GEMM::packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
},
nextArgs
});
// });
});
};
if (qout.valid() && oscales.valid()) {
// dispatchBool(qout_unsigned, [&]<bool USE_UNSIGNED>() {
static constexpr float SHIFT_GELU = 0.171875f;
constexpr bool USE_UNSIGNED = true;
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED>;
auto argsQuantize = typename EpilogueQuantize::Arguments{
.qout = qout.data_ptr<GEMM::packed_act_t>(),
.oscales = oscales.data_ptr<GEMM::packed_ascale_t>(),
.shift_value = SHIFT_GELU,
.smooth_factor = smooth_factor.data_ptr<GEMM::packed_wscale_t>()
};
// TODO: check if gelu is needed
if (out.valid()) {
launch_lora.template operator()<GEMM::EpilogueCombination<GEMM::EpilogueDefault, EpilogueQuantize>, GEMM::EpilogueGelu>({
GEMM::EpilogueDefault::Arguments{},
argsQuantize
}, {});
} else {
launch_lora.template operator()<EpilogueQuantize, GEMM::EpilogueGelu>(argsQuantize, {});
}
// });
} else if (rotary_emb.valid()) {
assert(norm_q.valid());
assert(norm_k.valid());
// assert(isTypeMatch<GEMM::half_t>(rotary_emb.scalar_type()));
assert(rotary_emb.scalar_type() == Tensor::FP32);
assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
launch_lora.template operator()<GEMM::EpilogueQKVProj, GEMM::EpilogueNop>(GEMM::EpilogueQKVProj::Arguments{
.pool_out = poolout.valid() ? poolout.data_ptr<GEMM::half_t>() : nullptr,
.rotary_emb = rotary_emb.data_ptr<float>(),
.rmsnorm_weight_q = norm_q.data_ptr<GEMM::half_t>(),
.rmsnorm_weight_k = norm_k.data_ptr<GEMM::half_t>(),
.epsilon = 1e-6,
}, {});
} else if (out.valid()) {
launch_lora.template operator()<GEMM::EpilogueDefault, GEMM::EpilogueNop>({}, {});
} else {
assert(false);
}
}
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth) {
using GEMM = GEMM_W4A4;
// using Lora = GEMM::Lora;
int M = input.numel() / input.shape[-1];
int N = input.shape[-1];
assert(output.dtype() == Tensor::INT8);
assert(output.numel() / output.shape[-1] == M);
assert(output.shape[-1] == N / 2);
// assert(oscales.dtype() == Tensor::FP16);
assert(isTypeMatch<GEMM::half_t>(oscales.dtype()));
assert(oscales.numel() == M * N / GEMM::WARP_K);
const int rank = lora_down.shape[1];
assert(lora_down.shape[0] == N);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert(lora_act_out.shape[0] == M);
assert(lora_act_out.shape[1] == rank);
lora_act_out.zero_();
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
dispatchVal(rank, std::integer_sequence<int, 0, 32, 48, 64, 80, 96>(), [&]<int RANK>() {
using Lora = typename GEMM::Lora<RANK>;
using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel;
auto func = invoke_kernel<kernel, typename kernel::Arguments>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE>>>(
typename kernel::Arguments{
.input = input.data_ptr<GEMM::half_t>(),
.smooth_factor = smooth.valid() ? smooth.data_ptr<GEMM::packed_wscale_t>() : nullptr,
.output = output.data_ptr<GEMM::packed_act_t>(),
.oscales = oscales.data_ptr<GEMM::packed_ascale_t>(),
.lora_wgt_down = lora_down.data_ptr<GEMM::packed_fpsum_t>(),
.lora_act = lora_act_out.data_ptr<float>(),
.M = M,
.N = N,
}
);
checkCUDA(cudaGetLastError());
});
}
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
using GEMM = GEMM_W4A4;
int M = input.numel() / input.shape[-1];
int K = input.shape[-1];
assert(output.dtype() == Tensor::INT8);
assert(output.numel() / output.shape[-1] == M);
assert(output.shape[-1] == K / 2);
// assert(oscales.dtype() == Tensor::FP16);
assert(isTypeMatch<GEMM::half_t>(oscales.dtype()));
assert(oscales.numel() == M * K / GEMM::WARP_K);
dim3 grid(M / GEMM::WARP_M, K / GEMM::WARP_K);
invoke_kernel<GEMM::quantize_w4a4_act_kernel><<<grid, GEMM::WARP_SIZE>>>(
input.data_ptr<GEMM::half_t>(),
output.data_ptr<GEMM::packed_act_t>(),
oscales.data_ptr<GEMM::packed_ascale_t>(),
K
);
checkCUDA(cudaGetLastError());
}
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
using GEMM = GEMM_W4A4;
int N = input.numel() / input.shape[-1];
int K = input.shape[-1];
assert(output.dtype() == Tensor::INT8);
assert(output.ndims() == 2);
assert(output.shape[0] == N);
assert(output.shape[1] == K / 2);
assert(isTypeMatch<GEMM::half_t>(oscales.dtype()));
// assert(oscales.dtype() == Tensor::FP16);
assert(oscales.numel() == N * K / GEMM::WARP_K);
dim3 grid(N / GEMM::WARP_N, K / GEMM::WARP_K);
invoke_kernel<GEMM::quantize_w4a4_wgt_kernel><<<grid, GEMM::WARP_SIZE>>>(
input.data_ptr<GEMM::half_t>(),
output.data_ptr<GEMM::packed_wgt_t>(),
oscales.data_ptr<GEMM::packed_wscale_t>(),
K
);
checkCUDA(cudaGetLastError());
}
void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_glu) {
int M = input.numel() / input.shape[-1];
int K = input.shape[-1];
assert(output.dtype() == Tensor::INT8);
assert(output.numel() / output.shape[-1] == M);
assert(output.shape[-1] == fuse_glu ? K / 2 : K);
assert(oscales.dtype() == Tensor::FP16);
assert(oscales.numel() == M * 1);
auto launch = [&]<bool FUSE_GLU>() {
using GEMM = GEMM_W8A8;
using kernel = GEMM::quantize_w8a8_act_kernel<FUSE_GLU>;
assert(kernel::check(M, K));
dim3 grid = kernel::gridSize(M, K);
dim3 block = kernel::blockSize(M, K);
auto func = invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, 92160));
func<<<grid, block, kernel::smemSize(M, K)>>>(
input.data_ptr<GEMM::half_t>(),
output.data_ptr<GEMM::packed_act_t>(),
oscales.data_ptr<GEMM::packed_ascale_t>(),
K,
false
);
checkCUDA(cudaGetLastError());
};
if (fuse_glu) {
launch.template operator()<true>();
} else {
launch.template operator()<false>();
}
}
void gemm_w8a8(Tensor act, // [M, K]
Tensor wgt, // [N, K]
Tensor out, // [M, N]
Tensor ascales, // [1, M]
Tensor wscales // [1, N]
)
{
using GEMM = GEMM_W8A8;
using Epilogue = GEMM::EpilogueSilu;
int M = act.numel() / act.shape[-1];
int N = wgt.shape[0];
int K = act.shape[-1];
assert(K == wgt.shape[1]);
Epilogue::Arguments epilogueArgs;
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
bool swapBlockMN = M > N * 2;
if (swapBlockMN) {
std::swap(grid.x, grid.y);
}
invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>><<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS>>>(
act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
out.data_ptr<GEMM::half_t>(),
M, N, K, epilogueArgs,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
}
void gemm_w8a8_fuse_litela(
Tensor act, // [B, (M), K]
Tensor wgt, // [N, K]
Tensor out_q, // [B, (M), N / 3]
Tensor out_vk, // [B, num_heads, head_dim + 1, head_dim]
Tensor ascales, // [1, M]
Tensor wscales // [1, N]
) {
using GEMM = GEMM_W8A8;
using Epilogue = GEMM::EpilogueLiteLA;
int M = act.numel() / act.shape[-1];
int N = wgt.shape[0];
int K = act.shape[-1];
assert(K == wgt.shape[1]);
assert(out_vk.ndims() == 4);
assert(out_vk.shape[2] == Epilogue::LITELA_HEAD_DIM + 1);
assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM);
assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N);
int batch_size = out_vk.shape[0];
int num_heads = out_vk.shape[1];
assert(M % batch_size == 0);
int batch_m = M / batch_size;
Epilogue::Arguments epilogueArgs;
epilogueArgs.batch_m = act.shape[1];
epilogueArgs.out_q = out_q.data_ptr<GEMM::half_t>();
epilogueArgs.out_vk = out_vk.data_ptr<float>();
checkCUDA(cudaMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
const GEMM::packed_act_t *,
const GEMM::packed_wgt_t *,
const GEMM::packed_ascale_t *,
const GEMM::packed_wscale_t *,
GEMM::half_t *,
int, int, int,
Epilogue::Arguments,
bool,
bool>;
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, Epilogue::SHMEM_SIZE));
dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);
bool swapBlockMN = M > N * 2;
if (swapBlockMN) {
std::swap(grid.x, grid.y);
}
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, Epilogue::SHMEM_SIZE>>>(
act.data_ptr<GEMM::packed_act_t>(),
wgt.data_ptr<GEMM::packed_wgt_t>(),
ascales.data_ptr<GEMM::packed_ascale_t>(),
wscales.data_ptr<GEMM::packed_wscale_t>(),
nullptr,
M, N, K, epilogueArgs,
swapBlockMN,
false
);
checkCUDA(cudaGetLastError());
invoke_kernel<Epilogue::vk_mul_q_kernel><<<dim3(batch_m / 128, num_heads, batch_size), 128>>>(
out_q.data_ptr<GEMM::half_t>(),
out_vk.data_ptr<float>(),
1e-6f
);
checkCUDA(cudaGetLastError());
}
\ No newline at end of file
#pragma once
#include "common.h"
#include "Tensor.h"
void gemm_w4a4(
Tensor act, // packed act [M, K / 2]
Tensor wgt, // packed act [N, K / 2]
Tensor out, // linear [M, N]
Tensor qout, // packed act [M, N / 2]
Tensor ascales, // packed as [K / 64, M]
Tensor wscales, // packed ws [K / 64, N]
Tensor oscales, // packed as [N / 64, M]
Tensor poolout, // linear [M / PoolSize, N]
Tensor lora_act_in, // packed lora_act [M, R]
Tensor lora_up, // packed lora_wgt [N, R]
Tensor lora_down, // packed lora_wgt [N, R]
Tensor lora_act_out, // packed lora_act [M, R]
Tensor norm_q, // linear [HEAD_DIM]
Tensor norm_k, // linear [HEAD_DIM]
Tensor rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
Tensor bias, // packed ws [N]
Tensor smooth_factor, // packed ws [N], for quantization of the next layer
bool act_unsigned,
std::vector<float> lora_scales // [R / 16]
);
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth = {});
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);
void gemm_w8a8(Tensor act, // [M, K]
Tensor wgt, // [N, K]
Tensor out, // [M, N]
Tensor ascales, // [1, M]
Tensor wscales // [1, N]
);
void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_glu);
void gemm_w8a8_fuse_litela(
Tensor act, // [B, (M), K]
Tensor wgt, // [N, K]
Tensor out_q, // [B, (M), N / 3]
Tensor out_vk, // [B, num_heads, head_dim + 1, head_dim]
Tensor ascales, // [1, M]
Tensor wscales // [1, N]
);
\ No newline at end of file
#include "common.h"
#include "Tensor.h"
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>
using spdlog::fmt_lib::format;
Tensor gemm_w8a8_fp16(Tensor input, // INT8
Tensor weight, // INT8
Tensor out, // FP16
half alpha,
half beta // FP16
) {
auto N = weight.size(0);
auto K = input.size(-1);
auto M = input.numel() / K;
assert(weight.size(1) == K);
spdlog::debug("gemm_w8a8: M={} K={} N={}", M, K, N);
using ElementOutput = cutlass::half_t;
using ElementAccumulator = int32_t;
using ElementComputeEpilogue = cutlass::half_t;
using ElementInputA = int8_t; // <- data type of elements in input matrix A
using ElementInputB = int8_t; // <- data type of elements in input matrix B
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// #if CUDA_ARCH >= 800
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementComputeEpilogue>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
// #elif CUDA_ARCH >= 750
// using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
// cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
// ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
// using Gemm = cutlass::gemm::device::Gemm<
// int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
// ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
// cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
// DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
// DefaultGemmCfg::InstructionShape,
// cutlass::epilogue::thread::LinearCombination<
// ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
// ElementAccumulator, ElementComputeEpilogue>>;
// #elif CUDA_ARCH >= 700
// using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
// cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
// ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
// using Gemm = cutlass::gemm::device::Gemm<
// int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
// ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
// cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
// DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
// DefaultGemmCfg::InstructionShape,
// cutlass::epilogue::thread::LinearCombination<
// ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>;
// #else
// #error "Unsupported cuda arch"
// #endif
auto input_size = cutlass::MatrixCoord(M, K);
auto weight_size = cutlass::MatrixCoord(K, N);
auto output_size = cutlass::MatrixCoord(M, N);
auto device = input.device();
// use the broadcasted bias as the output
// auto out = bias.to(device).view({1, -1}).repeat({M, 1});
if (!out.valid()) {
auto out_shape = input.shape;
out_shape[-1] = N;
out = Tensor::empty(out_shape, Tensor::FP16, input.device());
}
// FIXME: check contiguous of input if dims >= 3
assert(input.stride(-1) == 1);
// assert(input.is_contiguous());
assert(weight.is_contiguous());
assert(out.dtype() == Tensor::FP16);
assert(out.shape[-1] == N);
assert(out.numel() / out.shape[-1] == M);
assert(out.stride(-1) == 1);
// FIXME: check contiguous of output if dims >= 3
// constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass::gemm::GemmCoord problem_size(M, N, K);
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
input.data_ptr<ElementInputA>(), LayoutInputA(input.stride(-2)));
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
out.data_ptr<ElementOutput>(), LayoutOutput(out.stride(-2)));
typename Gemm::Arguments arguments{
problem_size, // <- problem size of matrix multiplication
input_ref, // <- reference to matrix A on device
weight_ref, // <- reference to matrix B on device
out_ref, // <- reference to matrix C on device
out_ref, // <- reference to matrix D on device
{ElementOutput(alpha), ElementOutput(beta)},
1};
Gemm gemm_op;
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
// cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
BufferCUDA workspace(workspace_size);
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error(format("cutlass cannot implement M={} N={} K={}", M, N, K));
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.getPtr());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
return out;
}
#pragma once
#include "common.h"
#include "Tensor.h"
Tensor gemm_w8a8_fp16(Tensor input, // INT8
Tensor weight, // INT8
Tensor out,
half scale,
half bias
);
\ No newline at end of file
#include "layernorm_kernels_impl.cuh"
#include "dispatch_utils.h"
void rms_norm(Tensor &out, // [..., hidden_size]
Tensor &input, // [..., hidden_size]
Tensor &weight, // [hidden_size]
float epsilon,
bool use_quant) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
if (use_quant) {
vllm::rms_norm_kernel<scalar_t, int8_t, true><<<grid, block, 0, stream>>>(
out.data_ptr<int8_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
} else {
vllm::rms_norm_kernel<scalar_t, scalar_t, false><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
}
});
}
void layernorm_general(Tensor out, Tensor input, Tensor weight, Tensor bias, float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 256));
block.x = 32 * ((block.x + 31) / 32);
size_t size_shmem = input.scalar_size() * hidden_size;
const cudaStream_t stream = getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalLayerNorm", [&] {
using T = typename packed_as<scalar_t, 2>::type;
vllm::generalLayerNorm<T, half, true><<<grid, block, size_shmem, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
weight.valid() ? reinterpret_cast<T*>(weight.data_ptr<scalar_t>()) : nullptr,
bias.valid() ? reinterpret_cast<T*>(bias.data_ptr<scalar_t>()) : nullptr,
reinterpret_cast<T*>(out.data_ptr<scalar_t>()),
epsilon, num_tokens, hidden_size, nullptr, nullptr, nullptr, true
);
});
}
void rms_norm_general(Tensor &out, // [..., hidden_size]
Tensor &input, // [..., hidden_size]
Tensor &weight, // [hidden_size]
Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
block.x = 32 * ((block.x + 31) / 32);
const cudaStream_t stream = getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalLayerNorm", [&] {
using T = scalar_t;
if (use_per_token_quant) {
// per-token
vllm::generalLayerNorm<T, half><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr<half>(),
out.data_ptr<int8_t>(), false
);
// input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale
// normed_output_quant, use_shmem
// out.data_ptr<int8_t>(), input.data_ptr<scalar_t>(),
// weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
} else {
// per-tensor
vllm::generalLayerNorm<T, half><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
nullptr, epsilon, num_tokens, hidden_size, scaling.data_ptr<half>(), nullptr,
out.data_ptr<int8_t>(), false
);
}
});
}
void rms_norm_general_fuse_sum(Tensor &out, // [..., hidden_size]
Tensor &input, // [..., hidden_size]
Tensor &weight, // [hidden_size]
Tensor &input_sum, // [tokens] or [1]
Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
block.x = 32 * ((block.x + 31) / 32);
const cudaStream_t stream = getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalLayerNorm_fuse_sum", [&] {
using T = scalar_t;
if (use_per_token_quant) {
// per-token
vllm::generalLayerNorm_fuse_sum<T, half><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
nullptr, epsilon, num_tokens, hidden_size, input_sum.data_ptr<half>(), nullptr, scaling.data_ptr<half>(),
out.data_ptr<int8_t>(), false
);
// input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale
// normed_output_quant, use_shmem
// out.data_ptr<int8_t>(), input.data_ptr<scalar_t>(),
// weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
} else {
// per-tensor
// Rasing error here
// Not implemented per-tensor input_sum
assert(false);
vllm::generalLayerNorm_fuse_sum<T, half><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr<half>(), nullptr,
out.data_ptr<int8_t>(), false
);
}
});
}
void invoke_dequant_add_residual_rms_norm_quant(
Tensor &out, // [..., hidden_size]
Tensor &input, // [..., hidden_size]
Tensor &residual, // [..., hidden_size]
Tensor &gamma, // [hidden_size]
half scale,
float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
residual.scalar_type(), "dequant_add_residual_rms_norm_quant_kernel",
[&] {
vllm::dequant_add_residual_rms_norm_quant_kernel<scalar_t, half, false>
<<<grid, block, 0, stream>>>(
input.data_ptr<int32_t>(), residual.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(), gamma.data_ptr<scalar_t>(), epsilon,
scale, num_tokens, hidden_size);
});
}
void invoke_dequant_add_residual_rms_norm_quant(
Tensor &out, // [..., hidden_size]
Tensor &input, // [..., hidden_size]
Tensor &residual, // [..., hidden_size]
Tensor &gamma, // [hidden_size]
Tensor &scale, // [num_tokens]
float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
residual.scalar_type(), "dequant_add_residual_rms_norm_quant_kernel",
[&] {
vllm::dequant_add_residual_rms_norm_quant_kernel<scalar_t, half*, true>
<<<grid, block, 0, stream>>>(
input.data_ptr<int32_t>(), residual.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(), gamma.data_ptr<scalar_t>(), epsilon,
scale.data_ptr<half>(), num_tokens, hidden_size);
});
}
#pragma once
#include "common.h"
#include "Tensor.h"
#include <cuda_fp16.h>
void rms_norm(Tensor &out, // [num_tokens, hidden_size]
Tensor &input, // [num_tokens, hidden_size]
Tensor &weight, // [hidden_size]
float epsilon, bool use_quant);
void layernorm_general(Tensor out, Tensor input, Tensor weight, Tensor bias, float epsilon);
void rms_norm_general(Tensor &out, // [..., hidden_size]
Tensor &input, // [..., hidden_size]
Tensor &weight, // [hidden_size]
Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant);
void rms_norm_general_fuse_sum(Tensor &out, // [..., hidden_size]
Tensor &input, // [..., hidden_size]
Tensor &weight, // [hidden_size]
Tensor &input_sum, // [tokens] or [1]
Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant);
void invoke_dequant_add_residual_rms_norm_quant(
Tensor &out, // [..., hidden_size]
Tensor &input, // [..., hidden_size]
Tensor &residual, // [..., hidden_size]
Tensor &gamma, // [hidden_size]
half scale, float epsilon);
void invoke_dequant_add_residual_rms_norm_quant(
Tensor &out, // [..., hidden_size]
Tensor &input, // [..., hidden_size]
Tensor &residual, // [..., hidden_size]
Tensor &gamma, // [hidden_size]
Tensor &scale, // [num_tokens]
float epsilon);
\ No newline at end of file
#include <cuda_bf16.h>
#define ENABLE_BF16 1
#include "utils.cuh"
#include "reduction_utils.cuh"
namespace vllm {
// from TRTLLM
template <typename Tf, typename T>
__inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_variance, const T* gamma, const T* beta, int i)
{
Tf ret = (val - s_mean) * s_variance;
if (gamma != nullptr) {
ret = ret * cuda_cast<Tf>(gamma[i]);
}
if (beta != nullptr) {
ret = ret + cuda_cast<Tf>(beta[i]);
}
return ret;
}
// from TRTLLM
/* Computes the layernorm https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
* normed_output <- ( (input - E[input]) / Sqrt(Var[input] + eps) ) * gamma + beta
* input is [tokens, hidden_dim]. Mean and Variance are per-row (i.e. per-token)
*
* One CTA handles one row.
*
* with USE_DIFF_OF_SQUARES set to false:
* First pass (loop) computes the mean.
* Second computes the variance via Var[x] = E[(x - E[x])²].
* Third pass computes and writes normed_output
*
* with USE_DIFF_OF_SQUARES set to true (may be faster but less accurate):
* First pass (loop) computes the mean and variance via Var[x] = E[x²] - E[x]²
* Second pass computes and writes normed_output
*
* use_shmem controls if we cache input values into shared memory
*
* Optional: with dynamic scaling, the last pass doesn't write immediately but finds the
* amax per row. A final pass scales to int8 accordingly, and writes output to
* normed_output_quant.
*/
template <typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false>
__global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps,
int tokens, int hidden_dim, const scale_type* scale_orig_quant_per_tensor, scale_type* scale_orig_quant_per_token,
int8_t* normed_output_quant, bool use_shmem)
{
constexpr auto num_elems_T = num_elems<T>::value;
using int8_packed_t = typename packed_as<int8_t, num_elems_T>::type;
using float_packed_t = typename packed_as<float, num_elems_T>::type;
using T_scalar = typename packed_as<T, 1>::type;
extern __shared__ __align__(sizeof(float)) char _shmem[];
T* shmem = reinterpret_cast<T*>(_shmem);
__shared__ float s_mean;
__shared__ float s_variance;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
float mean = 0.0f;
float variance = 0.0f;
float local_sum = 0.0f;
float local_var_sum = 0.0f;
const int n_elems = hidden_dim / num_elems_T;
for (int i = tidx; i < n_elems; i += blockDim.x)
{
const T val = input[bidx * n_elems + i];
if (use_shmem)
{
shmem[i] = val;
}
const float_packed_t val_f = cuda_cast<float_packed_t>(val);
local_sum += cuda_sum<float>(val_f);
if (USE_DIFF_OF_SQUARES)
{
local_var_sum += cuda_sum<float>(val_f * val_f);
}
}
if (USE_DIFF_OF_SQUARES)
{
float packed[2] = {local_sum, local_var_sum};
blockReduceSumV2<float, 2>(packed);
mean = packed[0];
variance = packed[1];
}
else
{
mean = blockReduceSum(local_sum);
}
if (threadIdx.x == 0)
{
mean = mean / hidden_dim;
s_mean = mean;
if (USE_DIFF_OF_SQUARES)
{
variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
s_variance = rsqrtf(variance + eps);
}
}
__syncthreads();
if (!USE_DIFF_OF_SQUARES)
{
for (int i = tidx; i < n_elems; i += blockDim.x)
{
const T val = use_shmem ? shmem[i] : input[bidx * n_elems + i];
float_packed_t diff = cuda_cast<float_packed_t>(val) - s_mean;
local_var_sum += cuda_sum<float>(diff * diff);
}
variance = blockReduceSum(local_var_sum);
if (threadIdx.x == 0)
{
s_variance = rsqrtf(variance / hidden_dim + eps);
}
__syncthreads();
}
const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr;
const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
const float_packed_t scale_orig_quant
= cuda_cast<float_packed_t>(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f);
T_scalar amax = 1e-6f;
for (int i = tidx; i < n_elems; i += blockDim.x)
{
const int index = bidx * n_elems + i;
const float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
const T val = cuda_cast<T>(compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i));
if (with_per_token_scaling)
{
amax = cuda_max(cuda_max<T_scalar, T>(cuda_abs(val)), amax);
if (use_shmem)
{
shmem[i] = val;
}
}
else if (with_per_tensor_scaling)
{
reinterpret_cast<int8_packed_t*>(normed_output_quant)[index]
= cuda_cast<int8_packed_t>(cuda_cast<float_packed_t>(val) * scale_orig_quant);
}
else
{
normed_output[index] = val;
}
}
if (with_per_token_scaling)
{
float abs_max_f = blockAllReduceMax(cuda_cast<float>(amax));
const float dynamic_per_token_scale = 127.f / abs_max_f;
for (int i = tidx; i < n_elems; i += blockDim.x)
{
const int index = bidx * n_elems + i;
float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
if (!use_shmem)
{
val_f = compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i);
}
reinterpret_cast<int8_packed_t*>(normed_output_quant)[index]
= cuda_cast<int8_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale));
}
if (tidx == 0)
{
scale_orig_quant_per_token[bidx] = abs_max_f / 127.f;
}
}
}
template <typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false>
__global__ void generalLayerNorm_fuse_sum(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps,
int tokens, int hidden_dim, scale_type* input_sum, const scale_type* scale_orig_quant_per_tensor, scale_type* scale_orig_quant_per_token,
int8_t* normed_output_quant, bool use_shmem)
{
constexpr auto num_elems_T = num_elems<T>::value;
using int8_packed_t = typename packed_as<int8_t, num_elems_T>::type;
using float_packed_t = typename packed_as<float, num_elems_T>::type;
using T_scalar = typename packed_as<T, 1>::type;
extern __shared__ __align__(sizeof(float)) char _shmem[];
T* shmem = reinterpret_cast<T*>(_shmem);
__shared__ float s_mean;
__shared__ float s_variance;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
float mean = 0.0f;
float variance = 0.0f;
float local_sum = 0.0f;
float local_var_sum = 0.0f;
const int n_elems = hidden_dim / num_elems_T;
for (int i = tidx; i < n_elems; i += blockDim.x)
{
const T val = input[bidx * n_elems + i];
if (use_shmem)
{
shmem[i] = val;
}
const float_packed_t val_f = cuda_cast<float_packed_t>(val);
local_sum += cuda_sum<float>(val_f);
if (USE_DIFF_OF_SQUARES)
{
local_var_sum += cuda_sum<float>(val_f * val_f);
}
}
if (USE_DIFF_OF_SQUARES)
{
float packed[2] = {local_sum, local_var_sum};
blockReduceSumV2<float, 2>(packed);
mean = packed[0];
variance = packed[1];
}
else
{
mean = blockReduceSum(local_sum);
}
if (threadIdx.x == 0)
{
mean = mean / hidden_dim;
s_mean = mean;
if (USE_DIFF_OF_SQUARES)
{
variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
s_variance = rsqrtf(variance + eps);
}
}
__syncthreads();
if (!USE_DIFF_OF_SQUARES)
{
for (int i = tidx; i < n_elems; i += blockDim.x)
{
const T val = use_shmem ? shmem[i] : input[bidx * n_elems + i];
float_packed_t diff = cuda_cast<float_packed_t>(val) - s_mean;
local_var_sum += cuda_sum<float>(diff * diff);
}
variance = blockReduceSum(local_var_sum);
if (threadIdx.x == 0)
{
s_variance = rsqrtf(variance / hidden_dim + eps);
}
__syncthreads();
}
const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr;
const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
const float_packed_t scale_orig_quant
= cuda_cast<float_packed_t>(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f);
T_scalar amax = 1e-6f;
T_scalar sum = 0.0f;
for (int i = tidx; i < n_elems; i += blockDim.x)
{
const int index = bidx * n_elems + i;
const float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
const T val = cuda_cast<T>(compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i));
if (with_per_token_scaling)
{
amax = cuda_max(cuda_max<T_scalar, T>(cuda_abs(val)), amax);
sum += cuda_sum<float>(val);
if (use_shmem)
{
shmem[i] = val;
}
}
else if (with_per_tensor_scaling)
{
reinterpret_cast<int8_packed_t*>(normed_output_quant)[index]
= cuda_cast<int8_packed_t>(cuda_cast<float_packed_t>(val) * scale_orig_quant);
}
else
{
normed_output[index] = val;
}
}
if (with_per_token_scaling)
{
float abs_max_f = blockAllReduceMax(cuda_cast<float>(amax));
float sum_f = blockAllReduceSum(cuda_cast<float>(sum));
const float dynamic_per_token_scale = 127.f / abs_max_f;
for (int i = tidx; i < n_elems; i += blockDim.x)
{
const int index = bidx * n_elems + i;
float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
if (!use_shmem)
{
val_f = compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i);
}
reinterpret_cast<int8_packed_t*>(normed_output_quant)[index]
= cuda_cast<int8_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale));
}
if (tidx == 0)
{
scale_orig_quant_per_token[bidx] = abs_max_f / 127.f;
input_sum[bidx] = sum_f;
}
}
}
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t, typename out_type, bool use_quant>
__global__ void
rms_norm_kernel(out_type *__restrict__ out, // [..., hidden_size]
const scalar_t *__restrict__ input, // [..., hidden_size]
const scalar_t *__restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
if constexpr (use_quant) {
out[blockIdx.x * hidden_size + idx] = float_to_int8_rn(
((float)(x * s_variance)) * (float)(weight[idx]));
} else {
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
}
}
template <typename T, typename scale_type, bool use_per_token_dequant>
__global__ void dequant_add_residual_rms_norm_quant_kernel(
const int32_t *__restrict__ input, T *__restrict__ residual,
int8_t *__restrict__ output, const T *__restrict__ gamma,
const float layernorm_eps, const scale_type scale, int num_tokens, int hidden_size) {
// layernorm module in the T5 style No bias and no subtraction of mean.
const int tid = threadIdx.x;
__shared__ float s_variance;
float variance = 0.0f;
float local_var_sum = 0.0f;
for (int i = tid; i < hidden_size; i += blockDim.x) {
float diff = 0.0f;
if constexpr (use_per_token_dequant) {
diff = ((((float)input[blockIdx.x * hidden_size + i]) * __half2float(scale[blockIdx.x])) +
(float)residual[blockIdx.x * hidden_size + i]);
} else {
diff = ((((float)input[blockIdx.x * hidden_size + i]) * __half2float(scale)) +
(float)residual[blockIdx.x * hidden_size + i]);
}
residual[blockIdx.x * hidden_size + i] = (T)diff;
local_var_sum += diff * diff;
}
variance = blockReduceSum(local_var_sum);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / (float)hidden_size + layernorm_eps);
}
__syncthreads();
for (int i = tid; i < hidden_size; i += blockDim.x) {
output[blockIdx.x * hidden_size + i] = float_to_int8_rn(
(((float)(residual[blockIdx.x * hidden_size + i])) * s_variance) *
(float)(gamma[i]));
}
}
} // namespace vllm
#include "misc_kernels_impl.cuh"
#include "misc_kernels.h"
#include "dispatch_utils.h"
Tensor add(Tensor a, Tensor b) {
assert(a.shape.dataExtent == b.shape.dataExtent);
assert(a.dtype() == b.dtype());
assert(a.is_contiguous());
assert(b.is_contiguous());
int threadsPerBlock = 1024;
int blocksPerGrid = (a.numel() + threadsPerBlock - 1) / threadsPerBlock;
auto stream = getCurrentCUDAStream();
Tensor out = Tensor::empty_like(a);
dispatch(out.scalar_type(), [&]<typename scalar_t>() {
add_kernel<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
a.data_ptr<scalar_t>(), b.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), out.numel());
});
return out;
}
void mul_add(Tensor x, Tensor scale, Tensor bias) {
// assert(scale.shape.data == bias.shape.data);
// FIXME FIXME
assert(x.numel() % scale.numel() == 0);
assert(x.numel() % bias.numel() == 0);
assert(x.dtype() == scale.dtype());
assert(x.dtype() == bias.dtype());
constexpr int unroll = 8;
assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);
assert((uintptr_t)scale.data_ptr() % (x.scalar_size() * unroll) == 0);
assert((uintptr_t)bias.data_ptr() % (x.scalar_size() * unroll) == 0);
assert(x.numel() % unroll == 0);
assert(scale.numel() % unroll == 0);
assert(bias.numel() % unroll == 0);
int threadsPerBlock = 1024;
int blocksPerGrid = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll);
auto stream = getCurrentCUDAStream();
dispatch(x.scalar_type(), [&]<typename scalar_t>() {
mul_add_kernel<scalar_t, unroll><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
x.data_ptr<scalar_t>(), scale.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), x.numel(), scale.numel(), bias.numel());
});
}
Tensor embedding(Tensor input_id, Tensor lookup) {
assert(input_id.dtype() == Tensor::INT32);
assert(lookup.ndims() == 2);
auto shapeOut = input_id.shape;
shapeOut.dataExtent.push_back(lookup.shape[-1]);
auto stream = getCurrentCUDAStream();
Tensor out = Tensor::empty(shapeOut, lookup.scalar_type(), input_id.device());
dispatch(out.scalar_type(), [&]<typename scalar_t>() {
EmbeddingKernel<<<input_id.numel(), std::min(lookup.shape[-1], 1024), 0, stream>>>(
input_id.data_ptr<int32_t>(), out.data_ptr<scalar_t>(), lookup.data_ptr<scalar_t>(), lookup.shape[-1]);
});
return out;
}
Tensor argmax_sample(Tensor logits) {
assert(logits.ndims() == 2);
auto stream = getCurrentCUDAStream();
Tensor out = Tensor::empty({logits.shape[0]}, Tensor::INT32, logits.device());
dispatch(logits.scalar_type(), [&]<typename scalar_t>() {
argmax_sample_kernel<<<logits.shape[0], std::min(logits.shape[1], 1024), 0, stream>>>(
logits.data_ptr<scalar_t>(), out.data_ptr<int32_t>(), logits.shape[1]
);
});
return out;
}
void splitqkv(Tensor qkv, Tensor q, Tensor k, Tensor v) {
// FIXME FIXME
// assert(qkv.shape[0] == q.shape[0]);
// assert(qkv.shape[0] == k.shape[0]);
// assert(qkv.shape[0] == v.shape[0]);
auto stream = getCurrentCUDAStream();
int dim_q = q.shape[-1] * q.shape[-2];
int dim_k = k.shape[-1] * k.shape[-2];
int dim_v = v.shape[-1] * v.shape[-2];
assert(dim_k == dim_v);
assert(dim_q + dim_k + dim_v == qkv.shape[-1]);
int num_tokens = qkv.numel() / qkv.shape[-1];
dispatch(qkv.scalar_type(), [&]<typename scalar_t>() {
splitqkv_kernel<<<num_tokens, std::min(qkv.shape[-1], 1024), 0, stream>>>(
qkv.data_ptr<scalar_t>(),
q.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(),
v.data_ptr<scalar_t>(),
dim_q,
dim_k
);
});
}
template<size_t N>
std::array<Tensor, N> split_mod(Tensor input) {
assert(input.shape[-1] % N == 0);
int threadsPerBlock = 1024;
int blocksPerGrid = (input.numel() + threadsPerBlock - 1) / threadsPerBlock;
auto stream = getCurrentCUDAStream();
auto shapeOut = input.shape;
shapeOut[-1] /= N;
std::array<Tensor, N> out;
for (int k = 0; k < N; k++) {
out[k] = Tensor::empty(shapeOut, input.scalar_type(), input.device());
}
dispatch(input.scalar_type(), [&]<typename scalar_t>() {
std::array<scalar_t *, N> outPtr;
for (int k = 0; k < N; k++) {
outPtr[k] = out[k].template data_ptr<scalar_t>();
}
split_mod_kernel<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
input.data_ptr<scalar_t>(),
outPtr, input.numel());
});
return out;
}
Tensor quant_static(Tensor x, float scale) {
Tensor out = Tensor::empty(x.shape, Tensor::INT8, x.device());
constexpr int unroll = 8;
assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);
int threadsPerBlock = 1024;
int blocksPerGrid = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll);
auto stream = getCurrentCUDAStream();
dispatch(x.scalar_type(), [&]<typename scalar_t>() {
quant_kernel_static<scalar_t, unroll><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
x.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), (scalar_t)scale, x.numel());
});
return out;
}
Tensor quant_static_fuse_gelu(Tensor x, float scale) {
Tensor out = Tensor::empty(x.shape, Tensor::INT8, x.device());
constexpr int unroll = 8;
assert((uintptr_t)x.data_ptr() % (x.scalar_size() * unroll) == 0);
int threadsPerBlock = 1024;
int blocksPerGrid = (x.numel() + threadsPerBlock * unroll - 1) / (threadsPerBlock * unroll);
auto stream = getCurrentCUDAStream();
dispatch(x.scalar_type(), [&]<typename scalar_t>() {
quant_kernel_static_fuse_gelu<scalar_t, unroll><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
x.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), (scalar_t)scale, x.numel());
});
return out;
}
Tensor topk(Tensor x, int k) {
constexpr int MAXK = 64 + 4;
const int N = x.shape[-1];
const int batch = x.numel() / N;
assert(k <= N);
assert(k <= MAXK);
auto outShape = x.shape;
outShape[-1] = k;
outShape.dataStride.clear();
Tensor out = Tensor::empty(outShape, Tensor::INT32, x.device());
auto stream = getCurrentCUDAStream();
dispatchVal(k, std::make_integer_sequence<int, MAXK + 1>(), [&]<int K>() {
if constexpr (K == 0) {
assert(false);
return;
}
if constexpr (K > 0) {
dispatch(x.scalar_type(), [&]<typename scalar_t>() {
topk_kernel<scalar_t, K><<<ceilDiv(batch, 32), 32, 0, stream>>>(
x.data_ptr<scalar_t>(),
out.data_ptr<int>(),
N, x.stride(-2), batch
);
checkCUDA(cudaGetLastError());
});
}
});
return out;
}
template std::array<Tensor, 2> split_mod<2>(Tensor input);
template std::array<Tensor, 3> split_mod<3>(Tensor input);
template std::array<Tensor, 4> split_mod<4>(Tensor input);
template std::array<Tensor, 5> split_mod<5>(Tensor input);
template std::array<Tensor, 6> split_mod<6>(Tensor input);
\ No newline at end of file
#pragma once
#include "common.h"
#include "Tensor.h"
Tensor add(Tensor a, Tensor b);
void mul_add(Tensor x, Tensor scale, Tensor bias);
Tensor embedding(Tensor input_id, Tensor lookup);
Tensor argmax_sample(Tensor logits);
void splitqkv(Tensor qkv, Tensor q, Tensor k, Tensor v);
Tensor quant_static(Tensor x, float scale);
Tensor quant_static_fuse_gelu(Tensor x, float scale);
Tensor topk(Tensor x, int k);
template<size_t N>
std::array<Tensor, N> split_mod(Tensor input);
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment