Unverified Commit 6512937d authored by HandH1998's avatar HandH1998 Committed by GitHub
Browse files

Support W4A8 quantization for vllm (#5218)

parent c0644cf9
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
model_name: "HandH1998/QQQ-Llama-3-8b-g128"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.409
- name: "exact_match,flexible-extract"
value: 0.406
limit: 1000
num_fewshot: 5
...@@ -7,3 +7,4 @@ Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml ...@@ -7,3 +7,4 @@ Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Minitron-4B-Base.yaml Minitron-4B-Base.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml Qwen2-1.5B-Instruct-FP8W8.yaml
Meta-Llama-3-8B-QQQ.yaml
...@@ -170,6 +170,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -170,6 +170,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
......
...@@ -115,6 +115,13 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, ...@@ -115,6 +115,13 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias); c10::optional<torch::Tensor> const& bias);
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
torch::Tensor const& b_q_weight,
torch::Tensor const& s_tok,
torch::Tensor const& s_ch,
torch::Tensor const& s_group,
torch::Tensor& workspace, int64_t size_m,
int64_t size_n, int64_t size_k);
#endif #endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
......
/*
* Modified by HandH1998
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) { return elems[i]; }
};
/*
* Modified by HandH1998
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
// Asynchronous global->shared copy
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
// Async copy fence.
__device__ inline void cp_async_fence() {
asm volatile("cp.async.commit_group;\n" ::);
}
// Wait until at most `n` async copy stages are still pending.
template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
}
__syncthreads();
}
// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
}
}
...@@ -25,6 +25,12 @@ ...@@ -25,6 +25,12 @@
#include <iostream> #include <iostream>
#include "common/base.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "common/mem.h"
#endif
template <typename T> template <typename T>
inline std::string str(T x) { inline std::string str(T x) {
return std::to_string(x); return std::to_string(x);
...@@ -32,23 +38,9 @@ inline std::string str(T x) { ...@@ -32,23 +38,9 @@ inline std::string str(T x) {
namespace marlin_dense { namespace marlin_dense {
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) { return elems[i]; }
};
using I4 = Vec<int, 4>; using I4 = Vec<int, 4>;
// Matrix fragments for tensor core instructions; their precise layout is // Matrix fragments for tensor core instructions; their precise layout is
// documented here: // documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
...@@ -57,43 +49,6 @@ using FragB = Vec<half2, 2>; ...@@ -57,43 +49,6 @@ using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales using FragS = Vec<half2, 1>; // quantization scales
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
// Asynchronous global->shared copy
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
// Async copy fence.
__device__ inline void cp_async_fence() {
asm volatile("cp.async.commit_group;\n" ::);
}
// Wait until at most `n` async copy stages are still pending.
template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation. // output/accumulation.
__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, __device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
...@@ -164,39 +119,6 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { ...@@ -164,39 +119,6 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
frag_b[1] = __hmul2(frag_b[1], s); frag_b[1] = __hmul2(frag_b[1], s);
} }
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
}
__syncthreads();
}
// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
}
}
template <const int threads, // number of threads in a threadblock template <const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
......
/*
* Adapted from
* https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda_kernel.cu
* https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda.cpp
* Modified by HandH1998
* Copyright (C) 2024 HandH1998
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include "../dense/common/base.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "../dense/common/mem.h"
#endif
template <typename T>
inline std::string str(T x) {
return std::to_string(x);
}
namespace {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using I4 = Vec<int, 4>;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type
using FragA = Vec<uint32_t, 2>;
using FragB = Vec<uint32_t, 1>;
using FragC = Vec<int, 4>;
using FragS_GROUP = Vec<half2, 1>; // weight per-group quantization scales
using FragS_CHANNEL =
Vec<float, 2>; // weight per-channel quantization scales or activaton
// per-token quantization scales
// NOTE(HandH1998): cp.async.cg only support BYTES = 16, however,
// cp.async.ca can support BYTES = 4, 8, 16;
// as s_tok's shape is equal to prob_m, we need set s_tok to float type,
// and cp_size = 1 float, i.e., 4 BYTES
// Asynchronous global->shared copy for activation quantizaton scales s_tok
__device__ inline void cp_async1(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 4;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" cp.async.ca.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
// m16n8k16 tensor core mma instruction with int8 inputs and int32
// output/accumulation.
__device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
int* c = reinterpret_cast<int*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
"r"(c[3]));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in int8 tensor core layout.
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
: "=r"(a[0]), "=r"(a[1])
: "r"(smem));
}
inline __device__ half2 float2_to_half2(float2 f) {
uint32_t res;
// NOTE(HandH1998): h0,h1 should be uint16_t, not half
uint16_t h0, h1;
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h0) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h1) : "f"(f.y));
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(h0), "h"(h1));
return reinterpret_cast<half2&>(res);
}
inline __device__ float int32_to_float(int h) {
float res;
asm volatile("cvt.rn.f32.s32 %0, %1;\n" : "=f"(res) : "r"(h));
return res;
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values
// for weight per channel dequant.
__device__ inline FragB dequant_per_channel(int q) {
static constexpr int MASK = 0xf0f0f0f0;
FragB frag_b;
frag_b[0] = (q & MASK);
return frag_b;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values
// for weight per group dequant.
__device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) {
static constexpr uint32_t LO = 0x000f000f;
static constexpr uint32_t HI = 0x00f000f0;
static constexpr uint32_t EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
static constexpr uint32_t SUB = 0x64086408;
static constexpr uint32_t MUL = 0x2c002c00;
static constexpr uint32_t ADD = 0xd480d480;
*reinterpret_cast<half2*>(&t0) = __hsub2(
*reinterpret_cast<half2*>(&t0), *reinterpret_cast<const half2*>(&SUB));
*reinterpret_cast<half2*>(&t1) = __hfma2(
*reinterpret_cast<half2*>(&t1), *reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
uint16_t s = reinterpret_cast<uint16_t*>(&frag_s)[i];
uint32_t double_s;
// pack 2xfp16 to half2
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(double_s) : "h"(s), "h"(s));
// dequant and convert 4 half to 4 uint8 (be placed at the low 8 bits of 4
// half, respectively)
static constexpr uint32_t MAGIC_NUM = 0x64806480;
*reinterpret_cast<half2*>(&t0) = __hfma2(
*reinterpret_cast<half2*>(&t0), *reinterpret_cast<half2*>(&double_s),
*reinterpret_cast<const half2*>(&MAGIC_NUM));
*reinterpret_cast<half2*>(&t1) = __hfma2(
*reinterpret_cast<half2*>(&t1), *reinterpret_cast<half2*>(&double_s),
*reinterpret_cast<const half2*>(&MAGIC_NUM));
// take out the 4 uint8 from 4 half, then convert them to 4 int8 and pack 4
// int8 into 1 uint32
FragB frag_b;
uint32_t uint8s;
static constexpr uint32_t MASK_0246 = 0x6420;
static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(uint8s)
: "r"(t0), "r"(t1), "n"(MASK_0246));
frag_b[0] = (uint8s ^ UINT8s_TO_INT8s_MASK);
return frag_b;
}
template <const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__ void Marlin(
const int4* __restrict__ A, // int8 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // int32 global_reduce buffer of shape
// (max_par*16*4)xn, as int8 tensor core's output is
// int32 dtype
int4* __restrict__ D, // fp16 output buffer of shape mxn
const float* __restrict__ s_tok, // fp32 activation per-token quantization
// scales of shape mx1
const int4* __restrict__ s_ch, // fp32 weight per-channel quantization
// scales of shape 1xn
const int4* __restrict__ s_group, // fp16 weight per-group quantization
// scales of shape (k/groupsize)xn, when
// group_blocks=-1, it should be nullptr
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int parallel = 1;
if (prob_m > 16 * thread_m_blocks) {
parallel = prob_m / (16 * thread_m_blocks);
prob_m = 16 * thread_m_blocks;
}
int k_tiles = prob_k / 16 / thread_k_blocks;
int n_tiles = prob_n / 16 / thread_n_blocks;
int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts in
// the middle of group.
if constexpr (group_blocks != -1)
iters = (group_blocks / thread_k_blocks) *
ceildiv(iters, (group_blocks / thread_k_blocks));
int slice_row = (iters * blockIdx.x) % k_tiles;
int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par;
int slice_iters; // number of threadblock tiles in the current slice
int slice_count =
0; // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to
// top
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if (slice_col_par >= n_tiles) {
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 16;
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 4;
D += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
s_tok += (slice_col_par / n_tiles) * 16 * thread_m_blocks;
locks += (slice_col_par / n_tiles) * n_tiles;
slice_col = slice_col_par % n_tiles;
}
// Compute all information about the current slice which is required for
// synchronization.
auto init_slice = [&]() {
slice_iters =
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
if (slice_iters == 0) return;
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
slice_count = 1;
slice_idx = 0;
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
if (col_first <= k_tiles * (slice_col_par + 1)) {
int col_off = col_first - k_tiles * slice_col_par;
slice_count = ceildiv(k_tiles - col_off, iters);
if (col_off > 0) slice_count++;
int delta_first = iters * blockIdx.x - col_first;
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
slice_idx = slice_count - 1;
else {
slice_idx = slice_count - 1 - delta_first / iters;
if (col_off > 0) slice_idx--;
}
}
if (slice_col == n_tiles) {
A += 16 * thread_m_blocks * prob_k / 16;
C += 16 * thread_m_blocks * prob_n / 4;
D += 16 * thread_m_blocks * prob_n / 8;
s_tok += 16 * thread_m_blocks;
locks += n_tiles;
slice_col = 0;
}
};
init_slice();
int a_gl_stride = prob_k / 16; // stride of the A matrix in global memory
// We typically use `constexpr` to indicate that this value is a compile-time
// constant
constexpr int a_sh_stride =
16 * thread_k_blocks / 16; // stride of an A matrix tile in shared memory
constexpr int a_gl_rd_delta_o =
16 * thread_k_blocks /
16; // delta between subsequent A tiles in global memory
int a_gl_rd_delta_i =
a_gl_stride *
(threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
constexpr int a_sh_wr_delta =
a_sh_stride *
(threads / a_gl_rd_delta_o); // between shared memory writes
constexpr int a_sh_rd_delta_o =
1 * ((threads / 32) /
(thread_n_blocks / 4)); // between shared memory tile reads
constexpr int a_sh_rd_delta_i =
a_sh_stride * 16; // within a shared memory tile
constexpr int a_sh_stage =
a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
constexpr int a_sh_wr_iters =
ceildiv(a_sh_stage,
a_sh_wr_delta); // number of shared write iterations for a tile
int b_gl_stride = 16 * prob_n / 32;
constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
constexpr int b_sh_wr_delta = threads;
constexpr int b_sh_rd_delta = threads;
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
constexpr int s_tok_sh_stride = 16 * thread_m_blocks;
constexpr int s_ch_sh_stride = 16 * thread_n_blocks / 4;
int s_group_gl_stride = prob_n / 8;
constexpr int s_group_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_group_sh_stage = s_group_sh_stride;
int s_group_gl_rd_delta = s_group_gl_stride;
// Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
a_gl_rd += a_gl_rd_delta_o * slice_row;
// Shared write index of current thread.
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
// Shared read index.
// NOTE(HandH1998): int8 input a only need 16 threads to load 16x16 matrix
int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16);
a_sh_rd += 1 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
int b_gl_rd =
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x;
int b_sh_rd = threadIdx.x;
int s_tok_gl_rd = threadIdx.x;
// NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10,
// 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for
// thread 0, 1, 2, 3. For more details, refer to mma operand A layout as
// s_tok's size is not fixed, we can not shuffle before inference we shuffle
// it when fetching s_tok from global memory to shared memory, that's why
// s_tok_sh_wr is like this
int s_tok_sh_wr =
(threadIdx.x / 16) * 16 + (threadIdx.x % 8) * 2 + (threadIdx.x % 16) / 8;
int s_tok_sh_rd = (threadIdx.x % 32) / 4;
bool s_tok_sh_wr_pred = threadIdx.x < prob_m;
int s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
int s_ch_sh_wr = threadIdx.x;
int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
2 * ((threadIdx.x % 32) % 4);
bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride;
int s_group_gl_rd, s_group_sh_wr, s_group_sh_rd;
bool s_group_sh_wr_pred;
if constexpr (group_blocks != -1) {
s_group_gl_rd =
s_group_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_group_sh_stride * slice_col + threadIdx.x;
s_group_sh_wr = threadIdx.x;
// NOTE(HandH1998): s_group_sh_rd is related to mma output C
s_group_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
s_group_sh_wr_pred = threadIdx.x < s_group_sh_stride;
}
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool a_sh_wr_pred[a_sh_wr_iters];
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++)
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto transform_a = [&](int i) {
int row = i / a_gl_rd_delta_o;
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int a_sh_wr_trans[a_sh_wr_iters];
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++)
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < thread_m_blocks; j++)
a_sh_rd_trans[i][j] =
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const int4* B_ptr[b_sh_wr_iters];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
extern __shared__ int4 sh[];
// Shared memory storage for global fetch pipelines.
// NOTE(HandH1998): stages need >= 4, otherwise, sh_s_tok = sh + max(stages *
// a_sh_stage + stages * b_sh_stage, 4 * stages * a_sh_stage)
int4* sh_a = sh;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_s_tok = sh_b + (stages * b_sh_stage);
int4* sh_s_ch = sh_s_tok + s_tok_sh_stride;
int4* sh_s_group = sh_s_ch + s_ch_sh_stride;
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2];
FragC frag_c[thread_m_blocks][4][2];
FragS_GROUP frag_s_group[2][4];
FragS_CHANNEL frag_s_tok[thread_m_blocks];
FragS_CHANNEL frag_s_ch[2][4];
// Zero accumulators.
auto zero_accums = [&]() {
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
reinterpret_cast<int*>(frag_c)[i] = 0;
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
if (pred) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
cp_async4_pred(
&sh_a_stage[a_sh_wr_trans[i]],
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
a_sh_wr_pred[i]);
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
B_ptr[i] += b_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if constexpr (group_blocks != -1) {
if (pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_group_stage = sh_s_group + s_group_sh_stage * pipe;
if (s_group_sh_wr_pred)
cp_async4(&sh_s_group_stage[s_group_sh_wr],
&s_group[s_group_gl_rd]);
s_group_gl_rd += s_group_gl_rd_delta;
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence();
};
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<stages - 2>();
__syncthreads();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto fetch_to_registers = [&](int k, int pipe) {
// It may seem inefficient that we reload the groups for every sub-tile;
// however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance.
if constexpr (group_blocks != -1) {
int4* sh_s_group_stage =
sh_s_group +
s_group_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s_group[k % 2])[0] =
sh_s_group_stage[s_group_sh_rd];
}
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++)
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
frag_b_quant[k % 2] = *reinterpret_cast<I4*>(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
};
// Execute the actual tensor core matmul of a sub-tile.
auto matmul = [&](int k) {
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for (int j = 0; j < 4; j++) {
int b_quant = frag_b_quant[k % 2][j];
// int b_quant_shift = b_quant << 4;
FragB frag_b0, frag_b1;
// If there are no groups, we can just scale the final output once and can
// avoid doing so for each weight.
if constexpr (group_blocks != -1) {
int b_quant_shift = b_quant >> 8;
frag_b0 = dequant_per_group(b_quant, frag_s_group[k % 2][j], 0);
frag_b1 = dequant_per_group(b_quant_shift, frag_s_group[k % 2][j], 1);
} else {
int b_quant_shift = b_quant << 4;
frag_b0 = dequant_per_channel(b_quant);
frag_b1 = dequant_per_channel(b_quant_shift);
}
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride;
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
constexpr int red_sh_delta = b_sh_stride;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
(threadIdx.x % b_sh_stride);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
#pragma unroll
for (int i = red_off; i > 0; i /= 2) {
if (i <= red_idx && red_idx < 2 * i) {
#pragma unroll
for (int j = 0; j < 4 * 2; j++) {
int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) {
int* c_rd =
reinterpret_cast<int*>(&sh[red_sh_delta * j + red_sh_rd]);
int* c_wr = reinterpret_cast<int*>(&sh[red_sh_wr]);
#pragma unroll
for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
c_rd[k] + c_wr[k];
}
sh[red_sh_wr] =
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
}
}
__syncthreads();
}
if (red_idx == 0) {
#pragma unroll
for (int i = 0; i < 4 * 2; i++) {
int* c_rd =
reinterpret_cast<int*>(&sh[red_sh_delta * i + red_sh_rd]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
c_rd[j];
}
}
__syncthreads();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
// global_reduce works on INT32 elements, which are the results of INT8 GEMM.
// This is why we need another INT32 maxtrix `C` to reduce instead of the
// original half matrix `D`.
auto global_reduce = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr int active_threads = 32 * thread_n_blocks / 4;
if (threadIdx.x < active_threads) {
int c_gl_stride = prob_n / 4;
int c_gl_wr_delta_o = 8 * c_gl_stride;
int c_gl_wr_delta_i = 8 * (active_threads / 32);
int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2;
c_gl_wr += (4 * thread_n_blocks) * slice_col;
constexpr int c_sh_wr_delta = active_threads * 2;
int c_sh_wr = 2 * threadIdx.x;
int row = (threadIdx.x % 32) / 4;
if (!first) {
// Interestingly, doing direct global accesses here really seems to mess up
// the compiler and lead to slowdowns, hence we also use async-copies even
// though these fetches are not actually asynchronous.
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(
&sh[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
cp_async4_pred(
&sh[c_sh_wr + c_sh_wr_delta * i + 1],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2) + 1],
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
}
cp_async_fence();
cp_async_wait<0>();
}
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
if (!first) {
int4 d_red1 = sh[c_sh_wr + i * c_sh_wr_delta];
int4 d_red2 = sh[c_sh_wr + i * c_sh_wr_delta + 1];
#pragma unroll
for (int j = 0; j < 4; j++) {
reinterpret_cast<int*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
reinterpret_cast<int*>(&d_red1)[j];
}
#pragma unroll
for (int j = 0; j < 4; j++) {
reinterpret_cast<int*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)] +=
reinterpret_cast<int*>(&d_red2)[j];
}
}
if (!last) {
int4 d1, d2;
#pragma unroll
for (int j = 0; j < 4; j++) {
reinterpret_cast<int*>(&d1)[j] = reinterpret_cast<int*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)];
}
#pragma unroll
for (int j = 0; j < 4; j++) {
reinterpret_cast<int*>(&d2)[j] = reinterpret_cast<int*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)];
}
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
d1;
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) +
1] = d2;
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto write_result = [&]() {
int d_gl_stride = prob_n / 8;
constexpr int d_sh_stride = 2 * thread_n_blocks + 1;
int d_gl_wr_delta = d_gl_stride * (threads / (2 * thread_n_blocks));
constexpr int d_sh_rd_delta =
d_sh_stride * (threads / (2 * thread_n_blocks));
int d_gl_wr = d_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
(threadIdx.x % (2 * thread_n_blocks));
d_gl_wr += (2 * thread_n_blocks) * slice_col;
int d_sh_wr =
(4 * d_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
d_sh_wr += 32 * (threadIdx.x / 32);
int d_sh_rd = d_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
(threadIdx.x % (2 * thread_n_blocks));
int d_gl_wr_end = d_gl_stride * prob_m;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, int c0, int c1, float a_s, FragS_CHANNEL& w_s) {
float2 deq_res;
deq_res.x = int32_to_float(c0) * w_s[0] * a_s;
deq_res.y = int32_to_float(c1) * w_s[1] * a_s;
((half2*)sh)[idx] = float2_to_half2(deq_res);
};
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
int wr = d_sh_wr + 8 * j;
write(wr + (4 * d_sh_stride) * 0 + 0, frag_c[i][j][0][0],
frag_c[i][j][0][1], frag_s_tok[i][0],
frag_s_ch[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * d_sh_stride) * 8 + 0, frag_c[i][j][0][2],
frag_c[i][j][0][3], frag_s_tok[i][1],
frag_s_ch[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * d_sh_stride) * 0 + 4, frag_c[i][j][1][0],
frag_c[i][j][1][1], frag_s_tok[i][0],
frag_s_ch[j / 2][2 * (j % 2) + 1]);
write(wr + (4 * d_sh_stride) * 8 + 4, frag_c[i][j][1][2],
frag_c[i][j][1][3], frag_s_tok[i][1],
frag_s_ch[j / 2][2 * (j % 2) + 1]);
}
d_sh_wr += 16 * (4 * d_sh_stride);
}
}
__syncthreads();
#pragma unroll
for (int i = 0;
i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
i++) {
if (d_gl_wr < d_gl_wr_end) {
D[d_gl_wr] = sh[d_sh_rd];
d_gl_wr += d_gl_wr_delta;
d_sh_rd += d_sh_rd_delta;
}
}
};
// Start global fetch and register load pipelines.
auto start_pipes = [&]() {
#pragma unroll
for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
zero_accums();
wait_for_stage();
fetch_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
};
start_pipes();
// Main loop.
while (slice_iters) {
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines have
// even length meaning that the next iteration will always start at index 0.
#pragma unroll
for (int pipe = 0; pipe < stages;) {
#pragma unroll
for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages);
if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages);
pipe++;
wait_for_stage();
}
matmul(k);
}
slice_iters--;
if (slice_iters == 0) break;
}
a_gl_rd += a_gl_rd_delta_o * stages;
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) {
cp_async_wait<0>();
bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before
// write-out
if (last) {
if (s_tok_sh_wr_pred) {
cp_async1(&sh_s_tok[s_tok_sh_wr], &s_tok[s_tok_gl_rd]);
}
if (s_ch_sh_wr_pred) {
cp_async4(&sh_s_ch[s_ch_sh_wr], &s_ch[s_ch_gl_rd]);
}
cp_async_fence();
}
thread_block_reduce();
if (last) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
frag_s_tok[i][0] =
*reinterpret_cast<float*>(&sh_s_tok[16 * i + 2 * s_tok_sh_rd]);
frag_s_tok[i][1] = *reinterpret_cast<float*>(
&sh_s_tok[16 * i + 2 * s_tok_sh_rd + 1]);
}
reinterpret_cast<int4*>(&frag_s_ch)[0] = sh_s_ch[s_ch_sh_rd + 0];
reinterpret_cast<int4*>(&frag_s_ch)[1] = sh_s_ch[s_ch_sh_rd + 1];
reinterpret_cast<int4*>(&frag_s_ch)[2] = sh_s_ch[s_ch_sh_rd + 8];
reinterpret_cast<int4*>(&frag_s_ch)[3] = sh_s_ch[s_ch_sh_rd + 9];
}
}
if (slice_count > 1) { // only globally reduce if there is more than one
// block in a slice
barrier_acquire(&locks[slice_col], slice_idx);
global_reduce(slice_idx == 0, last);
barrier_release(&locks[slice_col], last);
}
if (last) // only the last block in a slice actually writes the result
write_result();
slice_row = 0;
slice_col_par++;
slice_col++;
init_slice();
if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
if (slice_col == 0) {
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
}
s_group_gl_rd = s_group_sh_stride * slice_col + threadIdx.x;
s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
start_pipes();
}
}
}
}
#else
template <const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__ void Marlin(
const int4* __restrict__ A, // int8 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // int32 global_reduce buffer of shape
// (max_par*16*4)xn, as int8 tensor core's output is
// int32 dtype
int4* __restrict__ D, // fp16 output buffer of shape mxn
const float* __restrict__ s_tok, // fp32 activation per-token quantization
// scales of shape mx1
const int4* __restrict__ s_ch, // fp32 weight per-channel quantization
// scales of shape 1xn
const int4* __restrict__ s_group, // fp16 weight per-group quantization
// scales of shape (k/groupsize)xn, when
// group_blocks=-1, it should be nullptr
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
) {
// Marlin is not implemented yet for SM < 8.0
assert(false);
return;
}
#endif
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
const int USER_THREADS =
256; // Note: This is only used with user-provided thread_k/n
const int STAGES = 4; // 4 pipeline stages fit into shared memory
static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;
static constexpr int tile_size = 16;
static constexpr int max_par = 16;
static constexpr int pack_factor_4bit =
8; // We have 8 4-bit vals inside a 32 bit
#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
GROUP_BLOCKS, NUM_THREADS) \
else if (thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute(Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
max_shared_mem); \
Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, D_ptr, s_tok_ptr, s_ch_ptr, s_group_ptr, \
prob_m, prob_n, prob_k, locks); \
}
typedef struct {
int thread_k;
int thread_n;
int num_threads;
} thread_config_t;
thread_config_t small_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{128, 128, 256}, // Default
{128, 64, 128}, // Reduce N 2X, same K
{64, 256, 256}, // Reduce K 2X, increase N 2X
{64, 128, 128}, // Reduce K 2X, same N
};
thread_config_t large_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{64, 256, 256}, // Default
{128, 128, 256}, // Reduce N 2X, increase K 2X
{64, 128, 128}, // Reduce N 2X, same K
{128, 64, 128}, // Reduce N 4X, increase K 2X
};
bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n,
int prob_k) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
return false;
}
// Verify K/N are divisible by thread K/N
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
return false;
}
// thread_k can be only 128 or 64 (because it must be less than groupsize
// which is 128)
if (th_config.thread_k != 128 && th_config.thread_k != 64) {
return false;
}
// Verify min for thread K/N
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
return false;
}
// num_threads must be at least 128 (= 4 warps)
if (th_config.num_threads < 128) {
return false;
}
return true;
}
thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
if (prob_m <= 16) {
for (auto th_config : small_batch_thread_configs) {
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
return th_config;
}
}
} else {
for (auto th_config : large_batch_thread_configs) {
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
return th_config;
}
}
}
return thread_config_t{-1, -1, -1};
}
#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
void marlin_qqq_cuda(const void* A, const void* B, void* C, void* D,
void* s_tok, void* s_ch, void* s_group, int prob_m,
int prob_n, int prob_k, void* workspace,
int groupsize = -1, int dev = 0, cudaStream_t stream = 0,
int thread_k = -1, int thread_n = -1, int sms = -1,
int max_par = 16) {
int tot_m = prob_m;
int tot_m_blocks = ceildiv(tot_m, 16);
int pad = 16 * tot_m_blocks - tot_m;
if (sms == -1)
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
// Set thread config
thread_config_t th_config;
if (thread_k != -1 && thread_n != -1) {
// User-defined config
th_config = thread_config_t{thread_k, thread_n, USER_THREADS};
} else {
// Auto config
th_config = determine_thread_config(prob_m, prob_n, prob_k);
}
if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) {
throw std::runtime_error(
"Invalid thread config: thread_k = " + str(th_config.thread_k) +
", thread_n = " + str(th_config.thread_n) +
", num_threads = " + str(th_config.num_threads) + " for MKN = [" +
str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]");
}
int num_threads = th_config.num_threads;
thread_k = th_config.thread_k;
thread_n = th_config.thread_n;
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
int blocks = sms;
if (prob_m == 0 || prob_n == 0 || prob_k == 0) {
return;
}
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
" is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k);
if (group_blocks != -1) {
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
}
const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
int4* D_ptr = (int4*)D;
const float* s_tok_ptr = (const float*)s_tok;
const int4* s_ch_ptr = (const int4*)s_ch;
const int4* s_group_ptr = (const int4*)s_group;
int* locks = (int*)workspace;
for (int i = 0; i < tot_m_blocks; i += 4) {
int thread_m_blocks = tot_m_blocks - i;
prob_m = tot_m - 16 * i;
int par = 1;
if (thread_m_blocks > 4) {
// Note that parallel > 1 currently only works for inputs without any
// padding
par = (16 * thread_m_blocks - pad) / 64;
if (par > max_par) par = max_par;
prob_m = 64 * par;
i += 4 * (par - 1);
thread_m_blocks = 4;
}
// For compilation speed, we only define the kernel configurations that have
// seemed useful (in terms of performance) in our testing, however many more
// are, in principle, possible.
if (false) {
}
CALL_IF(8, 8, 256)
CALL_IF(16, 4, 256)
CALL_IF(8, 4, 128)
CALL_IF(4, 8, 128)
else {
throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
", " + str(prob_k) + ", " + str(prob_n) + "]" +
", groupsize = " + str(groupsize) +
", thread_m_blocks = " + str(thread_m_blocks) +
", thread_n_blocks = " + str(thread_n_blocks) +
", thread_k_blocks = " + str(thread_k_blocks));
}
A_ptr += 16 * thread_m_blocks * (prob_k / 16) * par;
D_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
s_tok_ptr += 16 * thread_m_blocks * par;
}
}
} // anonymous namespace
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
torch::Tensor const& b_q_weight,
torch::Tensor const& s_tok,
torch::Tensor const& s_ch,
torch::Tensor const& s_group,
torch::Tensor& workspace, int64_t size_m,
int64_t size_n, int64_t size_k) {
// Verify M
TORCH_CHECK(size_m == a.size(0),
"Shape mismatch: a.size(0) = " + str(a.size(0)) +
", size_m = " + str(size_m));
TORCH_CHECK(size_m == s_tok.numel(),
"Shape mismatch: s_tok.numel() = " + str(s_tok.numel()) +
", size_m = " + str(size_m));
// Verify K
TORCH_CHECK(size_k == a.size(1),
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
", size_k = " + str(size_k));
TORCH_CHECK(size_k % tile_size == 0,
"size_k = " + str(size_k) +
" is not divisible by tile_size = " + str(tile_size));
TORCH_CHECK(
(size_k / tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = " + str(b_q_weight.size(0)) +
", size_k = " + str(size_k) + ", tile_size = " + str(tile_size));
int groupsize = (s_group.numel() == 0) ? -1 : size_k / s_group.size(0);
// Verify groupsize
TORCH_CHECK(groupsize == -1 || groupsize == 128,
"Unexpected groupsize = " + str(groupsize));
// Verify N
TORCH_CHECK(s_ch.numel() == size_n,
"Shape mismatch: s_ch.numel() = " + str(s_ch.numel()) +
", size_n = " + str(size_n));
TORCH_CHECK(b_q_weight.size(1) % tile_size == 0,
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
" is not divisible by tile_size = " + str(tile_size));
if (groupsize != -1) {
TORCH_CHECK(s_group.size(1) == size_n,
"Shape mismatch: s_group.size(1) = " + str(s_group.size(1)) +
", size_n = " + str(size_n));
TORCH_CHECK(
size_k % s_group.size(0) == 0,
"size_k = " + str(size_k) +
", is not divisible by s_group.size(0) = " + str(s_group.size(0)));
}
int actual_size_n = (b_q_weight.size(1) / tile_size) * pack_factor_4bit;
TORCH_CHECK(size_n == actual_size_n,
"Shape mismatch: size_n = " + str(size_n) +
", actual_size_n = " + str(actual_size_n));
// Verify A device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
// Verify B device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
// Verify s_tok device, strides and dtype
TORCH_CHECK(s_tok.device().is_cuda(), "s_tok is not on GPU");
TORCH_CHECK(s_tok.is_contiguous(), "s_tok is not contiguous");
TORCH_CHECK(s_tok.dtype() == torch::kFloat32, "s_tok's dtype is not float32");
// Verify s_ch device, strides and dtype
TORCH_CHECK(s_ch.device().is_cuda(), "s_ch is not on GPU");
TORCH_CHECK(s_ch.is_contiguous(), "s_ch is not contiguous");
TORCH_CHECK(s_ch.dtype() == torch::kFloat32, "s_ch's dtype is not float32");
// Verify s_group device, strides and dtype
TORCH_CHECK(s_group.device().is_cuda(), "s_group is not on GPU");
TORCH_CHECK(s_group.is_contiguous(), "s_group is not contiguous");
TORCH_CHECK(s_group.dtype() == torch::kFloat16,
"s_group's dtype is not float16");
// Verify workspace size
TORCH_CHECK(size_n % min_thread_n == 0,
"size_n = " + str(size_n) +
", is not divisible by min_thread_n = " + str(min_thread_n));
int min_workspace_size = (size_n / min_thread_n) * max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = " + str(workspace.numel()) +
" is below min_workspace_size = " + str(min_workspace_size));
// Alloc C matrix
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options_c = torch::TensorOptions().dtype(torch::kInt).device(a.device());
torch::Tensor c = torch::empty({max_par * 64, size_n}, options_c);
// Alloc D matrix
auto options_d =
torch::TensorOptions().dtype(torch::kFloat16).device(a.device());
torch::Tensor d = torch::empty({size_m, size_n}, options_d);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_k = -1;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_n = -1;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int sms = -1;
int dev = a.get_device();
marlin_qqq_cuda(
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), d.data_ptr(),
s_tok.data_ptr(), s_ch.data_ptr(), s_group.data_ptr(), size_m, size_n,
size_k, workspace.data_ptr(), groupsize, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par);
return d;
}
...@@ -149,6 +149,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -149,6 +149,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm); ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm); ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
// marlin_qqq_gemm for QQQ.
ops.def("marlin_qqq_gemm", &marlin_qqq_gemm);
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization. // quantization.
ops.def( ops.def(
......
...@@ -10,6 +10,9 @@ from vllm import _custom_ops as ops ...@@ -10,6 +10,9 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.qqq import (
MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N,
MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS, MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS,
...@@ -21,6 +24,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( ...@@ -21,6 +24,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_weights) marlin_weights)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize) marlin_24_quantize)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501
marlin_qqq_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp, awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp,
sort_weights) sort_weights)
...@@ -425,3 +430,64 @@ def test_awq_marlin_gemm( ...@@ -425,3 +430,64 @@ def test_awq_marlin_gemm(
print("max_diff = {}".format(max_diff)) print("max_diff = {}".format(max_diff))
assert max_diff < 0.04 assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("qqq"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_marlin_qqq_gemm(
k_chunk,
n_chunk,
num_bits,
group_size,
mnk_factors,
):
int8_traits = torch.iinfo(torch.int8)
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
# Quantize activations
s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(
torch.float)
q_a = (a_input / s_a).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
# Quantize weights
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \
marlin_qqq_quantize(b_weight, num_bits, group_size)
workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
MARLIN_QQQ_MAX_PARALLEL)
output = ops.marlin_qqq_gemm(
q_a,
marlin_qqq_q_w,
s_a,
marlin_qqq_s_channel,
marlin_qqq_s_group,
workspace.scratch,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
)
output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
...@@ -389,6 +389,15 @@ def scaled_int8_quant( ...@@ -389,6 +389,15 @@ def scaled_int8_quant(
return output, input_scales return output, input_scales
# qqq ops
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
workspace, size_m, size_n, size_k)
# moe # moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor, block_size: int, sorted_token_ids: torch.Tensor,
......
...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import ( ...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config) GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...@@ -37,6 +38,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -37,6 +38,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"squeezellm": SqueezeLLMConfig, "squeezellm": SqueezeLLMConfig,
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig, "bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
} }
......
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
MARLIN_QQQ_TILE = 16
MARLIN_QQQ_MIN_THREAD_N = 64
MARLIN_QQQ_MIN_THREAD_K = 128
MARLIN_QQQ_MAX_PARALLEL = 16
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
MARLIN_QQQ_SUPPORTED_SYM = [True]
class QQQConfig(QuantizationConfig):
"""Config class for QQQ
Reference: https://arxiv.org/pdf/2406.09904
"""
def __init__(
self,
weight_bits: int,
group_size: int,
is_sym: bool = True,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.is_sym = is_sym
# Verify
if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS:
raise ValueError(
f"QQQ does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} "
"are supported.")
if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"QQQ does not support group_size = {self.group_size}. "
f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} "
"are supported.")
if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM:
raise ValueError(
f"QQQ does not support is_sym = {self.is_sym}. "
f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported.")
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // self.weight_bits
# Tile size used by QQQ kernels.
self.tile_size = MARLIN_QQQ_TILE
# Min out_features dim
self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N
# Min in_features dim
self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K
# Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = MARLIN_QQQ_MAX_PARALLEL
# Permutation length used by the QQQ kernels.
self.perm_len = 1024
def __repr__(self) -> str:
return "QQQConfig(weight_bits={}, group_size={})".format(
self.weight_bits, self.group_size)
@classmethod
def get_name(cls) -> str:
return "qqq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
"""List of filenames to search for in the model directory."""
return [
"quant_config.json",
"quantize_config.json",
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "QQQConfig":
weight_bits = cls.get_from_keys(config, ["wbits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QQQLinearMethod"]:
if isinstance(layer, LinearBase):
return QQQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class QQQLinearMethod(LinearMethodBase):
"""Linear method for QQQ.
Args:
quant_config: The QQQ quantization config.
"""
def __init__(self, quant_config: QQQConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"min_n_threads = {self.quant_config.min_n_threads}.")
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"pack_factor = {self.quant_config.pack_factor}.")
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"min_k_threads = {self.quant_config.min_k_threads}.")
if (self.quant_config.group_size != -1 and
input_size_per_partition % self.quant_config.group_size != 0):
raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}.")
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self.quant_config.perm_len // (
self.quant_config.tile_size**2)
if output_size_per_partition % num_tiles_per_perm != 0:
raise ValueError(
"Each permutation group must reside on the same gpu")
# Quantized 4Bit weights packed into Int32.
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.tile_size,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"marlin_tile_size": self.quant_config.tile_size,
},
)
s_channel = Parameter(
torch.empty(
1,
output_size_per_partition,
device="cuda",
dtype=torch.float,
),
requires_grad=False,
)
set_weight_attrs(
s_channel,
{
"input_dim": None,
"output_dim": 1,
},
)
if self.quant_config.group_size == -1:
s_group = Parameter(
torch.tensor(
[],
device="cuda",
dtype=torch.half,
),
requires_grad=False,
)
else:
s_group = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
device="cuda",
dtype=torch.half,
),
requires_grad=False,
)
set_weight_attrs(
s_group,
{
"input_dim": None if self.quant_config.group_size == -1 else 0,
"output_dim":
None if self.quant_config.group_size == -1 else 1,
},
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = Parameter(torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
requires_grad=False)
layer.register_parameter("B", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("s_channel", s_channel)
set_weight_attrs(s_channel, extra_weight_attrs)
layer.register_parameter("s_group", s_group)
set_weight_attrs(s_group, extra_weight_attrs)
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.B
s_ch = layer.s_channel
s_group = layer.s_group
workspace = layer.workspace
x_2d = x.view(-1, x.shape[-1])
size_m = x_2d.shape[0]
size_k = x_2d.shape[1]
size_n = s_ch.shape[1]
x_int8, s_tok = ops.scaled_int8_quant(x_2d)
output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group,
workspace, size_m, size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
if bias is not None:
output.add_(bias) # In-place add
return output
from typing import List
import numpy
import torch
from .marlin_utils_test import marlin_permute_weights
from .quant_utils import get_pack_factor, qqq_quantize_weights
def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
# Permute
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
# Pack
pack_factor = get_pack_factor(num_bits)
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
dtype=numpy.uint32)
if group_size == size_k:
for i in range(pack_factor):
q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i
else:
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << num_bits * i
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
return q_packed
def get_qqq_scale_perms():
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
def get_qqq_weight_perm(num_bits: int, quant_type: str):
perm_list: List[int] = []
for i in range(32):
perm1: List[int] = []
col = i // 4
for block in [0, 1]:
for row in [
4 * (i % 4),
4 * (i % 4) + 1,
4 * (i % 4) + 2,
4 * (i % 4) + 3,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = numpy.array(perm_list)
assert quant_type in ["per-channel",
"per-group"], "not supported quantization type"
if num_bits == 4:
if quant_type == "per-channel":
interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3])
else:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
else:
raise Exception("num_bits must be 4, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size):
scale_perm, scale_perm_single = get_qqq_scale_perms()
if group_size < size_k and group_size != -1:
s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm]
s_channel = s_channel.reshape(
(-1, len(scale_perm_single)))[:, scale_perm_single]
s_group = s_group.reshape((-1, size_n)).contiguous()
else:
s_channel = s_channel.reshape(
(-1, len(scale_perm_single)))[:, scale_perm_single]
s_channel = s_channel.reshape((-1, size_n)).contiguous()
return s_group, s_channel
def marlin_qqq_quantize(
w: torch.Tensor,
num_bits: int,
group_size: int,
):
size_k, size_n = w.shape
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
quant_type = "per-channel" if group_size == size_k else "per-group"
# Quantize
w_ref, q_w, s_group, s_channel = qqq_quantize_weights(
w, num_bits, group_size)
# Reformat to marlin_qqq
weight_perm = get_qqq_weight_perm(num_bits, quant_type)
marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits,
weight_perm, group_size)
marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales(
s_group, s_channel, size_k, size_n, group_size)
# Create result
res_list = [
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel
]
for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device)
return res_list
...@@ -205,6 +205,88 @@ def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int): ...@@ -205,6 +205,88 @@ def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
) )
# QQQ employs different quant schemes for per-group and
# per-channel quantization.
def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
orig_device = w.device
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
assert group_size in SUPPORTED_GROUP_SIZES + [
size_k
], f"Unsupported groupsize = {group_size}"
if group_size == -1:
group_size = size_k
assert group_size <= size_k
if group_size < size_k:
# Reshape to [groupsize, -1]
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
max_q_val = 2**num_bits - 1
half_q_val = (max_q_val + 1) // 2
# Compute scale for each group
s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
s_group *= 2 / max_q_val # 2 => symmetric
# Quantize
q_w = torch.round(w / s_group).int()
q_w += half_q_val
q_w = torch.clamp(q_w, 0, max_q_val)
# Compute ref (dequantized)
w_ref = (q_w - half_q_val).half() * s_group
# Restore original shapes
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
q_w = reshape_w(q_w)
w_ref = reshape_w(w_ref)
# Compute int8 quantization scale for each channel
s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
s_channel /= 127.0
t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
w_ref = t_int8.half() * s_channel
s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
# Fuse scales
s_group = (s_group.reshape(-1, size_n).contiguous() /
s_channel).to(dtype=torch.half)
else:
max_q_val = 2**(num_bits - 1) - 1
# Compute scale for each channel
s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
s_channel /= max_q_val
# Quantize
q_w = torch.round(w / s_channel).int()
q_w = torch.clamp(q_w, -max_q_val, max_q_val)
# Compute ref (dequantized)
w_ref = q_w.half() * s_channel
s_group = torch.tensor([], dtype=torch.half)
# div 2 ** (8 - self.bits)) to offset right shift in unpacking
s_channel /= (2**(8 - num_bits))
s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device),
s_group.to(device=orig_device),
s_channel.to(device=orig_device),
)
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
orig_device = q_w.device orig_device = q_w.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment