Commit a130cf33 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.3.3' into vllm-v0.3.2-dtk23.10 and add gfx

parents a2d181be 82091b86
......@@ -50,7 +50,10 @@ steps:
command: pytest -v -s worker
- label: LoRA Test
command: pytest -v -s lora
command: pytest -v -s lora --forked
- label: Metrics Test
command: pytest -v -s metrics
- label: Benchmarks
working_dir: "/vllm-workspace/.buildkite"
......
......@@ -25,7 +25,10 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff==0.1.5
pip install ruff==0.1.5 codespell==2.2.6 tomli==2.0.1
- name: Analysing the code with ruff
run: |
ruff vllm tests
- name: Spelling check with codespell
run: |
codespell --toml pyproject.toml
\ No newline at end of file
......@@ -41,7 +41,7 @@ python3 setup.py install
+ 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/
## 验证
- python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.3.1
- python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.3.3
## Known Issue
-
......
......@@ -73,10 +73,12 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
......
......@@ -7,7 +7,7 @@ On the server side, run one of the following commands:
--disable-log-requests
(TGI backend)
./launch_hf_server.sh <your_model>
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
On the client side, run:
python benchmarks/benchmark_serving.py \
......@@ -375,7 +375,7 @@ if __name__ == "__main__":
parser.add_argument(
"--disable-tqdm",
action="store_true",
help="Specify to disbale tqdm progress bar.",
help="Specify to disable tqdm progress bar.",
)
parser.add_argument(
"--save-result",
......
import json
import os
import sys
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from vllm.model_executor.layers.fused_moe import fused_moe
import torch
import torch.nn.functional as F
import triton
def main():
method = fused_moe
for bs in [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
]:
run_grid(bs, method=method)
def run_grid(bs, method):
d_model = 4096
num_total_experts = 8
top_k = 2
tp_size = 2
model_intermediate_size = 14336
num_layers = 32
num_calls = 100
num_warmup_trials = 1
num_trials = 1
configs = []
if bs <= 16:
BLOCK_SIZES_M = [16]
elif bs <= 32:
BLOCK_SIZES_M = [16, 32]
elif bs <= 64:
BLOCK_SIZES_M = [16, 32, 64]
elif bs <= 128:
BLOCK_SIZES_M = [16, 32, 64, 128]
else:
BLOCK_SIZES_M = [16, 32, 64, 128, 256]
for block_size_n in [32, 64, 128, 256]:
for block_size_m in BLOCK_SIZES_M:
for block_size_k in [64, 128, 256]:
for group_size_m in [1, 16, 32, 64]:
for num_warps in [4, 8]:
configs.append({
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": 4,
})
best_config = None
best_time_us = 1e20
for config in configs:
print(f'{tp_size=} {bs=}')
print(f'{config}')
# warmup
print(f'warming up')
try:
for _ in range(num_warmup_trials):
run_timing(
num_calls=num_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
)
except triton.runtime.autotuner.OutOfResources:
continue
# trial
print(f'benchmarking')
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
)
kernel_dur_us = 1000 * kernel_dur_ms
model_dur_ms = kernel_dur_ms * num_layers
if kernel_dur_us < best_time_us:
best_config = config
best_time_us = kernel_dur_us
print(
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f} {bs=} {tp_size=} {top_k=} {num_total_experts=} {d_model=} {model_intermediate_size=} {num_layers=}'
)
print("best_time_us", best_time_us)
print("best_config", best_config)
filename = "/tmp/config.jsonl"
print(f"writing config to file {filename}")
with open(filename, "a") as f:
f.write(json.dumps({str(bs): best_config}) + "\n")
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
top_k: int, tp_size: int, model_intermediate_size: int, method,
config) -> float:
shard_intermediate_size = model_intermediate_size // tp_size
hidden_states = torch.rand(
(bs, d_model),
device="cuda:0",
dtype=torch.bfloat16,
)
ws = torch.rand(
(num_total_experts, 2 * shard_intermediate_size, d_model),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
w2s = torch.rand(
(num_total_experts, d_model, shard_intermediate_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
gating_output = F.softmax(torch.rand(
(num_calls, bs, num_total_experts),
device=hidden_states.device,
dtype=torch.float32,
),
dim=-1)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(num_calls):
hidden_states = method(
hidden_states=hidden_states,
w1=ws,
w2=w2s,
gating_output=gating_output[i],
topk=2,
renormalize=True,
inplace=True,
override_config=config,
)
end_event.record()
end_event.synchronize()
dur_ms = start_event.elapsed_time(end_event) / num_calls
return dur_ms
if __name__ == "__main__":
sys.exit(main())
......@@ -2,19 +2,16 @@
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include "cuda_compat.h"
#include "dispatch_utils.h"
namespace vllm {
template<typename T>
__device__ __forceinline__ T silu(const T& x) {
// x * sigmoid(x)
return (T) (((float) x) / (1.0f + expf((float) -x)));
}
template<typename scalar_t>
__global__ void silu_and_mul_kernel(
// Activation and gating kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
......@@ -22,32 +19,58 @@ __global__ void silu_and_mul_kernel(
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = silu(x) * y;
out[token_idx * d + idx] = ACT_FN(x) * y;
}
}
template<typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
return (T) (((float) x) / (1.0f + expf((float) -x)));
}
template<typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
const float f = (float) x;
constexpr float ALPHA = M_SQRT1_2;
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
}
} // namespace vllm
// Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"act_and_mul_kernel", \
[&] { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
void silu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"silu_and_mul_kernel",
[&] {
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
d);
});
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
void gelu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
}
namespace vllm {
......
......@@ -23,13 +23,6 @@ void reshape_and_cache(
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
void gather_cached_kv(
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping);
// Just for unittest
void convert_fp8_e5m2(
torch::Tensor& src_cache,
......
......@@ -269,167 +269,6 @@ void reshape_and_cache(
namespace vllm {
// Grid: (num_blocks, block_size).
template<typename scalar_t>
__global__ void gather_cached_kv_kernel(
scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int* __restrict__ slot_mapping, // [num_tokens]
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size,
const int x) {
const int token_idx = blockIdx.x;
const int slot_idx = slot_mapping[token_idx];
const int block_idx = slot_idx / block_size;
const int block_offset = slot_idx % block_size;
const int num_tokens = num_heads * head_size;
for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
const int tgt_key_idx = token_idx * key_stride + i;
const int tgt_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
const int x_offset = head_offset % x;
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ head_idx * (head_size / x) * block_size * x
+ x_idx * block_size * x
+ block_offset * x
+ x_offset;
const int src_value_idx = block_idx * num_heads * head_size * block_size
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;
key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
}
}
template <typename scalar_t>
__global__ void gather_cached_kv_kernel_optimized(
scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int *__restrict__ slot_mapping, // [num_tokens]
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size,
const int x)
{
const int token_idx = blockIdx.x;
const int slot_idx = slot_mapping[token_idx];
const int block_idx = slot_idx / block_size;
const int block_offset = slot_idx % block_size;
const int dim = num_heads * head_size;
assert(dim % 4 == 0); // this is true for known use cases
const int unroll_factor = 4;
const int unrolled_dim = dim / unroll_factor;
for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
{
int tgt_key_indices[unroll_factor];
int tgt_value_indices[unroll_factor];
int src_key_indices[unroll_factor];
int src_value_indices[unroll_factor];
scalar_t keys_to_store[unroll_factor];
scalar_t values_to_store[unroll_factor];
#pragma unroll
for (int j = 0; j < unroll_factor; ++j)
{
int index = i + j * unrolled_dim;
const int tgt_key_idx = token_idx * key_stride + index;
const int tgt_value_idx = token_idx * value_stride + index;
const int head_idx = index / head_size;
const int head_offset = index % head_size;
const int x_idx = head_offset / x;
const int x_offset = head_offset % x;
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ head_idx * (head_size / x) * block_size * x
+ x_idx * block_size * x
+ block_offset * x
+ x_offset;
const int src_value_idx = block_idx * num_heads * head_size * block_size
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;
tgt_key_indices[j] = tgt_key_idx;
tgt_value_indices[j] = tgt_value_idx;
src_key_indices[j] = src_key_idx;
src_value_indices[j] = src_value_idx;
keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
}
#pragma unroll
for (int j = 0; j < unroll_factor; ++j)
{
key[tgt_key_indices[j]] = keys_to_store[j];
value[tgt_value_indices[j]] = values_to_store[j];
}
}
}
} // namespace vllm
void gather_cached_kv(
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping) // [in] [num_tokens]
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
key.scalar_type(),
"gather_cached_kv_kernel_optimized",
[&] {
vllm::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(),
key_stride,
value_stride,
num_heads,
head_size,
block_size,
x);
});
}
namespace vllm {
template<typename Tout, typename Tin>
__global__ void convert_fp8_e5m2_kernel(
const Tin* __restrict__ src_cache,
......
......@@ -57,6 +57,10 @@ void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_new(
torch::Tensor& out,
torch::Tensor& input);
......@@ -80,6 +84,15 @@ torch::Tensor awq_dequantize(
int split_k_iters,
int thx,
int thy);
torch::Tensor marlin_gemm(
torch::Tensor& a,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
torch::Tensor& workspace,
int64_t size_m,
int64_t size_n,
int64_t size_k);
#endif
void squeezellm_gemm(
......@@ -94,11 +107,13 @@ torch::Tensor gptq_gemm(
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_exllama);
bool use_exllama,
int bit);
void gptq_shuffle(
torch::Tensor q_weight,
torch::Tensor q_perm);
torch::Tensor q_perm,
int bit);
void moe_align_block_size(
torch::Tensor topk_ids,
......
......@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 8192) \
......@@ -39,6 +40,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
......
......@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU.");
ops.def(
"gelu_new",
&gelu_new,
......@@ -48,11 +52,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Quantization ops
// Quantization ops
#ifndef USE_ROCM
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
......@@ -75,10 +81,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"gather_cached_kv",
&gather_cached_kv,
"Gather key and value from the cache into contiguous QKV tensors");
cache_ops.def(
"convert_fp8_e5m2",
&convert_fp8_e5m2,
......
......@@ -146,6 +146,129 @@ public:
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};
class MatrixView_q2_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x0f) * 2;
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
{
int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03;
items[1] = (d >> 2) & 0x03;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03;
items[1] = (d >> 2) & 0x03;
items[2] = (d >> 4) & 0x03;
items[3] = (d >> 6) & 0x03;
}
};
class MatrixView_q3_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int z_w = column * 3 / 32;
int z_mod = column & 0x1f;
if (z_mod == 10) {
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
} else if (z_mod == 21) {
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
} else if (z_mod < 10) {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
} else if (z_mod < 21) {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
} else {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
}
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x1f);
uint32_t d;
if (shift <= 4) {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
} else if (shift == 8) {
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
} else if (shift <= 16) {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
} else if (shift == 20) {
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
} else {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
}
items[0] = d & 0x07;
items[1] = (d >> 3) & 0x07;
items[2] = (d >> 6) & 0x07;
items[3] = (d >> 9) & 0x07;
}
};
class MatrixView_q8_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x03) * 8;
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
{
int shift = (column & 0x03) * 8;
uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff;
items[1] = (d >> 8) & 0xff;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x03) * 2;
uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff;
items[1] = (d >> 8) & 0xff;
items[2] = (d >> 16) & 0xff;
items[3] = (d >> 24) & 0xff;
}
};
} // namespace gptq
} // namespace vllm
#endif
......@@ -13,7 +13,10 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopq
#include "compat.cuh"
#include "matrix_view.cuh"
#include "qdq_2.cuh"
#include "qdq_3.cuh"
#include "qdq_4.cuh"
#include "qdq_8.cuh"
namespace vllm {
namespace gptq {
......@@ -22,6 +25,7 @@ namespace gptq {
#define BLOCK_M_SIZE_MAX 8
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
#define MAX_Q_GEMM_ROWS 50
#define MAX_Q_GEMM_ROWS_8BIT 24
#define MAX_ALT_GEMM_ROWS 8
#define THREADS_X 32
#define THREADS_Y 32
......@@ -75,6 +79,106 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
return __half2float(__low2half(result)) + __half2float(__high2half(result));
}
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h)
{
// Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127
float result = {};
#pragma unroll
for (int i = 0; i < 4; i++)
{
half2 w01 = dq[i];
float w0 = __low2float(w01);
float w1 = __high2float(w01);
float x0 = __half2float(*a_ptr++);
float x1 = __half2float(*a_ptr++);
result = fma(w0, x0, result);
result = fma(w1, x1, result);
}
float qs = __half2float(qs_h);
result *= qs;
half result_h = __float2half_rn(result);
return __hadd(result_h, g_result);
}
__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
half result_h = __hadd(__low2half(result), __high2half(result));
return __hfma(result_h, qs_h, g_result);
}
__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
half result_h = __hadd(__low2half(result), __high2half(result));
return __hfma(result_h, qs_h, g_result);
}
typedef void (*fp_gemm_half_q_half_gptq_kernel)
(
const half*,
......@@ -89,8 +193,9 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)
const int*
);
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_kernel
__global__ void gemm_half_q_half_gptq_4bit_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
......@@ -231,80 +336,794 @@ __global__ void gemm_half_q_half_gptq_kernel
}
}
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_2bit_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int* __restrict__ b_q_perm
)
{
#if BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
#endif
return NULL;
}
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int t = threadIdx.x;
void gemm_half_q_half_cuda_part
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0;
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
else a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
// Zero output
if (n >= size_n) return;
if (blockIdx.z == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// a, b offset
int qk = offset_k / (32 / 2);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int zeros[4];
half scales[4];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
// Column result
half block_c[m_count][4] = {};
// Dequantize and multiply
int k = offset_k;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
}
#pragma unroll
for (int j = 0; j < 1; j++)
{
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
half2 dq[4][8];
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
#pragma unroll
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]);
block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]);
block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]);
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]);
}
b_ptr += size_n;
a_ptr += 16;
}
k += 16;
}
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_3bit_kernel
(
const half* a,
const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales,
const int* b_q_perm,
half* c,
int size_m,
int size_n,
int size_k,
int m_count,
int groups
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int* __restrict__ b_q_perm
)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
int t = threadIdx.x;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>
(
a,
b_q_weight,
b_gptq_qzeros,
b_gptq_scales,
c,
size_m,
size_n,
size_k,
groups,
b_q_perm
);
}
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0;
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
else a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
// Zero output
if (n >= size_n) return;
if (blockIdx.z == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// a, b offset
int qk = offset_k / 32 * 3;
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int zeros[4];
half scales[4];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
// Column result
half block_c[m_count][4] = {};
// Dequantize and multiply
int k = offset_k;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
}
#pragma unroll
for (int j = 0; j < 1; j++)
{
int4 load_int4[3];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1);
#pragma unroll
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]);
block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]);
block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]);
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]);
}
a_ptr += 32;
}
k += 32;
}
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_8bit_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int* __restrict__ b_q_perm
)
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0;
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
else a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
// Zero output
if (n >= size_n) return;
if (blockIdx.z == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// a, b offset
int qk = offset_k / (32 / 8);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int zeros[4];
half scales[4];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
// Column result
half block_c[m_count][4] = {};
// Dequantize and multiply
int k = offset_k;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
int4 load_int4[2];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]);
block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]);
block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]);
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]);
}
a_ptr += 8;
}
k += 32;
}
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(
bool first_block, const int m_count, const int bit)
{
#define SELECT_KERNEL(M_COUNT) \
if (m_count == M_COUNT) { \
if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel<true, M_COUNT>; \
if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel<true, M_COUNT>; \
if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel<true, M_COUNT>; \
if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel<true, M_COUNT>; \
}
#if BLOCK_M_SIZE_MAX >= 1
SELECT_KERNEL(1);
#endif
#if BLOCK_M_SIZE_MAX >= 2
SELECT_KERNEL(2);
#endif
#if BLOCK_M_SIZE_MAX >= 3
SELECT_KERNEL(3);
#endif
#if BLOCK_M_SIZE_MAX >= 4
SELECT_KERNEL(4);
#endif
#if BLOCK_M_SIZE_MAX >= 5
SELECT_KERNEL(5);
#endif
#if BLOCK_M_SIZE_MAX >= 6
SELECT_KERNEL(6);
#endif
#if BLOCK_M_SIZE_MAX >= 7
SELECT_KERNEL(7);
#endif
#if BLOCK_M_SIZE_MAX >= 8
SELECT_KERNEL(8);
#endif
return NULL;
}
void gemm_half_q_half_cuda_part
(
const half* a,
const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales,
const int* b_q_perm,
half* c,
int size_m,
int size_n,
int size_k,
int m_count,
int groups,
int bit
)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>
(
a,
b_q_weight,
b_gptq_qzeros,
b_gptq_scales,
c,
size_m,
size_n,
size_k,
groups,
b_q_perm
);
}
__global__ void reconstruct_exllama_8bit_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
const int size_k,
const int size_n,
const int groups,
half* __restrict__ b
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ int perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
if (b_q_perm)
{
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
}
// Column
int n = offset_n + t * 4;
if (n >= size_n) return;
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / (32 / 8);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
__syncthreads();
int k = offset_k;
int lk = 0;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
}
for (int p = 0; p < 4; p++)
{
int4 load_int4[2];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1);
//half* dqh = (half*)dq;
if (b_q_perm)
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
else
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
}
k += 32;
}
}
__global__ void reconstruct_exllama_4bit_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
const int size_k,
const int size_n,
const int groups,
half* __restrict__ b
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ int perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
if (b_q_perm)
{
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
}
// Column
int n = offset_n + t * 4;
if (n >= size_n) return;
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / (32 / 4);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
__syncthreads();
int k = offset_k;
int lk = 0;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
}
for (int p = 0; p < 4; p++)
{
half2 dq[4][4];
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
b_ptr += size_n;
//half* dqh = (half*)dq;
if (b_q_perm)
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
else
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
}
k += 32;
}
}
__global__ void reconstruct_exllama_3bit_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
const int size_k,
const int size_n,
const int groups,
half* __restrict__ b
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ int perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
if (b_q_perm)
{
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
}
// Column
int n = offset_n + t * 4;
if (n >= size_n) return;
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / 32* 3;
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
__syncthreads();
int k = offset_k;
int lk = 0;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
}
for (int p = 0; p < 1; p++)
{
int4 load_int4[3];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1);
if (b_q_perm)
{
for (int j = 0; j < 16; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
else
{
for (int j = 0; j < 16; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
}
k += 32;
}
}
__global__ void reconstruct_exllama_kernel
__global__ void reconstruct_exllama_2bit_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
......@@ -317,7 +1136,7 @@ __global__ void reconstruct_exllama_kernel
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
......@@ -345,21 +1164,15 @@ __global__ void reconstruct_exllama_kernel
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / (32 / 4);
int qk = offset_k / (32 / 2);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
__syncthreads();
......@@ -374,28 +1187,24 @@ __global__ void reconstruct_exllama_kernel
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
}
for (int p = 0; p < 4; p++)
for (int p = 0; p < 2; p++)
{
half2 dq[4][4];
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
half2 dq[4][8];
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
b_ptr += size_n;
//half* dqh = (half*)dq;
if (b_q_perm)
{
for (int j = 0; j < 4; j++)
for (int j = 0; j < 8; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
......@@ -404,7 +1213,7 @@ __global__ void reconstruct_exllama_kernel
}
else
{
for (int j = 0; j < 4; j++)
for (int j = 0; j < 8; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
......@@ -416,7 +1225,6 @@ __global__ void reconstruct_exllama_kernel
}
}
void reconstruct_exllama
(
const uint32_t* b_q_weight,
......@@ -426,7 +1234,8 @@ void reconstruct_exllama
half* out,
int height,
int width,
int groups
int groups,
int bit
)
{
dim3 blockDim, gridDim;
......@@ -435,6 +1244,15 @@ void reconstruct_exllama
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel;
if (bit == 2) {
reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel;
} else if (bit == 3) {
reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel;
} else if (bit == 8) {
reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>
(
......@@ -450,7 +1268,7 @@ void reconstruct_exllama
}
__global__ void gemm_half_q_half_alt_kernel(
__global__ void gemm_half_q_half_alt_4bit_kernel(
const half2* __restrict__ vec,
const uint32_t* __restrict__ mat,
half* __restrict__ mul,
......@@ -548,6 +1366,95 @@ __global__ void gemm_half_q_half_alt_kernel(
}
__global__ void gemm_half_q_half_alt_8bit_kernel(
const half2* __restrict__ vec,
const uint32_t* __restrict__ mat,
half* __restrict__ mul,
const half* __restrict__ scales,
const uint32_t* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int height,
int width
)
{
int zero_width = width / 4;
int vec_height = height * 2;
const int blockwidth2 = BLOCK_KN_SIZE / 2;
int b = blockIdx.y * BLOCK_M_SIZE_MAX;
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
int h = BLOCK_KN_SIZE * blockIdx.z / 4;
int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2;
int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
if (threadIdx.x < h_end) {
for (int m = 0; m < b_end; ++m) {
blockvec[m][threadIdx.x] =
vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 +
threadIdx.x];
}
}
if (blockIdx.z == 0)
{
for (int m = 0; m < b_end; m++)
mul[(b + m) * width + w] = __int2half_rn(0);
}
__syncthreads();
int i = width * h + w;
int g_h = h * 4;
int k = 0;
int z_w = w / 4;
int z_mod = (w % 4) * 8;
half2 res2;
half res[BLOCK_M_SIZE_MAX] = {};
unsigned int tmp;
while (k < h_end) {
tmp = mat[i];
half2 scales_tmp[2];
half2 zeros_tmp[2];
for (int tmp_k = 0; tmp_k < 2; tmp_k++) {
int g = g_idx[g_h + (k + tmp_k) * 2];
int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1];
half scale_f = scales[g * width + w];
half scale_f2 = scales[g2 * width + w];
half2 scale = __halves2half2(scale_f, scale_f2);
half2 zero = __halves2half2(
__hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)),
__hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))
);
scales_tmp[tmp_k] = scale;
zeros_tmp[tmp_k] = zero;
}
for (int m = 0; m < b_end; m++) {
#ifndef USE_ROCM
res2 = {};
#else
res2.x = __half_as_ushort(__float2half(0));
res2.y = __half_as_ushort(__float2half(0));
#endif
half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), __int2half_rn((tmp >> 8) & 0xFF));
res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), __int2half_rn((tmp >> 24) & 0xFF));
res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
#ifndef USE_ROCM
res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
#else
res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
#endif
}
i += width;
k += 2;
}
for (int m = 0; m < b_end; m++) {
atomicAdd(&mul[(b + m) * width + w], res[m]);
}
}
void gemm_half_q_half_alt
(
const half* a,
......@@ -558,7 +1465,8 @@ void gemm_half_q_half_alt
half* c,
int size_m,
int size_n,
int size_k
int size_k,
int bit
)
{
dim3 blockDim, gridDim;
......@@ -569,8 +1477,13 @@ void gemm_half_q_half_alt
gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
auto kernel = gemm_half_q_half_alt_4bit_kernel;
if (bit == 8) {
kernel = gemm_half_q_half_alt_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>>
kernel<<<gridDim, blockDim, 0, stream>>>
(
(const half2*) a,
b_q_weight,
......@@ -579,12 +1492,12 @@ void gemm_half_q_half_alt
b_gptq_qzeros,
b_g_idx,
size_m,
size_k / 8,
size_k / 32 * bit,
size_n
);
}
template<class T, int bit>
__global__ void reconstruct_gptq_kernel
(
const uint32_t* __restrict__ w,
......@@ -600,30 +1513,79 @@ __global__ void reconstruct_gptq_kernel
// Start of block
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
int row = blockIdx.y * 8;
int row = blockIdx.y * 32 / bit;
if (column >= width) return;
// Views
MatrixView_q4_column w_(w, height, width);
MatrixView_half_rw out_(out, height, width);
MatrixView_half w_scales_(w_scales, group, width);
MatrixView_q4_row w_zeros_(w_zeros, group, width);
T w_zeros_(w_zeros, group, width);
uint32_t w_read = w_.item_uint32_t(row, column);
uint32_t w_read = w[blockIdx.y * width + column];
half* out_ptr = out_.item_ptr(row, column);
#pragma unroll
for (int s = 0; s < 32; s += 4)
for (int s = 0; s < 32; s += bit)
{
int group = g_idx[row + s / 4];
int group = g_idx[row + s / bit];
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), w_scale);
*out_ptr = w_item; out_ptr += out_.width;
}
}
__global__ void reconstruct_gptq_3bit_kernel
(
const uint32_t* __restrict__ w,
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int* __restrict__ g_idx,
const int height,
const int width,
const int group,
half* __restrict__ out
)
{
// Start of block
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
int row = blockIdx.y * 32;
if (column >= width) return;
// Views
MatrixView_half_rw out_(out, height, width);
MatrixView_half w_scales_(w_scales, group, width);
MatrixView_q3_row w_zeros_(w_zeros, group, width);
uint32_t w1 = w[(blockIdx.y * 3) * width + column];
uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column];
uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column];
half* out_ptr = out_.item_ptr(row, column);
#pragma unroll
for (int i = 0; i < 32; i += 1)
{
int group = g_idx[row + i];
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
int w_item;
if (i == 10) {
w_item = (w1 >> 30) | ((w2 << 2) & 0x4);
} else if (i == 21) {
w_item = (w2 >> 31) | ((w3 << 1) & 0x6);
} else if (i < 10) {
w_item = ((w1 >> (i * 3)) & 0x7);
} else if (i < 21) {
w_item = ((w2 >> (i * 3 - 32)) & 0x7);
} else {
w_item = ((w3 >> (i * 3 - 64)) & 0x7);
}
*out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale);
out_ptr += out_.width;
}
}
void reconstruct_gptq
(
......@@ -634,16 +1596,28 @@ void reconstruct_gptq
half* out,
int height,
int width,
int groups
int groups,
int bit
)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
gridDim.y = DIVIDE(height, 8);
gridDim.y = DIVIDE(height, 32 / bit);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
auto kernel = reconstruct_gptq_kernel<MatrixView_q4_row, 4>;
if (bit == 2) {
kernel = reconstruct_gptq_kernel<MatrixView_q2_row, 2>;
} else if (bit == 8) {
kernel = reconstruct_gptq_kernel<MatrixView_q8_row, 8>;
} else if (bit == 3) {
kernel = reconstruct_gptq_3bit_kernel;
gridDim.y = DIVIDE(height, 32);
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
kernel<<<gridDim, blockDim, 0, stream>>>
(
b_q_weight,
b_gptq_scales,
......@@ -671,19 +1645,27 @@ void gemm_half_q_half_cuda
int size_n,
int size_k,
int groups,
bool use_exllama
bool use_exllama,
int bit
)
{
if ((use_exllama && size_m > MAX_Q_GEMM_ROWS) || (!use_exllama && size_m > MAX_ALT_GEMM_ROWS)) {
bool use_reconstruct;
if (use_exllama) {
use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || (bit != 8 && size_m > MAX_Q_GEMM_ROWS));
} else {
// The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so we disabled them for now.
use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS);
}
if (use_reconstruct) {
// Reconstruct FP16 matrix, then cuBLAS
if (use_exllama) {
reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq,
size_k, size_n, groups);
size_k, size_n, groups, bit);
}
else
{
reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
temp_dq, size_k, size_n, groups);
temp_dq, size_k, size_n, groups, bit);
}
const half alpha = __float2half(1.0f);
......@@ -707,7 +1689,7 @@ void gemm_half_q_half_cuda
{
gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX,
groups);
groups, bit);
}
if (last_chunk_size)
......@@ -715,18 +1697,17 @@ void gemm_half_q_half_cuda
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros,
b_gptq_scales, b_g_idx, c + last_chunk * size_n,
last_chunk_size, size_n, size_k, last_chunk_size,
groups);
groups, bit);
}
}
else
{
gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
c, size_m, size_n, size_k);
c, size_m, size_n, size_k, bit);
}
}
__global__ void shuffle_kernel
__global__ void shuffle_4bit_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
......@@ -740,13 +1721,53 @@ __global__ void shuffle_kernel
while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
}
__global__ void shuffle_8bit_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
const int size_n
)
{
int n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
while (k < size_k) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
}
__global__ void shuffle_2bit_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
const int size_n
)
{
int n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
while (k < size_k) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
}
__global__ void shuffle_3bit_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
const int size_n
)
{
int n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
while (k < size_k) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
}
__global__ void make_sequential_kernel
__global__ void make_sequential_4bit_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const int* __restrict__ q_perm,
const int w_height,
const int w_width
)
{
......@@ -778,37 +1799,204 @@ __global__ void make_sequential_kernel
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
__global__ void make_sequential_2bit_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const int* __restrict__ q_perm,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int q_perm_idx = w_new2_row << 4;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 16; i++)
{
int source_row = q_perm[q_perm_idx++];
int w2_row = source_row >> 4;
int w2_subrow = source_row & 0x0f;
int w2_row_shift = w2_subrow << 1;
int wnew2_row_shift = i << 1;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000300000003;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
__global__ void make_sequential_3bit_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const int* __restrict__ q_perm,
const int w_width
)
{
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w_column >= w_width) return;
int w_new_row = blockIdx.y * 3;
int q_perm_idx = blockIdx.y << 5;
uint32_t dst[3] = {0, 0, 0};
#pragma unroll
for (int i = 0; i < 32; i++)
{
int source_row = q_perm[q_perm_idx++];
int z_w = (source_row / 32) * 3;
int z_mod = source_row % 32;
int z_bit;
if (z_mod != 10){
if (z_mod != 21){
z_bit = z_mod;
if (z_bit > 21){
z_bit *= 3;
z_bit -= 64;
z_w += 2;
} else if (z_bit > 10){
z_bit *= 3;
z_bit -= 32;
z_w += 1;
} else {
z_bit *= 3;
}
} else {
z_w += 1;
}
}
uint64_t src;
if (z_mod == 10) {
src = (w[z_w * w_width + w_column] >> 30) | ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4);
} else if (z_mod == 21){
src = (w[z_w * w_width + w_column] >> 31) | ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6);
} else {
src = w[z_w * w_width + w_column];
src >>= z_bit;
src &= 0x07;
}
z_w = 0;
if (i != 10){
if (i != 21){
z_bit = i;
if (z_bit > 21){
z_bit *= 3;
z_bit -= 64;
z_w += 2;
} else if (z_bit > 10){
z_bit *= 3;
z_bit -= 32;
z_w += 1;
} else {
z_bit *= 3;
}
} else {
z_w += 1;
}
}
if (i == 10) {
dst[z_w] |= (src & 0x03) << 30;
dst[z_w + 1] |= ((src & 0x4) >> 2);
} else if (i == 21) {
dst[z_w] |= (src & 0x01) << 31;
dst[z_w + 1] |= ((src & 0x6) >> 1);
} else {
dst[z_w] |= (src << z_bit);
}
}
w_new[w_new_row * w_width + w_column] = dst[0];
w_new[(w_new_row + 1) * w_width + w_column] = dst[1];
w_new[(w_new_row + 2) * w_width + w_column] = dst[2];
}
__global__ void make_sequential_8bit_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const int* __restrict__ q_perm,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int q_perm_idx = w_new2_row << 2;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 4; i++)
{
int source_row = q_perm[q_perm_idx++];
int w2_row = source_row >> 2;
int w2_subrow = source_row & 0x03;
int w2_row_shift = w2_subrow << 3;
int wnew2_row_shift = i << 3;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x000000ff000000ff;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
void shuffle_exllama_weight
(
uint32_t* q_weight,
int* q_perm,
int height,
int width
int width,
int bit
)
{
if (q_perm)
{
uint32_t* new_qweight = NULL;
cudaMalloc(&new_qweight, height / 8 * width * sizeof(uint32_t));
cudaMalloc(&new_qweight, height / 32 * bit * width * sizeof(uint32_t));
dim3 blockDim, gridDim;
blockDim.x = THREADS_X;
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = height / 8;
gridDim.y = height / 32 * bit;
auto kernel = make_sequential_4bit_kernel;
if (bit == 2) {
kernel = make_sequential_2bit_kernel;
} else if (bit == 3) {
kernel = make_sequential_3bit_kernel;
gridDim.y = height / 32;
} else if (bit == 8) {
kernel = make_sequential_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
kernel<<<gridDim, blockDim, 0, stream>>>
(
q_weight,
new_qweight,
q_perm,
height / 8,
width
);
// Replace qweights
cudaMemcpyAsync(q_weight, new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
cudaMemcpyAsync(q_weight, new_qweight, height / 32 * bit * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
// Cleanup
cudaDeviceSynchronize();
cudaFree(new_qweight);
......@@ -818,6 +2006,14 @@ void shuffle_exllama_weight
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1;
auto shuffle_kernel = shuffle_4bit_kernel;
if (bit == 2) {
shuffle_kernel = shuffle_2bit_kernel;
} else if (bit == 3) {
shuffle_kernel = shuffle_3bit_kernel;
} else if (bit == 8) {
shuffle_kernel = shuffle_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
}
......@@ -832,13 +2028,14 @@ torch::Tensor gptq_gemm
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_exllama
bool use_exllama,
int bit
)
{
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options);
at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options);
vllm::gptq::gemm_half_q_half_cuda
(
......@@ -854,7 +2051,8 @@ torch::Tensor gptq_gemm
c.size(1), // n
a.size(1), // k
b_gptq_qzeros.size(0), // group number
use_exllama
use_exllama,
bit
);
return c;
}
......@@ -862,14 +2060,16 @@ torch::Tensor gptq_gemm
void gptq_shuffle
(
torch::Tensor q_weight,
torch::Tensor q_perm
torch::Tensor q_perm,
int bit
)
{
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::gptq::shuffle_exllama_weight(
(uint32_t*) q_weight.data_ptr(),
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
q_weight.size(0) * 8,
q_weight.size(1)
q_weight.size(0) * 32 / bit,
q_weight.size(1),
bit
);
}
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_2_cuh
#define _qdq_2_cuh
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
// Permutation:
//
// ffddbb99 77553311 eeccaa88 66442200
__forceinline__ __device__ void shuffle_2bit_16
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
uint32_t qa0 = qa & 0x03;
uint32_t qa1 = (qa & 0x0c) >> 2;
qa >>= 4;
qb |= (qa1 << (i * 2 + 16));
qb |= (qa0 << (i * 2));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_2bit_16
(
const uint32_t q_0,
half2 (&dq)[8],
int stride,
const uint32_t zero
)
{
const uint32_t c0 = 0x64006400;
const half y4_ = __float2half_rn(1.0f / 4.0f);
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y4 = __halves2half2(y4_, y4_);
const half2 y16 = __halves2half2(y16_, y16_);
const half2 y64 = __halves2half2(y64_, y64_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
const half2 z1 = __half2half2(z1_.as_half);
const half2 z4 = __half2half2(z4_);
const half2 z16 = __half2half2(z16_);
const half2 z64 = __half2half2(z64_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
qa >>= 8;
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y4, z4);
dq[2] = __hfma2(q2.as_half2, y16, z16);
dq[3] = __hfma2(q3.as_half2, y64, z64);
dq[4] = __hadd2(q4.as_half2, z1);
dq[5] = __hfma2(q5.as_half2, y4, z4);
dq[6] = __hfma2(q6.as_half2, y16, z16);
dq[7] = __hfma2(q7.as_half2, y64, z64);
}
} // namespace gptq
} // namespace vllm
#endif
#ifndef _qdq_3_cuh
#define _qdq_3_cuh
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__ __device__ void shuffle_3bit_32
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
// qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t qd = qc >> 26;
qc <<= 4;
qc |= qb >> 28;
qb <<= 2;
qb |= qa >> 30;
// qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu
uint32_t za = 0;
uint32_t zb = 0;
uint32_t zc = 0;
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
// za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu
za |= ((qd & 0x01) >> 0) << 15;
zb |= ((qd & 0x02) >> 1) << 15;
zc |= ((qd & 0x04) >> 2) << 15;
za |= ((qd & 0x08) >> 3) << 31;
zb |= ((qd & 0x10) >> 4) << 31;
zc |= ((qd & 0x20) >> 5) << 31;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q[0 * stride] = za;
q[1 * stride] = zb;
q[2 * stride] = zc;
}
__forceinline__ __device__ void dequant_3bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[16],
int stride,
const uint32_t zero
)
{
const uint32_t c0 = 0x64006400;
const half y8_ = __float2half_rn(1.0f / 8.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y8 = __halves2half2(y8_, y8_);
const half2 y64 = __halves2half2(y64_, y64_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
const half2 z8 = __halves2half2(z8_, z8_);
const half2 z64 = __halves2half2(z64_, z64_);
uint32_t qa = q_0;
uint32_t qb = q_1;
uint32_t qc = q_2;
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
qa >>= 6;
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
qa >>= 9;
qa &= 0x00010001;
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
qb >>= 6;
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
qb >>= 8;
qb &= 0x00020002;
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
qc >>= 6;
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
qc >>= 7;
qc &= 0x00040004;
half2_uint32 q15((qa | qb | qc) | c0);
dq[ 0] = __hadd2( q0.as_half2, z1);
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
dq[ 2] = __hadd2( q2.as_half2, z1);
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
dq[ 5] = __hadd2( q5.as_half2, z1);
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
dq[ 7] = __hadd2( q7.as_half2, z1);
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
dq[10] = __hadd2(q10.as_half2, z1);
dq[11] = __hfma2(q11.as_half2, y8, z8);
dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y8, z8);
dq[14] = __hfma2(q14.as_half2, y64, z64);
dq[15] = __hadd2(q15.as_half2, z1);
}
} // namespace gptq
} // namespace vllm
#endif
......@@ -38,16 +38,17 @@ __forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
int stride,
const uint32_t zero
)
{
const uint32_t c0 = 0x64006400;
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half2 y16 = __halves2half2(y16_, y16_);
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z16 = __halves2half2(z16_, z16_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
const half2 z1 = __half2half2(z1_.as_half);
const half2 z16 = __half2half2(z16_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
......@@ -143,93 +144,4 @@ __forceinline__ __device__ void dequant_4bit_8_gptq
} // namespace gptq
} // namespace vllm
#else
namespace vllm {
namespace gptq {
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
)
{
half dqh[8];
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1)[2],
half2 (&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z = __hmul(z, scale);
z1[0] = __half2half2(z);
y1[0] = __half2half2(scale);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1)[2],
half2(&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z1[0] = __half2half2(z);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1)[2],
half2 (&y1)[2],
int stride,
bool scaled
)
{
half2 dqh2[8];
uint32_t qa = q_0;
for (int i = 0; i < 4; i++)
{
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
dqh2[i] = __halves2half2(d0, d1);
}
if (scaled)
{
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
}
else
{
dq[0] = __hadd2(dqh2[0], z1[0]);
dq[1] = __hadd2(dqh2[1], z1[0]);
dq[2] = __hadd2(dqh2[2], z1[0]);
dq[3] = __hadd2(dqh2[3], z1[0]);
}
}
} // namespace gptq
} // namespace vllm
#endif
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
__forceinline__ __device__ void shuffle_8bit_4
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_8bit_8
(
const uint32_t q_0,
const uint32_t q_1,
half2 (&dq)[4],
int stride,
const uint32_t zero
)
{
half dqh[8];
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
} // namespace gptq
} // namespace vllm
#endif
Contains code from https://github.com/IST-DASLab/marlin
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
------------------------------------------------------------------------------------
This product bundles various third-party components under other open source licenses.
This section summarizes those components and their licenses. See licenses/
for text of these licenses.
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
template <typename T> inline std::string str(T x) { return std::to_string(x); }
namespace marlin {
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template <typename T, int n> struct Vec {
T elems[n];
__device__ T &operator[](int i) { return elems[i]; }
};
using I4 = Vec<int, 4>;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
// Asynchronous global->shared copy with a cache hint indicating that the values
// may be evicted immediately; used for quantized weights B, which are only
// accessed precisely once and should thus not pollute the L2 cache which we
// need for inputs A and outputs C.
__device__ inline void cp_async4_stream(void *smem_ptr, const void *glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .b64 p;\n"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
// Async copy fence.
__device__ inline void cp_async_fence() {
asm volatile("cp.async.commit_group;\n" ::);
}
// Wait until at most `n` async copy stages are still pending.
template <int n> __device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
__device__ inline void mma(const FragA &a_frag, const FragB &frag_b,
FragC &frag_c) {
const uint32_t *a = reinterpret_cast<const uint32_t *>(&a_frag);
const uint32_t *b = reinterpret_cast<const uint32_t *>(&frag_b);
float *c = reinterpret_cast<float *>(&frag_c);
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut> __device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__ inline FragB dequant(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
*reinterpret_cast<const half2 *>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2 *>(&hi),
*reinterpret_cast<const half2 *>(&MUL),
*reinterpret_cast<const half2 *>(&ADD));
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s);
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int *lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
}
__syncthreads();
}
// Release barrier and increment visitation count.
__device__ inline void barrier_release(int *lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
}
}
template <const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks = -1 // number of consecutive 16x16 blocks with
// a separate quantization scale
>
__global__ void
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
int4 *__restrict__ C, // fp16 output buffer of shape mxn
const int4
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int *locks // extra global storage for barrier synchronization
) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int parallel = 1;
if (prob_m > 16 * thread_m_blocks) {
parallel = prob_m / (16 * thread_m_blocks);
prob_m = 16 * thread_m_blocks;
}
int k_tiles = prob_k / 16 / thread_k_blocks;
int n_tiles = prob_n / 16 / thread_n_blocks;
int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts in
// the middle of group.
if (group_blocks != -1)
iters = (group_blocks / thread_k_blocks) *
ceildiv(iters, (group_blocks / thread_k_blocks));
int slice_row = (iters * blockIdx.x) % k_tiles;
int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par;
int slice_iters; // number of threadblock tiles in the current slice
int slice_count =
0; // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to
// top
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if (slice_col_par >= n_tiles) {
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
locks += (slice_col_par / n_tiles) * n_tiles;
slice_col = slice_col_par % n_tiles;
}
// Compute all information about the current slice which is required for
// synchronization.
auto init_slice = [&]() {
slice_iters =
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
slice_iters = 0;
if (slice_iters == 0)
return;
if (slice_row + slice_iters > k_tiles)
slice_iters = k_tiles - slice_row;
slice_count = 1;
slice_idx = 0;
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
if (col_first <= k_tiles * (slice_col_par + 1)) {
int col_off = col_first - k_tiles * slice_col_par;
slice_count = ceildiv(k_tiles - col_off, iters);
if (col_off > 0)
slice_count++;
int delta_first = iters * blockIdx.x - col_first;
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
slice_idx = slice_count - 1;
else {
slice_idx = slice_count - 1 - delta_first / iters;
if (col_off > 0)
slice_idx--;
}
}
if (slice_col == n_tiles) {
A += 16 * thread_m_blocks * prob_k / 8;
C += 16 * thread_m_blocks * prob_n / 8;
locks += n_tiles;
slice_col = 0;
}
};
init_slice();
int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
// We typically use `constexpr` to indicate that this value is a compile-time
// constant
constexpr int a_sh_stride =
16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
constexpr int a_gl_rd_delta_o =
16 * thread_k_blocks /
8; // delta between subsequent A tiles in global memory
int a_gl_rd_delta_i =
a_gl_stride *
(threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
constexpr int a_sh_wr_delta =
a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes
constexpr int a_sh_rd_delta_o =
2 * ((threads / 32) /
(thread_n_blocks / 4)); // between shared memory tile reads
constexpr int a_sh_rd_delta_i =
a_sh_stride * 16; // within a shared memory tile
constexpr int a_sh_stage =
a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
constexpr int a_sh_wr_iters =
ceildiv(a_sh_stage,
a_sh_wr_delta); // number of shared write iterations for a tile
int b_gl_stride = 16 * prob_n / 32;
constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
constexpr int b_sh_wr_delta = threads;
constexpr int b_sh_rd_delta = threads;
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
int s_gl_stride = prob_n / 8;
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_sh_stage = s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
// Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
a_gl_rd += a_gl_rd_delta_o * slice_row;
// Shared write index of current thread.
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
// Shared read index.
int a_sh_rd =
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
int b_gl_rd =
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x;
int b_sh_rd = threadIdx.x;
int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_sh_stride * slice_col + threadIdx.x;
int s_sh_wr = threadIdx.x;
int s_sh_rd;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
if (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
else
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool a_sh_wr_pred[a_sh_wr_iters];
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++)
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto transform_a = [&](int i) {
int row = i / a_gl_rd_delta_o;
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int a_sh_wr_trans[a_sh_wr_iters];
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++)
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < thread_m_blocks; j++)
a_sh_rd_trans[i][j] =
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const int4 *B_ptr[b_sh_wr_iters];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
extern __shared__ int4 sh[];
// Shared memory storage for global fetch pipelines.
int4 *sh_a = sh;
int4 *sh_b = sh_a + (stages * a_sh_stage);
int4 *sh_s = sh_b + (stages * b_sh_stage);
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2];
FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4];
// Zero accumulators.
auto zero_accums = [&]() {
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
reinterpret_cast<float *>(frag_c)[i] = 0;
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
if (pred) {
int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
cp_async4_pred(
&sh_a_stage[a_sh_wr_trans[i]],
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
a_sh_wr_pred[i]);
}
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
B_ptr[i] += b_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred)
cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence();
};
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<stages - 2>();
__syncthreads();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto fetch_to_registers = [&](int k, int pipe) {
// It may seem inefficient that we reload the groups for every sub-tile;
// however, this does not seem to be a significant bottleneck, while some
// theoretically better attempts have lead to bad instruction ordering by
// the compiler and correspondingly a noticeable drop in performance.
if (group_blocks != -1) {
int4 *sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4 *>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
}
int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++)
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
frag_b_quant[k % 2] = *reinterpret_cast<I4 *>(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
};
// Execute the actual tensor core matmul of a sub-tile.
auto matmul = [&](int k) {
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for (int j = 0; j < 4; j++) {
int b_quant = frag_b_quant[k % 2][j];
int b_quant_shift = b_quant >> 8;
FragB frag_b0 = dequant(b_quant);
// If there are no groups, we can just scale the final output once and can
// avoid doing so for each weight.
if (group_blocks != -1)
scale(frag_b0, frag_s[k % 2][j], 0);
FragB frag_b1 = dequant(b_quant_shift);
if (group_blocks != -1)
scale(frag_b1, frag_s[k % 2][j], 1);
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride;
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
constexpr int red_sh_delta = b_sh_stride;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
(threadIdx.x % b_sh_stride);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
#pragma unroll
for (int i = red_off; i > 0; i /= 2) {
if (i <= red_idx && red_idx < 2 * i) {
#pragma unroll
for (int j = 0; j < 4 * 2; j++) {
int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) {
float *c_rd = reinterpret_cast<float *>(
&sh[red_sh_delta * j + red_sh_rd]);
float *c_wr = reinterpret_cast<float *>(&sh[red_sh_wr]);
#pragma unroll
for (int k = 0; k < 4; k++)
reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + j][k] +=
c_rd[k] + c_wr[k];
}
sh[red_sh_wr] =
reinterpret_cast<int4 *>(&frag_c)[4 * 2 * m_block + j];
}
}
__syncthreads();
}
if (red_idx == 0) {
#pragma unroll
for (int i = 0; i < 4 * 2; i++) {
float *c_rd =
reinterpret_cast<float *>(&sh[red_sh_delta * i + red_sh_rd]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<FragC *>(frag_c)[4 * 2 * m_block + i][j] +=
c_rd[j];
}
}
__syncthreads();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped partitioning
// minimizes the number of such reductions and our outputs are usually rather
// small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr int active_threads = 32 * thread_n_blocks / 4;
if (threadIdx.x < active_threads) {
int c_gl_stride = prob_n / 8;
int c_gl_wr_delta_o = 8 * c_gl_stride;
int c_gl_wr_delta_i = 4 * (active_threads / 32);
int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
4 * (threadIdx.x / 32) + threadIdx.x % 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
constexpr int c_sh_wr_delta = active_threads;
int c_sh_wr = threadIdx.x;
int row = (threadIdx.x % 32) / 4;
if (!first) {
// Interestingly, doing direct global accesses here really seems to mess up the
// compiler and lead to slowdowns, hence we also use async-copies even though
// these fetches are not actually asynchronous.
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 ||
8 * (i / 2) + row < prob_m);
}
cp_async_fence();
cp_async_wait<0>();
}
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
if (!first) {
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<float *>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
__half2float(reinterpret_cast<__half *>(&c_red)[j]);
}
}
if (!last) {
int4 c;
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<__half *>(&c)[j] =
__float2half(reinterpret_cast<float *>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
}
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
c;
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto write_result = [&]() {
int c_gl_stride = prob_n / 8;
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
constexpr int c_sh_rd_delta =
c_sh_stride * (threads / (2 * thread_n_blocks));
int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
(threadIdx.x % (2 * thread_n_blocks));
c_gl_wr += (2 * thread_n_blocks) * slice_col;
int c_sh_wr =
(4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
c_sh_wr += 32 * (threadIdx.x / 32);
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
(threadIdx.x % (2 * thread_n_blocks));
int c_gl_wr_end = c_gl_stride * prob_m;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, float c0, float c1, FragS &s) {
half2 res = __halves2half2(__float2half(c0), __float2half(c1));
if (group_blocks ==
-1) // for per-column quantization we finally apply the scale here
res = __hmul2(res, s[0]);
((half2 *)sh)[idx] = res;
};
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
int wr = c_sh_wr + 8 * j;
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
}
c_sh_wr += 16 * (4 * c_sh_stride);
}
}
__syncthreads();
#pragma unroll
for (int i = 0;
i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
i++) {
if (c_gl_wr < c_gl_wr_end) {
C[c_gl_wr] = sh[c_sh_rd];
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
}
}
};
// Start global fetch and register load pipelines.
auto start_pipes = [&]() {
#pragma unroll
for (int i = 0; i < stages - 1; i++)
fetch_to_shared(i, i, i < slice_iters);
zero_accums();
wait_for_stage();
fetch_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
};
start_pipes();
// Main loop.
while (slice_iters) {
// We unroll over both the global fetch and the register load pipeline to ensure
// all shared memory accesses are static. Note that both pipelines have even
// length meaning that the next iteration will always start at index 0.
#pragma unroll
for (int pipe = 0; pipe < stages;) {
#pragma unroll
for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages);
if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages);
pipe++;
wait_for_stage();
}
matmul(k);
}
slice_iters--;
if (slice_iters == 0)
break;
}
a_gl_rd += a_gl_rd_delta_o * stages;
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) {
cp_async_wait<0>();
bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before
// write-out
if (group_blocks == -1 && last) {
if (s_sh_wr_pred)
cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence();
}
thread_block_reduce();
if (group_blocks == -1 && last) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4 *>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4 *>(&frag_s)[1] = sh_s[s_sh_rd + 4];
}
}
if (slice_count > 1) { // only globally reduce if there is more than one
// block in a slice
barrier_acquire(&locks[slice_col], slice_idx);
global_reduce(slice_idx == 0, last);
barrier_release(&locks[slice_col], last);
}
if (last) // only the last block in a slice actually writes the result
write_result();
slice_row = 0;
slice_col_par++;
slice_col++;
init_slice();
if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
if (slice_col == 0) {
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] -= b_gl_stride;
}
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
start_pipes();
}
}
}
}
#else
template <const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks = -1 // number of consecutive 16x16 blocks with
// a separate quantization scale
>
__global__ void
Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
int4 *__restrict__ C, // fp16 output buffer of shape mxn
const int4
*__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int *locks // extra global storage for barrier synchronization
) {
// Marlin is not implemented yet for SM < 8.0
assert(false);
return;
}
#endif
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
const int USER_THREADS =
256; // Note: This is only used with user-provided thread_k/n
const int STAGES = 4; // 4 pipeline stages fit into shared memory
const int SHARED_MEM =
96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;
static constexpr int tile_size = 16;
static constexpr int max_par = 16;
static constexpr int pack_factor_4bit =
8; // We have 8 4-bit vals inside a 32 bit
#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
GROUP_BLOCKS, NUM_THREADS) \
else if (thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute(Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
SHARED_MEM); \
Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, GROUP_BLOCKS><<<blocks, NUM_THREADS, SHARED_MEM, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \
}
typedef struct {
int thread_k;
int thread_n;
int num_threads;
} thread_config_t;
thread_config_t small_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{128, 128, 256}, // Default
{128, 64, 128}, // Reduce N 2X, same K
{64, 256, 256}, // Reduce K 2X, increase N 2X
{64, 128, 128}, // Reduce K 2X, same N
};
thread_config_t large_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{64, 256, 256}, // Default
{128, 128, 256}, // Reduce N 2X, increase K 2X
{64, 128, 128}, // Reduce N 2X, same K
{128, 64, 128}, // Reduce N 4X, increase K 2X
};
bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
int prob_k) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
return false;
}
// Verify K/N are divisible by thread K/N
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
return false;
}
// thread_k can be only 128 or 64 (because it must be less than groupsize
// which is 128)
if (th_config.thread_k != 128 && th_config.thread_k != 64) {
return false;
}
// Verify min for thread K/N
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
return false;
}
// num_threads must be at least 128 (= 4 warps)
if (th_config.num_threads < 128) {
return false;
}
return true;
}
thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
if (prob_m <= 16) {
for (auto th_config : small_batch_thread_configs) {
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
return th_config;
}
}
} else {
for (auto th_config : large_batch_thread_configs) {
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
return th_config;
}
}
}
return thread_config_t{-1, -1, -1};
}
#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m,
int prob_n, int prob_k, void *workspace, int groupsize = -1,
int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
int thread_n = -1, int sms = -1, int max_par = 16) {
int tot_m = prob_m;
int tot_m_blocks = ceildiv(tot_m, 16);
int pad = 16 * tot_m_blocks - tot_m;
if (sms == -1)
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
// Set thread config
thread_config_t th_config;
if (thread_k != -1 && thread_n != -1) {
// User-defined config
th_config = thread_config_t{thread_k, thread_n, USER_THREADS};
} else {
// Auto config
th_config = determine_thread_config(prob_m, prob_n, prob_k);
}
if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) {
throw std::runtime_error(
"Invalid thread config: thread_k = " + str(th_config.thread_k) +
", thread_n = " + str(th_config.thread_n) +
", num_threads = " + str(th_config.num_threads) + " for MKN = [" +
str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]");
}
// Uncomment for debug
// std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) +
// ", thread_n = " + str(th_config.thread_n) +
// ", num_threads = " + str(th_config.num_threads) + " for
// MKN = [" + str(prob_m) +
// ", " + str(prob_k) + ", " + str(prob_n) + "]\n";
int num_threads = th_config.num_threads;
thread_k = th_config.thread_k;
thread_n = th_config.thread_n;
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
int blocks = sms;
if (prob_m == 0 || prob_n == 0 || prob_k == 0) {
return;
}
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
" is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k);
if (group_blocks != -1) {
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
}
const int4 *A_ptr = (const int4 *)A;
const int4 *B_ptr = (const int4 *)B;
int4 *C_ptr = (int4 *)C;
const int4 *s_ptr = (const int4 *)s;
int *locks = (int *)workspace;
for (int i = 0; i < tot_m_blocks; i += 4) {
int thread_m_blocks = tot_m_blocks - i;
prob_m = tot_m - 16 * i;
int par = 1;
if (thread_m_blocks > 4) {
// Note that parallel > 1 currently only works for inputs without any
// padding
par = (16 * thread_m_blocks - pad) / 64;
if (par > max_par)
par = max_par;
prob_m = 64 * par;
i += 4 * (par - 1);
thread_m_blocks = 4;
}
// For compilation speed, we only define the kernel configurations that have
// seemed useful (in terms of performance) in our testing, however many more
// are, in principle, possible.
if (false) {
}
CALL_IF(8, 8, 256)
CALL_IF(16, 4, 256)
CALL_IF(8, 4, 128)
CALL_IF(4, 8, 128)
else {
throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
", " + str(prob_k) + ", " + str(prob_n) + "]" +
", groupsize = " + str(groupsize) +
", thread_m_blocks = " + str(thread_m_blocks) +
", thread_n_blocks = " + str(thread_n_blocks) +
", thread_k_blocks = " + str(thread_k_blocks));
}
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
}
}
} // namespace marlin
torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &workspace,
int64_t size_m, int64_t size_n, int64_t size_k) {
// Verify M
TORCH_CHECK(size_m == a.size(0),
"Shape mismatch: a.size(0) = " + str(a.size(0)) +
", size_m = " + str(size_m));
// Verify K
TORCH_CHECK(size_k == a.size(1),
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
", size_k = " + str(size_k));
TORCH_CHECK(size_k % marlin::tile_size == 0,
"size_k = " + str(size_k) +
" is not divisible by tile_size = " + str(marlin::tile_size));
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = " +
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
", tile_size = " + str(marlin::tile_size));
// Verify N
TORCH_CHECK(b_scales.size(1) == size_n,
"b_scales.size(1) = " + str(b_scales.size(1)) +
", size_n = " + str(size_n));
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
" is not divisible by tile_size = " + str(marlin::tile_size));
int actual_size_n =
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit;
TORCH_CHECK(size_n == actual_size_n,
"size_n = " + str(size_n) +
", actual_size_n = " + str(actual_size_n));
// Verify A device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
// Verify B device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
// Verify scales device and strides
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
// Alloc C matrix
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c = torch::empty({size_m, size_n}, options);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_k = -1;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_n = -1;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int sms = -1;
// Detect groupsize
if (b_scales.size(0) != 1) {
TORCH_CHECK(size_k % b_scales.size(0) == 0,
"size_k = " + str(size_k) +
", is not divisible by b_scales.size(0) = " +
str(b_scales.size(0)));
}
int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0);
// Verify groupsize
TORCH_CHECK(groupsize == -1 || groupsize == 128,
"Unexpected groupsize = " + str(groupsize));
// Verify workspace size
TORCH_CHECK(
size_n % marlin::min_thread_n == 0,
"size_n = " + str(size_n) +
", is not divisible by min_thread_n = " + str(marlin::min_thread_n));
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = " + str(workspace.numel()) +
" is below min_workspace_size = " + str(min_workspace_size));
int dev = a.get_device();
marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
b_scales.data_ptr(), size_m, size_n, size_k,
workspace.data_ptr(), groupsize, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n,
sms, marlin::max_par);
return c;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment