Unverified Commit c3290cad authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

[Feature] Blazing fast W4A16 inference (#202)

* add w4a16

* fix `deploy.py`

* add doc

* add w4a16 kernels

* fuse w1/w3 & bugfixes

* fix typo

* python

* guard sm75/80 features

* add missing header

* refactor

* qkvo bias

* update cost model

* fix lint

* update `deploy.py`
parent d3dbe179
......@@ -304,6 +304,7 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:llama_fmha>
$<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:LlamaTritonBackend>
$<TARGET_OBJECTS:gemm_s4_f16>
$<TARGET_OBJECTS:TopKSamplingLayer>
$<TARGET_OBJECTS:TopPSamplingLayer>
$<TARGET_OBJECTS:TransformerTritonBackend>
......
......@@ -34,6 +34,19 @@ bash workspace/service_docker_up.sh
</details>
<details open>
<summary><b>7B with INT4 weight only quantization</b></summary>
```shell
python3 -m lmdeploy.serve.turbomind.deploy llama2 /path/to/llama-2-7b-chat-hf \
--model_format awq \
--group_size 128 \
--quant_path /path/to/awq-quant-weight.pt
bash workspace/service_docker_up.sh
```
</details>
## Serving [LLaMA](https://github.com/facebookresearch/llama)
Weights for the LLaMA models can be obtained from by filling out [this form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
......
......@@ -5,6 +5,7 @@ import os
import os.path as osp
import re
import shutil
import sys
from pathlib import Path
import fire
......@@ -12,9 +13,10 @@ import safetensors
import torch
from sentencepiece import SentencePieceProcessor
import lmdeploy
from lmdeploy.model import MODELS
supported_formats = ['llama', 'hf']
supported_formats = ['llama', 'hf', 'awq']
def get_package_root_path():
......@@ -107,7 +109,9 @@ def export(model_name: str,
tokenizer_path: str,
out_dir: str,
tp: int,
size_per_head: int = 128):
size_per_head: int = 128,
group_size: int = 0,
weight_type: str = 'fp16'):
"""Export deploying information to a config file.
Args:
......@@ -127,9 +131,10 @@ def export(model_name: str,
print(name, param.shape)
if param.dtype in [torch.float, torch.bfloat16]:
param = param.half()
param.contiguous().numpy().tofile(osp.join(out_dir, name))
param.contiguous().cpu().numpy().tofile(osp.join(out_dir, name))
attn_bias = False
inter_size = 0
# reverse the splitting axes since the weights are transposed above
for param_name, param_data in model_params.items():
......@@ -141,10 +146,14 @@ def export(model_name: str,
if key == 'w_qkv' and ext == 'bias':
attn_bias = True
copy = False
if key in ['w1', 'w3']:
if key in ['w1', 'w3', 'w13']:
split_dim = -1
# TODO: move parameter extraction outside of the loop
if key == 'w1':
inter_size = param_data.shape[-1]
inter_size = max(inter_size, param_data.shape[-1])
elif key == 'w13':
inter_size = max(inter_size, param_data.shape[-1] // 2)
elif key == 'w_qkv':
split_dim = -2
elif key in ['w2', 'wo']:
......@@ -170,6 +179,8 @@ def export(model_name: str,
else:
save_bin(param_data, param_name)
assert inter_size > 0
# export config and save it to {out_dir}/config.ini
model = MODELS.get(model_name)()
vocab_size, bos_id, eos_id = tokenizer_info(tokenizer_path)
......@@ -188,7 +199,8 @@ def export(model_name: str,
attn_bias=int(attn_bias),
start_id=bos_id,
end_id=eos_id,
weight_type='fp16',
weight_type=weight_type,
group_size=group_size,
# parameters for turbomind
max_batch_size=32,
max_context_token_num=4,
......@@ -329,7 +341,7 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
def permute(x: torch.Tensor):
SIZE_PER_HEAD = 128
if x.shape[-1] > 1: # qweights
if x.shape[-1] > 1:
dim = x.shape[-1]
n_heads = dim // SIZE_PER_HEAD
return x.view(-1, n_heads, 2,
......@@ -491,6 +503,228 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
tokenizer_path, triton_models_path, tp)
def deploy_awq(model_name: str, model_path: str, tokenizer_path: str,
triton_models_path: str, tp: int, quant_path: str,
group_size: int):
"""Deploy a model with huggingface transformers' format.
Args:
model_name (str): the name of the to-be-deployed model
model_path (str): the path of the directory where the model weight
files are
tokenizer_path (str): the path of the tokenizer model path
triton_models_path (str): the path of the exported triton models
tp (int): the number of tensor parallelism
quant_path (str): path of the quantized model, which can be None
group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits
"""
if tokenizer_path is None:
tokenizer_path = osp.join(model_path, 'tokenizer.model')
if osp.exists(tokenizer_path):
shutil.copy(tokenizer_path,
osp.join(triton_models_path, 'tokenizer/tokenizer.model'))
for _file in os.listdir(model_path):
if _file.endswith('.json') or _file.endswith('.py'):
json_path = osp.join(model_path, _file)
shutil.copy(json_path,
osp.join(triton_models_path, 'tokenizer', _file))
with get_package_root_path() as root_path:
shutil.copy(osp.join(root_path, 'turbomind/tokenizer.py'),
osp.join(triton_models_path, 'tokenizer'))
else:
print(f'tokenizer model {tokenizer_path} does not exist')
exit(-1)
# read model arguments from params.json
try:
params_path = osp.join(model_path, 'config.json')
with open(params_path) as f:
model_arg = json.load(f)
num_layer = model_arg['num_hidden_layers']
norm_eps = model_arg['rms_norm_eps']
if 'num_key_value_heads' in model_arg:
kv_head_num = model_arg['num_key_value_heads']
else:
kv_head_num = model_arg['num_attention_heads']
except Exception as e:
print(f'get "num_hidden_layers" and "rms_norm_eps" from '
f'{params_path} failed: {e}')
return False
# convert weights from hf to turbomind
if quant_path is None:
_files = [
osp.join(model_path, file) for file in os.listdir(model_path)
if file.endswith('.bin')
]
_files = sorted(_files)
else:
_files = [quant_path]
model_params = {}
_params = {}
for _file in _files:
_tmp = torch.load(_file, map_location='cpu')
_params.update(_tmp)
def get_tensor(name):
"""return tensor according its name."""
return _params[name].cuda().contiguous()
# import _turbomind as _tm
# TODO: find another way import _turbomind
lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
sys.path.append(osp.join(lmdeploy_dir, 'lib'))
import _turbomind as _tm # noqa: E402
def transpose_qk(src: torch.Tensor):
assert src.is_contiguous()
dst = torch.zeros_like(src)
_tm.transpose_qk_s4_k_m8(src, dst,
src.size(-1) * 8, src.size(0), group_size)
return dst
def fuse_w1_w3(w1_qw: torch.Tensor, w1_qz: torch.Tensor,
w1_s: torch.Tensor, w3_qw: torch.Tensor,
w3_qz: torch.Tensor, w3_s: torch.Tensor):
def fuse(a: torch.Tensor, b: torch.Tensor):
ab = torch.cat((a, b)).contiguous()
_ab = torch.zeros_like(ab)
_tm.fuse_w1_w3_s4_k_m8(ab, _ab, a.size(-1) * 8, a.size(0))
return _ab.view(a.size(0), -1)
w13_qw = fuse(w1_qw, w3_qw)
w13_qz = fuse(w1_qz, w3_qz)
w13_s = torch.cat((w1_s, w3_s)).view(2, w1_s.size(0), -1)
w13_s = w13_s.permute(1, 2, 0).contiguous().view(w1_s.size(0), -1)
return w13_qw, w13_qz, w13_s
def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size: int):
assert qw.is_contiguous()
assert qz.is_contiguous()
assert s.is_contiguous()
_qw = torch.zeros_like(qw)
_sz = torch.zeros_like(s, dtype=torch.int32)
_ws = torch.zeros_like(s)
_tm.convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
qw.size(-1) * 8, qw.size(0), group_size)
return _qw, _sz
attn_bias = False
for i in range(num_layer):
print(i)
# attention weights
q_qw = get_tensor(f'model.layers.{i}.self_attn.q_proj.qweight')
k_qw = get_tensor(f'model.layers.{i}.self_attn.k_proj.qweight')
v_qw = get_tensor(f'model.layers.{i}.self_attn.v_proj.qweight')
o_qw = get_tensor(f'model.layers.{i}.self_attn.o_proj.qweight')
q_qz = get_tensor(f'model.layers.{i}.self_attn.q_proj.qzeros')
k_qz = get_tensor(f'model.layers.{i}.self_attn.k_proj.qzeros')
v_qz = get_tensor(f'model.layers.{i}.self_attn.v_proj.qzeros')
o_qz = get_tensor(f'model.layers.{i}.self_attn.o_proj.qzeros')
q_s = get_tensor(f'model.layers.{i}.self_attn.q_proj.scales')
k_s = get_tensor(f'model.layers.{i}.self_attn.k_proj.scales')
v_s = get_tensor(f'model.layers.{i}.self_attn.v_proj.scales')
o_s = get_tensor(f'model.layers.{i}.self_attn.o_proj.scales')
try:
q_b = get_tensor(f'model.layers.{i}.self_attn.q_proj.bias')
k_b = get_tensor(f'model.layers.{i}.self_attn.k_proj.bias')
v_b = get_tensor(f'model.layers.{i}.self_attn.v_proj.bias')
o_b = get_tensor(f'model.layers.{i}.self_attn.o_proj.bias')
attn_bias = True
except: # noqa: E722
pass
q_qw = transpose_qk(q_qw)
k_qw = transpose_qk(k_qw)
q_qz = transpose_qk(q_qz)
k_qz = transpose_qk(k_qz)
q_s = permute(q_s)
k_s = permute(k_s)
qkv_qw = merge_qkv(q_qw, k_qw, v_qw, tp, dim=2)
qkv_qz = merge_qkv(q_qz, k_qz, v_qz, tp, dim=2)
qkv_s = merge_qkv(q_s, k_s, v_s, tp, dim=2)
qkv_qw, qkv_sz = convert_s4(qkv_qw, qkv_qz, qkv_s, group_size)
model_params[f'layers.{i}.attention.w_qkv.qweight'] = qkv_qw
model_params[f'layers.{i}.attention.w_qkv.scales_zeros'] = qkv_sz
o_qw, o_sz = convert_s4(o_qw, o_qz, o_s, group_size)
model_params[f'layers.{i}.attention.wo.qweight'] = o_qw
model_params[f'layers.{i}.attention.wo.scales_zeros'] = o_sz
if attn_bias:
q_b = permute(q_b)
k_b = permute(k_b)
qkv_b = merge_qkv(q_b, k_b, v_b, tp, dim=1)
model_params[f'layers.{i}.attention.w_qkv.bias'] = qkv_b
model_params[f'layers.{i}.attention.wo.bias'] = o_b
# ffn weights
w1_qw = get_tensor(f'model.layers.{i}.mlp.gate_proj.qweight')
w2_qw = get_tensor(f'model.layers.{i}.mlp.down_proj.qweight')
w3_qw = get_tensor(f'model.layers.{i}.mlp.up_proj.qweight')
w1_qz = get_tensor(f'model.layers.{i}.mlp.gate_proj.qzeros')
w2_qz = get_tensor(f'model.layers.{i}.mlp.down_proj.qzeros')
w3_qz = get_tensor(f'model.layers.{i}.mlp.up_proj.qzeros')
w1_s = get_tensor(f'model.layers.{i}.mlp.gate_proj.scales')
w2_s = get_tensor(f'model.layers.{i}.mlp.down_proj.scales')
w3_s = get_tensor(f'model.layers.{i}.mlp.up_proj.scales')
w13_qw, w13_qz, w13_s = fuse_w1_w3(w1_qw, w1_qz, w1_s, w3_qw, w3_qz,
w3_s)
w13_qw, w13_sz = convert_s4(w13_qw, w13_qz, w13_s, group_size)
w2_qw, w2_sz = convert_s4(w2_qw, w2_qz, w2_s, group_size)
model_params[f'layers.{i}.feed_forward.w13.qweight'] = w13_qw
model_params[f'layers.{i}.feed_forward.w13.scales_zeros'] = w13_sz
model_params[f'layers.{i}.feed_forward.w2.qweight'] = w2_qw
model_params[f'layers.{i}.feed_forward.w2.scales_zeros'] = w2_sz
# norm weights
attn_norm = get_tensor(f'model.layers.{i}.input_layernorm.weight')
ffn_norm = get_tensor(
f'model.layers.{i}.post_attention_layernorm.weight')
model_params[f'layers.{i}.attention_norm.weight'] = attn_norm
model_params[f'layers.{i}.ffn_norm.weight'] = ffn_norm
other = [('tok_embeddings.weight', 'model.embed_tokens.weight'),
('norm.weight', 'model.norm.weight'),
('output.weight', 'lm_head.weight')]
for ft, hf in other:
model_params[ft] = get_tensor(hf)
return export(model_name,
num_layer,
norm_eps,
kv_head_num,
model_params,
tokenizer_path,
triton_models_path,
tp,
weight_type='int4',
group_size=group_size)
def pack_model_repository(workspace_path: str):
"""package the model repository.
......@@ -521,7 +755,9 @@ def main(model_name: str,
model_format: str = 'hf',
tokenizer_path: str = None,
dst_path: str = './workspace',
tp: int = 1):
tp: int = 1,
quant_path: str = None,
group_size: int = 0):
"""deploy llama family models via turbomind.
Args:
......@@ -533,6 +769,9 @@ def main(model_name: str,
tokenizer_path (str): the path of tokenizer model
dst_path (str): the destination path that saves outputs
tp (int): the number of GPUs used for tensor parallelism
quant_path (str): path of the quantized model, which can be None
group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits
"""
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
......@@ -558,9 +797,12 @@ def main(model_name: str,
if model_format == 'llama':
res = deploy_llama(model_name, model_path, tokenizer_path,
triton_models_path, tp)
else:
elif model_format == 'hf':
res = deploy_hf(model_name, model_path, tokenizer_path,
triton_models_path, tp)
elif model_format == 'awq':
res = deploy_awq(model_name, model_path, tokenizer_path,
triton_models_path, tp, quant_path, group_size)
# update `tensor_para_size` in `triton_models/interactive/config.pbtxt`
with open(osp.join(triton_models_path, 'interactive/config.pbtxt'),
......
......@@ -69,3 +69,5 @@ set_property(TARGET sampling_penalty_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOL
add_library(custom_ar_kernels STATIC custom_ar_kernels.cu)
set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_subdirectory(gemm_s_f16)
# Copyright (c) OpenMMLab. All rights reserved.
add_library(gemm_s4_f16 STATIC gemm_s4_f16.cu format.cu)
target_compile_options(gemm_s4_f16 PRIVATE
--generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
set_property(TARGET gemm_s4_f16 PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET gemm_s4_f16 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <cassert>
#include <cstdint>
#include <cuda_fp16.h>
#include <type_traits>
namespace turbomind {
#ifndef TURBOMIND_S4_DEQUANT_USE_FMA
#define TURBOMIND_S4_DEQUANT_USE_FMA 0
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
#define TURBOMIND_ARCH_SM75 1
#else
#define TURBOMIND_ARCH_SM75 0
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
#define TURBOMIND_ARCH_SM80 1
#else
#define TURBOMIND_ARCH_SM80 0
#endif
constexpr int WARP_SIZE = 32;
#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
#if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
#define PRAGMA_UNROLL _Pragma("unroll")
#define PRAGMA_NO_UNROLL _Pragma("unroll 1")
#else
#define PRAGMA_UNROLL #pragma unroll
#define PRAGMA_NO_UNROLL #pragma unroll 1
#endif
#else
#define PRAGMA_UNROLL
#define PRAGMA_NO_UNROLL
#endif
// Modified from NVIDIA FasterTransformer:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
// Modified from llm-awq https://github.com/mit-han-lab/llm-awq/blob/main/awq/kernels/csrc/quantization/dequantize.cuh
__inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is
// thanks to the register packing format and the fact that we force our
// integers to be unsigned, and account for this in the fp16 subtractions. In
// addition, I exploit the fact that sub and fma have the same throughput in
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
// the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit
// float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result;
}
__inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
{
uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const& i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOT_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t MAGIC_NUM_0 = 0x64006400; // `1024`
static constexpr uint32_t MAGIC_NUM_1 = 0x54005400; // `64`
static constexpr uint32_t MAGIC_NUM_2 = MAGIC_NUM_1 >> 4; // `64` >> 4
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8;
if (0) { // 1024 & 64
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MAGIC_NUM_0));
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MAGIC_NUM_1));
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MAGIC_NUM_0));
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MAGIC_NUM_1));
}
else { // 64 only, trade 4 hfma2 with 2 shifts
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
h[0] <<= 4;
h[2] <<= 4;
// we don't need to subtract the magic nums because zeros will go through the same dequant function
// and carry the same magic constant, the magic num will be canceled out after subtracting zeros
}
return result;
}
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr)
{
uint32_t smem_int_ptr;
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
: "=r"(smem_int_ptr)
: "l"(ptr));
return smem_int_ptr;
}
__inline__ __device__ void ldmatrix_m8n8_x4_b16(uint& d0, uint& d1, uint& d2, uint& d3, uint32_t smem_int_ptr)
{
#if TURBOMIND_ARCH_SM75
asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(smem_int_ptr));
#else
assert(TURBOMIND_ARCH_SM75);
#endif
}
__inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t smem_int_ptr)
{
#if TURBOMIND_ARCH_SM75
asm("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(d0), "=r"(d1) : "r"(smem_int_ptr));
#else
assert(TURBOMIND_ARCH_SM75);
#endif
}
__inline__ __device__ void wait_flag(int* lock, int status, int thread_id)
{
int state = 0;
while (__syncthreads_and(state != status)) {
if (thread_id == 0) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#else
asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#endif
}
}
__syncthreads(); // memory fence
}
__inline__ __device__ void release_flag(int* lock, int status, int thread_id)
{
__syncthreads(); // memory fence
if (thread_id == 0) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#else
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#endif
}
}
__inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
{
uint s, z;
(half2&)z = __halves2half2(q.x, q.x);
(half2&)s = __halves2half2(q.y, q.y);
auto& t = (const uint&)x;
uint u, v;
if (TURBOMIND_S4_DEQUANT_USE_FMA) {
asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z));
}
else {
asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z));
asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s));
}
return (half2&)v;
}
template<typename T, int N>
struct Array {
T a[N];
__device__ __host__ constexpr T& operator[](int i) noexcept
{
return a[i];
}
__device__ __host__ constexpr const T& operator[](int i) const noexcept
{
return a[i];
}
};
template<int... Ns>
struct Shape {
static constexpr Array<int, sizeof...(Ns)> data_{Ns...};
constexpr Shape() = default;
Shape(std::integral_constant<int, Ns>...){};
template<int index>
constexpr auto get() const noexcept
{
return std::integral_constant<int, data_[index]>{};
}
constexpr auto m() const noexcept
{
return get<0>();
}
constexpr auto n() const noexcept
{
return get<1>();
}
constexpr auto k() const noexcept
{
return get<2>();
}
constexpr int c() const noexcept
{
return get<0>();
}
constexpr int s() const noexcept
{
return get<1>();
}
constexpr int count() const noexcept
{
return (Ns * ...);
}
};
template<int... Ns>
Shape(std::integral_constant<int, Ns>...) -> Shape<Ns...>;
template<int... Ns>
inline constexpr Shape<Ns...> shape_c{};
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "common.h"
#include <cstdint>
namespace turbomind {
template<typename T>
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
{
#if TURBOMIND_ARCH_SM80
constexpr int cp_size = sizeof(T);
static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16");
asm volatile("{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global.L2::256B [%1], [%2], %3;\n"
"}\n" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
#else
assert(TURBOMIND_ARCH_SM80);
#endif
}
template<typename T>
__inline__ __device__ void cp_async_cg_B(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
{
#if TURBOMIND_ARCH_SM80
constexpr int cp_size = sizeof(T);
static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16");
asm volatile("{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n"
"}\n" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
#else
assert(TURBOMIND_ARCH_SM80);
#endif
}
template<typename T>
__inline__ __device__ void cp_async_ca(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
{
#if TURBOMIND_ARCH_SM80
constexpr int cp_size = sizeof(T);
asm volatile("{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n"
"}\n" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
#else
assert(TURBOMIND_ARCH_SM80);
#endif
}
template<int WARPS, int CTA_M, int CTA_N, int CTA_K, int STAGES, int SLICES>
struct IteratorA {
static constexpr int SLICE_K = CTA_K / SLICES;
using AccessType = uint4;
static constexpr int kAccessSize = sizeof(AccessType);
static_assert(CTA_M % 32 == 0 && CTA_K % 32 == 0, "A is pre-formatted as 32x32 tiles");
// A is [K/32, M/32, WARP_SIZE] uint4
static constexpr int kShapeM = CTA_M;
static constexpr int kShapeK = SLICE_K / 32;
// thread access shape
static constexpr int kAccessM = 1;
static constexpr int kAccessK = 1;
// warp thread arrangement
static constexpr int kWarpThreadC = 32;
static constexpr int kWarpThreadS = 1;
// warp shape per access
static constexpr int kWarpAccessM = kWarpThreadC * kAccessM; // 32
static constexpr int kWarpAccessK = kWarpThreadS * kAccessK; // 1
// warp access iterations
static constexpr int kWarpIterM = kShapeM / kWarpAccessM;
static constexpr int kWarpIterK = kShapeK / kWarpAccessK;
// warp arrangement
static constexpr int kWarpM = kWarpIterM >= WARPS ? WARPS : kWarpIterM;
static constexpr int kWarpK = WARPS > kWarpIterM ? (WARPS / kWarpM) : 1;
// iterations
static constexpr int kIterM = kWarpIterM / kWarpM;
static constexpr int kIterK = kWarpIterK / kWarpK;
static constexpr int kIterCount = kIterM * kIterK;
static_assert(kIterCount > 0);
// warp footprint
static constexpr int kWarpFootprintM = kWarpAccessM * kIterM;
static constexpr int kWarpFootprintK = kWarpAccessK * kIterK;
static constexpr int kSizePerStage = kShapeK * kShapeM;
static constexpr int kSmemByteSize = kAccessSize * STAGES * kSizePerStage;
const uint* src_;
AccessType* smem_;
uint32_t smem_int_ptr_;
const int m_;
const int k_;
const int warp_id_;
const int lane_id_;
int src_offset_;
int dst_offset_;
int src_step_m_;
int src_step_k_;
int src_step_s_;
int dst_step_m_;
int dst_step_k_;
int dst_step_s_;
int iter_m_{0};
IteratorA() = default;
__device__ IteratorA(const uint* src, void* smem, int m, int k, int cta_m, int cta_k, int warp_id, int lane_id):
src_(src),
smem_((AccessType*)smem),
smem_int_ptr_(cast_smem_ptr_to_uint(smem)),
m_(m),
k_(k),
warp_id_(warp_id),
lane_id_(lane_id)
{
const int warp_offset_m = warp_id_ % kWarpM;
const int warp_offset_k = warp_id_ / kWarpM;
const int warp_thread_offset_m = lane_id_ % kWarpThreadC;
const int warp_thread_offset_k = lane_id_ / kWarpThreadC;
const int cta_thread_offset_m = kWarpFootprintM * warp_offset_m + warp_thread_offset_m * kAccessM;
const int cta_thread_offset_k = kWarpFootprintK * warp_offset_k + warp_thread_offset_k * kAccessK;
const int src_offset_m = cta_thread_offset_m + cta_m;
const int src_offset_k = cta_thread_offset_k + cta_k / 32;
src_offset_ = src_offset_k * m_ + src_offset_m;
src_step_m_ = kWarpAccessM;
src_step_k_ = kWarpAccessK * m_ - kIterM * kWarpAccessM;
src_step_s_ = CTA_K / 32 * m_ - kIterK * kWarpAccessK * m_;
const int dst_offset_m = cta_thread_offset_m;
const int dst_offset_k = cta_thread_offset_k;
dst_offset_ = dst_offset_k * kShapeM + dst_offset_m;
dst_step_m_ = kWarpAccessM;
dst_step_k_ = kWarpAccessK * kShapeM - kIterM * kWarpAccessM;
dst_step_s_ = SLICE_K / 32 * kShapeM - kIterK * kWarpAccessK * kShapeM;
dst_offset_ *= kAccessSize;
dst_step_m_ *= kAccessSize;
dst_step_k_ *= kAccessSize;
dst_step_s_ *= kAccessSize;
}
__device__ void prefetch_stage(bool mask)
{
PRAGMA_UNROLL
for (int i = 0; i < kIterCount; ++i) {
prefetch(mask);
++(*this);
}
next_stage();
}
__device__ void prefetch_batch(int batch_idx, int batch_size, bool mask)
{
PRAGMA_UNROLL
for (int i = 0; i < batch_size; ++i) {
if (batch_idx * batch_size + i < kIterCount) {
prefetch(mask);
++(*this);
}
}
}
__device__ IteratorA& operator++()
{
src_offset_ += src_step_m_;
dst_offset_ += dst_step_m_;
++iter_m_;
if (iter_m_ < kIterM) {
return *this;
}
iter_m_ = 0;
src_offset_ += src_step_k_;
dst_offset_ += dst_step_k_;
return *this;
}
__device__ void next_stage()
{
src_offset_ += src_step_s_;
dst_offset_ += dst_step_s_;
if (dst_offset_ >= kSmemByteSize) {
dst_offset_ -= kSmemByteSize;
}
}
__device__ void prefetch(bool mask)
{
cp_async_cg_A(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
}
};
template<int WARPS, int CTA_M, int CTA_N, int CTA_K, int STAGES, int SLICES, int GROUP_SIZE>
struct IteratorQ {
static constexpr int SLICE_K = CTA_K / SLICES;
using AccessType = uint;
static constexpr int kAccessSize = sizeof(AccessType);
static constexpr int kAccessM = kAccessSize / sizeof(half2);
static constexpr int kAccessK = GROUP_SIZE;
// warp thread arrangement
static constexpr int kWarpThreadC = 32;
static constexpr int kWarpThreadS = 1;
// warp shape per access
static constexpr int kWarpAccessM = kWarpThreadC * kAccessM; // 32
static constexpr int kWarpAccessK = kWarpThreadS * kAccessK; // GROUP_SIZE
// warp access iterations
static constexpr int kWarpIterM = CTA_M / kWarpAccessM; // CTA_M / 32
static constexpr int kWarpIterK = SLICE_K / kWarpAccessK; // SLICE_K / GROUP_SIZE, maybe 0
// kWarpIterK == 0 => SLICE_K < kWarpAccessK => kIterK == 1
// warp arrangement
static constexpr int kWarpM = kWarpIterM >= WARPS ? WARPS : kWarpIterM;
static constexpr int kWarpK = WARPS > kWarpIterM ? WARPS / kWarpM : 1;
// iterations
static constexpr int kIterM = kWarpIterM / kWarpM;
static constexpr int kIterK = kWarpIterK >= kWarpK ? kWarpIterK / kWarpK : 1;
static constexpr int kIterCount = kIterM * kIterK;
// warp footprint
static constexpr int kWarpFootprintM = kWarpAccessM * kIterM;
static constexpr int kWarpFootprintK = kWarpAccessK * kIterK;
static constexpr int kSizePerStage = std::max(SLICE_K / GROUP_SIZE, 1) * CTA_M;
static constexpr int kSmemByteSize = sizeof(uint) * STAGES * kSizePerStage;
const half2* const src_;
half2* const smem_;
uint32_t const smem_int_ptr_;
const int m_;
const int k_;
bool is_out_of_bound_; // mask for out-of-bound warps
int src_offset_k_;
int src_offset_m_;
int src_offset_;
int src_step_m_;
int src_step_k_;
int dst_offset_;
int dst_step_m_;
int dst_step_k_;
int tmp_src_offset_;
int tmp_dst_offset_;
int iter_m_{0};
struct Storage {
half2 data[SLICES][STAGES * kSizePerStage];
};
IteratorQ() = default;
__device__ IteratorQ(const half2* src, half2* smem, int m, int k, int cta_m, int cta_k, int warp_id, int lane_id):
src_(src), smem_(smem), smem_int_ptr_(cast_smem_ptr_to_uint(smem)), m_(m), k_(k)
{
const int warp_offset_m = warp_id % kWarpM;
const int warp_offset_k = warp_id / kWarpM;
const int warp_thread_offset_m = lane_id % kWarpThreadC;
const int warp_thread_offset_k = lane_id / kWarpThreadC;
const int cta_thread_offset_m = kWarpFootprintM * warp_offset_m + warp_thread_offset_m * kAccessM;
const int cta_thread_offset_k = kWarpFootprintK * warp_offset_k + warp_thread_offset_k * kAccessK;
// mask out-of-bound warps
is_out_of_bound_ = cta_thread_offset_k >= SLICE_K;
src_offset_m_ = cta_thread_offset_m + cta_m;
src_offset_k_ = cta_thread_offset_k + cta_k;
src_offset_ = src_offset_k_ / GROUP_SIZE * m_ + src_offset_m_;
src_step_m_ = kWarpAccessM;
src_step_k_ = m_ - kIterM * kWarpAccessM; // valid only when SLICE_K >= GROUP_SIZE
const int dst_offset_m = cta_thread_offset_m;
const int dst_offset_k = cta_thread_offset_k;
dst_offset_ = dst_offset_k / GROUP_SIZE * CTA_M + dst_offset_m;
dst_step_m_ = kWarpAccessM;
dst_step_k_ = CTA_M - kIterM * kWarpAccessM; // valid only when SLICE_K >= GROUP_SIZE
dst_offset_ *= kAccessSize;
dst_step_m_ *= kAccessSize;
dst_step_k_ *= kAccessSize;
tmp_src_offset_ = src_offset_;
tmp_dst_offset_ = dst_offset_;
}
__device__ void prefetch_stage(bool mask)
{
if (is_out_of_bound_) {
return;
}
PRAGMA_UNROLL
for (int i = 0; i < kIterCount; ++i) {
prefetch(mask);
++(*this);
}
next_stage();
}
__device__ void prefetch_batch(int batch_idx, int batch_size, bool mask)
{
if (is_out_of_bound_) {
return;
}
PRAGMA_UNROLL
for (int i = 0; i < batch_size; ++i) {
if (batch_idx * batch_size + i < kIterCount) {
prefetch(mask);
++(*this);
}
}
}
__device__ IteratorQ& operator++()
{
++iter_m_;
src_offset_ += src_step_m_;
dst_offset_ += dst_step_m_;
if (iter_m_ < kIterM) {
return *this;
}
iter_m_ = 0;
if constexpr (SLICE_K >= GROUP_SIZE) {
src_offset_ += src_step_k_;
dst_offset_ += dst_step_k_;
}
// else advnace offsets in `next_stage`
return *this;
}
__device__ void next_stage()
{
if constexpr (SLICE_K >= GROUP_SIZE) {
src_offset_ += (CTA_K / GROUP_SIZE - kIterK) * m_;
dst_offset_ += kAccessSize * (SLICE_K / GROUP_SIZE - kIterK) * CTA_M;
}
else { // SLICE_K < GROUP_SIZE, recompute `src_offset_`
src_offset_k_ += CTA_K;
src_offset_ = (src_offset_k_ / GROUP_SIZE) * m_ + src_offset_m_;
dst_offset_ += dst_step_k_;
}
if (dst_offset_ >= kSmemByteSize) {
dst_offset_ -= kSmemByteSize;
}
}
__device__ void prefetch(bool mask)
{
cp_async_ca(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
}
};
template<int WARPS, int CTA_M, int CTA_N, int CTA_K, int STAGES, int SLICES>
struct IteratorB {
static constexpr int SLICE_K = CTA_K / SLICES;
static constexpr int kElementSize = sizeof(half);
using AccessType = uint4;
static constexpr int kAccessSize = sizeof(AccessType);
static constexpr int kShapeK = SLICE_K;
static constexpr int kShapeN = CTA_N;
static constexpr int kAccessK = kAccessSize / sizeof(half);
static_assert(kShapeK % kAccessSize == 0);
// warp thread arrangement
static constexpr int kWarpThreadC = std::max(kShapeK / kAccessK, 1);
static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;
// warp shape per access
static constexpr int kWarpAccessK = kWarpThreadC * kAccessK;
static constexpr int kWarpAccessN = kWarpThreadS;
// warp access iterations
static constexpr int kWarpIterK = kShapeK / kWarpAccessK;
static constexpr int kWarpIterN = kShapeN / kWarpAccessN;
// warp arrangement
static constexpr int kWarpK = kWarpIterK >= WARPS ? WARPS : kWarpIterK;
static constexpr int kWarpN = WARPS > kWarpIterK ? WARPS / kWarpK : 1;
// iterations
static constexpr int kIterK = kWarpIterK / kWarpK;
static constexpr int kIterN = kWarpIterN >= kWarpN ? kWarpIterN / kWarpN : 1;
static constexpr int kIterCount = kIterK * kIterN;
static_assert(kIterCount > 0);
// warp footprint
static constexpr int kWarpFootprintK = kWarpAccessK * kIterK;
static constexpr int kWarpFootprintN = kWarpAccessN * kIterN;
// Eliminate bank-conflicts for 8x4 half2 tiles, watch out for misalignment
static constexpr int kSmemPadCtaK = SLICE_K + 8;
static constexpr int kSizePerTile = CTA_N * kSmemPadCtaK;
static constexpr int kSmemByteSize = kElementSize * STAGES * kSizePerTile;
const half* src_;
AccessType* const smem_; // [CTA_N, SLICE_K + 8]
const uint32_t smem_int_ptr_;
const int k_;
const int n_;
const int cta_n_;
const int warp_id_;
const int lane_id_;
const int c_;
const int s_;
int src_offset_n_;
int src_offset_;
int dst_offset_;
int src_step_k_;
int src_step_n_;
int dst_step_k_;
int dst_step_n_;
bool is_valid_n_;
int tmp_src_offset_;
int tmp_dst_offset_;
int tmp_src_offset_n_;
int iter_k_{0};
int iter_n_{0};
IteratorB() = default;
__device__ IteratorB(const half* src, void* smem, int k, int n, int cta_n, int cta_k, int warp_id, int lane_id):
src_(src),
smem_((AccessType*)smem),
smem_int_ptr_(cast_smem_ptr_to_uint(smem)),
k_(k),
n_(n),
cta_n_(cta_n),
warp_id_(warp_id),
lane_id_(lane_id),
c_(lane_id_ % kWarpThreadC),
s_(lane_id_ / kWarpThreadC)
{
const int warp_offset_k = warp_id_ % kWarpK;
const int warp_offset_n = warp_id_ / kWarpK;
const int warp_thread_offset_k = lane_id_ % kWarpThreadC;
const int warp_thread_offset_n = lane_id_ / kWarpThreadC;
const int cta_thread_offset_k = kWarpFootprintK * warp_offset_k + warp_thread_offset_k * kAccessK;
const int cta_thread_offset_n = kWarpFootprintN * warp_offset_n + warp_thread_offset_n;
const int src_offset_k = cta_thread_offset_k + cta_k;
src_offset_n_ = cta_thread_offset_n + cta_n_;
src_offset_ = src_offset_n_ * k_ + src_offset_k;
const int dst_offset_k = cta_thread_offset_k;
const int dst_offset_n = cta_thread_offset_n;
dst_offset_ = dst_offset_n * kSmemPadCtaK + dst_offset_k;
src_step_k_ = kWarpAccessK;
src_step_n_ = kWarpAccessN * k_ - kIterK * kWarpAccessK;
dst_step_k_ = kWarpAccessK;
dst_step_n_ = kWarpAccessN * kSmemPadCtaK - kIterK * kWarpAccessK;
dst_offset_ *= kElementSize;
dst_step_k_ *= kElementSize;
dst_step_n_ *= kElementSize;
tmp_src_offset_ = src_offset_;
tmp_dst_offset_ = dst_offset_;
tmp_src_offset_n_ = src_offset_n_;
is_valid_n_ = tmp_src_offset_n_ < n_;
}
__device__ void prefetch_stage(bool mask)
{
PRAGMA_UNROLL
for (int i = 0; i < kIterCount; ++i) {
prefetch(mask);
++(*this);
}
next_stage();
}
__device__ void prefetch_batch(int batch_idx, int batch_size, bool mask)
{
PRAGMA_UNROLL
for (int i = 0; i < batch_size; ++i) {
if (batch_idx * batch_size + i < kIterCount) {
prefetch(mask);
++(*this);
}
}
}
__device__ IteratorB& operator++()
{
if (!is_valid_n_) {
return *this;
}
// move to next k
tmp_src_offset_ += src_step_k_;
tmp_dst_offset_ += dst_step_k_;
++iter_k_;
if (iter_k_ < kIterK) {
return *this;
}
// move to next n
iter_k_ = 0;
tmp_src_offset_n_ += kWarpAccessN;
tmp_src_offset_ += src_step_n_;
tmp_dst_offset_ += dst_step_n_;
is_valid_n_ = tmp_src_offset_n_ < n_;
++iter_n_;
return *this;
}
__device__ void next_stage()
{
iter_n_ = 0;
src_offset_ += CTA_K;
dst_offset_ += kElementSize * kSizePerTile;
if (dst_offset_ >= kSmemByteSize) {
dst_offset_ -= kSmemByteSize;
}
tmp_src_offset_ = src_offset_;
tmp_dst_offset_ = dst_offset_;
tmp_src_offset_n_ = src_offset_n_;
is_valid_n_ = tmp_src_offset_n_ < n_;
}
__device__ void prefetch(bool mask)
{
cp_async_cg_B(
smem_int_ptr_ + tmp_dst_offset_, (const AccessType*)(src_ + tmp_src_offset_), is_valid_n_ && mask);
}
};
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#include "common.h"
#include <iostream>
namespace turbomind {
__device__ void atomic_assign_u4(uint32_t* address, uint32_t index, uint32_t value)
{
uint32_t old = *address;
uint32_t assumed;
do {
assumed = old;
uint32_t tmp = (assumed & ~(0xfu << (index * 4u))) | (value << (index * 4u));
old = atomicCAS(address, assumed, tmp);
} while (assumed != old);
}
__device__ uint32_t read_u4(const uint32_t* address, uint32_t index)
{
return (*address >> (index * 4u)) & 0xfu;
}
template<int... Ds>
__global__ void permute_u4(uint* dst, const uint* src, Array<int, sizeof...(Ds)> dims)
{
constexpr int N = sizeof...(Ds);
size_t count = 1;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
count *= dims[i];
}
constexpr int order[] = {Ds...};
for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
int indices[N]{};
PRAGMA_UNROLL
for (int j = N - 1, ii = i; j >= 0; --j) {
indices[j] = ii % dims[j];
ii /= dims[j];
}
auto data = read_u4(src + i / 8, i % 8);
int index = 0;
PRAGMA_UNROLL
for (int j = N - 1, stride = 1; j >= 0; --j) {
index += indices[order[j]] * stride;
stride *= dims[order[j]];
}
atomic_assign_u4(dst + index / 8, index % 8, data);
}
}
void reformat_s4_k8_m(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
{
// permutation for [k/8, m] layout
Array<int, 10> shape{k / 32, 2, 2, m / 32, 2, 2, 8, 2, 2, 2};
// |warp| lane | 2x2 | a0-7 |
permute_u4<0, 3, 6, 8, 9, 1, 4, 7, 2, 5><<<512, 512, 0, st>>>(dst, src, shape);
}
void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
{
// permutation for [k, m/8] layout
Array<int, 10> shape{k / 32, 2, 2, 4, 2, m / 32, 2, 2, 2, 4};
// |warp| lane | 2x2 | a0-7 |
permute_u4<0, 5, 9, 8, 3, 1, 6, 4, 2, 7><<<512, 512, 0, st>>>(dst, src, shape);
}
__global__ void dequantize_s4_offset_64(uint4* dst, const uint32_t* src, size_t count)
{
for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
dst[i] = dequantize_s4_to_fp16x2_v2(src[i]);
}
}
__global__ void merge_Q(half2* Q, const half* scales, const half* zeros, int count)
{
for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
if (TURBOMIND_S4_DEQUANT_USE_FMA) {
// dequant via HFMA2 has numerical statbility issue
Q[i] = __halves2half2(-zeros[i] * scales[i], scales[i]);
}
else {
Q[i] = __halves2half2(zeros[i], scales[i]);
}
}
}
void convert_s4_k_m8(uint32_t* A_dst,
half2* Q_dst,
half* workspace,
const uint32_t* A_src,
const half* scales,
const uint32_t* qzeros,
int m,
int k,
int group_size,
cudaStream_t st)
{
dequantize_s4_offset_64<<<256, 256, 0, st>>>((uint4*)workspace, qzeros, k / group_size * m / 8);
merge_Q<<<256, 256, 0, st>>>(Q_dst, scales, workspace, k / group_size * m);
reformat_s4_k_m8(A_dst, A_src, m, k, st);
}
void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st)
{
Array<int, 7> shape{k, m / size_per_head, 2, size_per_head / 2 / 8, 2, 2, 2};
// dequant transpose quant
// 0123456 -> 0123564 -> 0135642 -> 0135264
permute_u4<0, 1, 3, 5, 2, 6, 4><<<512, 512, 0, st>>>(dst, src, shape);
}
// [2, k, m/8] -> [k, m/8, 2]
void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
{
Array<int, 6> shape{2, k, m / 8, 2, 2, 2};
// dequant transpose quant
// 012345 -> 012453 -> 124530 -> 124053
permute_u4<1, 2, 4, 0, 5, 3><<<512, 512, 0, st>>>(dst, src, shape);
}
__global__ void dequantize_s4_kernel(uint4* dst, const uint* src, size_t count)
{
for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
dst[i] = dequantize_s4_to_fp16x2(src[i]);
}
}
void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t st)
{
dequantize_s4_kernel<<<512, 512>>>(dst, src, count);
}
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <cstdint>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
namespace turbomind {
void reformat_s4_k8_m(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st = {});
void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st = {});
void convert_s4_k_m8(uint32_t* A_dst,
half2* Q_dst,
half* workspace,
const uint32_t* A_src,
const half* scales,
const uint32_t* qzeros,
int m,
int k,
int group_size,
cudaStream_t st = {});
void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st = {});
void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st = {});
void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t st = {});
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#include "gemm_s4_f16.h"
#include "gemm_s4_f16_kernel.h"
#include "metric.h"
#include <algorithm>
#include <iomanip>
#include <ios>
#include <iostream>
#include <limits>
#include <numeric>
#include <stdexcept>
#include <tuple>
#include <vector>
namespace turbomind {
bool g_dump_kernel_info_once = false;
namespace ops {
struct Identity {
static __inline__ __device__ void apply(uint data, int m, int n, half* C, int M, int N)
{
if (n < N) {
(uint&)C[n * M + m] = (uint&)data;
}
}
};
struct SiluActivation {
static __inline__ __device__ void apply(uint data, int m, int n, half* C, int M, int N)
{
auto u = __half22float2((half2&)data);
float silu = u.x / (1.f + __expf(-u.x));
half val = __float2half_rn(silu * u.y);
if (n < N) {
C[n * (M / 2) + m / 2] = val;
}
}
};
} // namespace ops
template<typename... Ts>
struct OutputOps {
template<int index>
static __inline__ __device__ void apply(uint data, int m, int n, half* C, int M, int N)
{
std::tuple_element_t<index, std::tuple<Ts...>>::apply(data, m, n, C, M, N);
}
};
struct GemmS4F16::Impl {
using Kernels = std::vector<std::unique_ptr<IGemmKernel>>;
template<int GS, typename Op>
void Generate(std::vector<Kernels>& kernels)
{
// smem size (KB):
// sm75: 64
// sm80: 163
// sm86: 99
// sm89: 99
// sm90: 227
Kernels k;
// 256
k.emplace_back(new GemmKernel<Shape<256, 128, 32>, Shape<32, 128, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<256, 64, 64>, Shape<64, 64, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<256, 64, 32>, Shape<64, 64, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<256, 32, 64>, Shape<64, 32, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<256, 16, 256>, Shape<32, 16, 128>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<256, 8, 256>, Shape<32, 8, 128>, 3, GS, Op>{});
// 128
k.emplace_back(new GemmKernel<Shape<128, 128, 64>, Shape<32, 128, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 128, 32>, Shape<32, 128, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 96, 64>, Shape<32, 96, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 64, 64>, Shape<32, 64, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 64, 32>, Shape<32, 64, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 32, 128>, Shape<32, 32, 64>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 16, 256>, Shape<32, 16, 64>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 8, 512>, Shape<32, 8, 128>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 8, 512>, Shape<32, 8, 128>, 2, GS, Op>{}); // for 86/89
// 64
k.emplace_back(new GemmKernel<Shape<64, 16, 256>, Shape<32, 16, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<64, 8, 256>, Shape<32, 8, 32>, 3, GS, Op>{});
kernels.push_back(std::move(k));
}
void Measure(half* C,
const uint* A,
const half* B,
const half2* Q,
int m,
int n,
int k,
int group_size,
Type type,
std::vector<Metric>& metrics,
cudaStream_t st,
std::vector<Kernels>& _kernels)
{
int gid = -1;
for (size_t i = 0; i < group_sizes_.size(); ++i) {
if (group_sizes_[i] == group_size) {
gid = i;
break;
}
}
if (gid < 0) {
throw std::runtime_error("unsupported group size");
}
const auto& kernels = _kernels[gid];
metrics = std::vector<Metric>(kernels.size());
int best = 0;
for (size_t i = 0; i < kernels.size(); ++i) {
metrics[i].id = i;
kernels[i]->GetMetric(metrics[i], m, n, k);
if (!metrics[i].feasible) {
metrics[i].time = std::numeric_limits<float>::infinity();
metrics[i].count = 1;
continue;
}
if (Compare(metrics[i], metrics[best])) {
best = i;
}
for (size_t j = 0; j < kWarmup + kMeasure; ++j) {
if (j == kWarmup) {
cudaEventRecord(ev_start_, st);
}
kernels[i]->Launch(C, A, B, Q, m, n, k, type, st);
}
cudaEventRecord(ev_end_, st);
cudaEventSynchronize(ev_end_);
float ms{};
cudaEventElapsedTime(&ms, ev_start_, ev_end_);
metrics[i].time = ms;
metrics[i].count = kMeasure;
}
metrics[best].best = 1;
// sort metrics
std::vector<int> indices(kernels.size());
std::iota(indices.begin(), indices.end(), 0);
std::stable_sort(
indices.begin(), indices.end(), [&](int i, int j) { return metrics[i].time < metrics[j].time; });
if (g_dump_kernel_info_once) {
DumpMetrics(std::cerr, metrics, indices);
g_dump_kernel_info_once = 0;
}
std::vector<Metric> tmp;
for (size_t i = 0; i < indices.size(); ++i) {
tmp.push_back(metrics[indices[i]]);
}
metrics.swap(tmp);
}
static bool Compare(const Metric& a, const Metric& b)
{
if (a.feasible != b.feasible) {
return a.feasible > b.feasible;
}
if (a.prefer != b.prefer) {
return a.prefer > b.prefer;
}
return a.grid_norm < b.grid_norm;
}
int Estimate(int m, int n, int k, Kernels& kernels)
{
int best = 0;
std::vector<Metric> metrics(kernels.size());
for (size_t i = 0; i < kernels.size(); ++i) {
metrics[i].id = i;
kernels[i]->GetMetric(metrics[i], m, n, k);
if (Compare(metrics[i], metrics[best])) {
best = i;
}
}
if (g_dump_kernel_info_once) {
std::vector<int> indices(kernels.size());
std::iota(indices.begin(), indices.end(), 0);
std::stable_sort(
indices.begin(), indices.end(), [&](int i, int j) { return Compare(metrics[i], metrics[j]); });
DumpMetrics(std::cerr, metrics, indices);
g_dump_kernel_info_once = 0;
}
return best;
}
void Run(half* C,
const uint* A,
const half* B,
const half2* Q,
int m,
int n,
int k,
int group_size,
Type type,
int algo_id,
cudaStream_t st,
std::vector<Kernels>& kernels)
{
for (size_t i = 0; i < group_sizes_.size(); ++i) {
if (group_sizes_[i] == group_size) {
if (algo_id < 0) {
algo_id = Estimate(m, n, k, kernels[i]);
}
if (algo_id < 0) {
throw std::runtime_error("no feasible kernel found");
}
kernels[i].at(algo_id)->Launch(C, A, B, Q, m, n, k, type, st);
return;
}
}
throw std::runtime_error("unsupported group size");
}
Impl()
{
cudaEventCreate(&ev_start_);
cudaEventCreate(&ev_end_);
using Ops = OutputOps<ops::Identity, ops::SiluActivation>;
/// TODO: add more group sizes
Generate<128, Ops>(kernels_);
group_sizes_.push_back(128);
}
~Impl()
{
cudaEventDestroy(ev_end_);
cudaEventDestroy(ev_start_);
}
std::vector<Kernels> kernels_;
std::vector<int> group_sizes_;
static constexpr int kWarmup = 10;
static constexpr int kMeasure = 100;
cudaEvent_t ev_start_{};
cudaEvent_t ev_end_{};
};
GemmS4F16::GemmS4F16(): impl_(std::make_unique<Impl>()) {}
GemmS4F16::~GemmS4F16() = default;
void GemmS4F16::Measure(half* C,
const uint* A,
const half* B,
const half2* Q,
int m,
int n,
int k,
int group_size,
Type type,
std::vector<Metric>& metrics,
cudaStream_t st)
{
impl_->Measure(C, A, B, Q, m, n, k, group_size, type, metrics, st, impl_->kernels_);
}
void GemmS4F16::Run(half* C,
const uint* A,
const half* B,
const half2* Q,
int m,
int n,
int k,
int group_size,
Type type,
int algo_id,
cudaStream_t st)
{
impl_->Run(C, A, B, Q, m, n, k, group_size, type, algo_id, st, impl_->kernels_);
}
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "metric.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <memory>
#include <vector>
namespace turbomind {
extern bool g_dump_kernel_info_once;
class GemmS4F16 {
public:
GemmS4F16();
~GemmS4F16();
enum Type
{
kGemm,
kFusedSiluFfn
};
void Measure(half* C,
const uint* A,
const half* B,
const half2* Q,
int m,
int n,
int k,
int group_size,
Type type,
std::vector<Metric>& metrics,
cudaStream_t st);
void Run(half* C,
const uint* A,
const half* B,
const half2* Q,
int m,
int n,
int k,
int group_size,
Type type,
int algo_id,
cudaStream_t st);
private:
struct Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "gemm_template.h"
#include "metric.h"
#include <iostream>
#include <memory>
#include <sstream>
namespace turbomind {
struct IGemmKernel {
virtual ~IGemmKernel() = default;
virtual void GetMetric(Metric& metric, int m, int n, int k) = 0;
virtual void Launch(half* C,
const uint* A,
const half* B,
const half2* Q,
int M,
int N,
int K,
int output_op_idx,
cudaStream_t) = 0;
virtual void Dump(std::ostream& os) = 0;
};
template<typename CtaShape, typename WarpShape, int Stages, int GroupSize, typename OutputOps>
struct GemmKernel: public IGemmKernel {
static constexpr CtaShape cta_shape{};
static constexpr WarpShape warp_shape{};
using GemmType = Gemm<cta_shape.m(),
cta_shape.n(),
cta_shape.k(),
warp_shape.m(),
warp_shape.n(),
warp_shape.k(),
Stages,
GroupSize,
OutputOps>;
decltype(&gemm_s4_f16_nn<GemmType>) kernel_func_;
std::shared_ptr<cudaDeviceProp> props_;
int max_active_ctas_{};
static constexpr int kSlices = GemmType::SLICES;
static constexpr int kSmemSizeA = GemmType::IteratorA::kSmemByteSize * kSlices;
static constexpr int kSmemSizeB = GemmType::IteratorB::kSmemByteSize * kSlices;
static constexpr int kSmemSizeC = sizeof(float) * cta_shape.m() * cta_shape.n();
static constexpr int kSmemByteSize = std::max(kSmemSizeA + kSmemSizeB, kSmemSizeC);
// static shared memory size of Q
static constexpr int kSmemSizeQ = sizeof(typename GemmType::IteratorQ::Storage);
explicit GemmKernel(std::shared_ptr<cudaDeviceProp> props = {}): props_(std::move(props))
{
if (!props_) {
props_ = std::make_shared<cudaDeviceProp>();
int device_id = -1;
cudaGetDevice(&device_id);
cudaGetDeviceProperties(props_.get(), device_id);
}
kernel_func_ = gemm_s4_f16_nn<GemmType>;
cudaFuncSetAttribute(kernel_func_, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_ctas_, kernel_func_, GemmType::kWarpCount * WARP_SIZE, kSmemByteSize);
};
bool is_feasible(int m, int n, int k)
{
return m % cta_shape.m() == 0 && k % cta_shape.k() == 0;
}
void GetMetric(Metric& metric, int m, int n, int k) override
{
metric.cta_shape = {cta_shape.m(), cta_shape.n(), cta_shape.k()};
metric.warp_shape = {warp_shape.m(), warp_shape.n(), warp_shape.k()};
metric.warps = GemmType::kWarpCount;
metric.stages = Stages;
metric.smem = (kSmemByteSize + kSmemSizeQ) / 1024.f;
metric.feasible = is_feasible(m, n, k) && max_active_ctas_ > 0;
metric.prefer = cta_shape.m() != 64 || m <= k;
if (!metric.feasible) {
return;
}
int grid_size = ((m + cta_shape.m() - 1) / cta_shape.m()) * ((n + cta_shape.n() - 1) / cta_shape.n());
metric.grid_size = grid_size;
metric.max_active_ctas = max_active_ctas_;
metric.active_ctas =
std::min(max_active_ctas_, (grid_size + props_->multiProcessorCount - 1) / props_->multiProcessorCount);
metric.waves = (float)grid_size / (props_->multiProcessorCount * metric.active_ctas);
metric.occupancy = (metric.active_ctas * GemmType::kWarpCount)
/ (float)(props_->maxThreadsPerMultiProcessor / props_->warpSize);
metric.cta_cnt_m = (m + cta_shape.m() - 1) / cta_shape.m();
metric.cta_cnt_n = (n + cta_shape.n() - 1) / cta_shape.n();
metric.cta_iter_k = (k + cta_shape.k() - 1) / cta_shape.k();
metric.tile_efficiency = (float)n / (metric.cta_cnt_n * cta_shape.n());
metric.wave_efficiency = metric.waves / std::ceil(metric.waves);
const int m_pad = (m + cta_shape.m() - 1) / cta_shape.m() * cta_shape.m();
const int n_pad = (n + cta_shape.n() - 1) / cta_shape.n() * cta_shape.n();
metric.grid_a0 = 0.25f * m * n_pad / cta_shape.n(); // Ta0 * M * [N / ctaN]
metric.grid_b0 = 1.00f * n * m_pad / cta_shape.m(); // Tb0 * N * [M / ctaM]
metric.grid_a1 = 0.65f * m_pad * n_pad / warp_shape.n(); // Ta1 * [M] * [N] / warpN
metric.grid_b1 = 0.25f * m_pad * n_pad / warp_shape.m(); // Tb1 * [M] * [N] / warpM
metric.grid_mm = 1.00f * m_pad * n_pad / 64; // Tm * [M] * [N]
metric.grid_sum = metric.grid_a0 + metric.grid_b0 + metric.grid_a1 + metric.grid_b1 + metric.grid_mm;
metric.cta_sum = metric.grid_sum / grid_size;
metric.waves1 = (float)grid_size / (props_->multiProcessorCount * metric.active_ctas);
metric.cta_wave = std::ceil(metric.waves1) * metric.active_ctas;
metric.grid_norm = metric.cta_wave * metric.cta_sum;
}
void Launch(
half* C, const uint* A, const half* B, const half2* Q, int M, int N, int K, int output_op_idx, cudaStream_t st)
override
{
constexpr int block_size = GemmType::kWarpCount * WARP_SIZE;
dim3 grid_size((M + cta_shape.m() - 1) / cta_shape.m(), (N + cta_shape.n() - 1) / cta_shape.n());
kernel_func_<<<grid_size, block_size, kSmemByteSize, st>>>(C, A, B, Q, M, N, K, output_op_idx);
}
void Dump(std::ostream& os) override
{
{
os << "[Gemm] CTA shape: " << cta_shape.m() << "x" << cta_shape.n() << "x" << cta_shape.k() << std::endl;
os << "[Gemm] warp shape: " << warp_shape.m() << "x" << warp_shape.n() << "x" << warp_shape.k()
<< std::endl;
os << "[Gemm] warp count: " << GemmType::kWarpCountM << "x" << GemmType::kWarpCountN << "x"
<< GemmType::kWarpCountK << " (" << GemmType::kWarpCount << ")" << std::endl;
os << std::endl;
}
{
using Iter = typename GemmType::IteratorA;
os << "[A] shape: " << Iter::kShapeM << " " << Iter::kShapeK << std::endl;
os << "[A] warp thread arrangement: " << Iter::kWarpThreadC << " " << Iter::kWarpThreadS << std::endl;
os << "[A] warp shape per access: " << Iter::kWarpAccessM << " " << Iter::kWarpAccessK << std::endl;
os << "[A] warp access iters: " << Iter::kWarpIterM << " " << Iter::kWarpIterK << std::endl;
os << "[A] warp arrangement: " << Iter::kWarpM << " " << Iter::kWarpK << std::endl;
os << "[A] iterations: " << Iter::kIterM << " " << Iter::kIterK << std::endl;
os << "[A] iters per tile: " << Iter::kIterCount << std::endl;
os << "[A] warp footprint: " << Iter::kWarpFootprintM << " " << Iter::kWarpFootprintK << std::endl;
os << "[A] shared memory: " << Iter::kSmemByteSize << std::endl;
os << std::endl;
}
{
using Iter = typename GemmType::IteratorB;
os << "[B] shape: " << Iter::kShapeK << " " << Iter::kShapeN << std::endl;
os << "[B] warp thread arrangement: " << Iter::kWarpThreadC << " " << Iter::kWarpThreadS << std::endl;
os << "[B] warp shape per access: " << Iter::kWarpAccessK << " " << Iter::kWarpAccessN << std::endl;
os << "[B] warp access iters: " << Iter::kWarpIterK << " " << Iter::kWarpIterN << std::endl;
os << "[B] warp arrangement: " << Iter::kWarpK << " " << Iter::kWarpN << std::endl;
os << "[B] iterations: " << Iter::kIterK << " " << Iter::kIterN << std::endl;
os << "[B] iters per tile: " << Iter::kIterCount << std::endl;
os << "[B] warp footprint: " << Iter::kWarpFootprintK << " " << Iter::kWarpFootprintN << std::endl;
os << "[B] shared memory: " << Iter::kSmemByteSize << std::endl;
os << std::endl;
}
{
using Iter = typename GemmType::IteratorQ;
// os << "[Q] shape: " << CTA_M << " " << Iter::SLICE_K << std::endl;
os << "[Q] warp thread arrangement: " << Iter::kWarpThreadC << " " << Iter::kWarpThreadS << std::endl;
os << "[Q] warp shape per access: " << Iter::kWarpAccessM << " " << Iter::kWarpAccessK << std::endl;
os << "[Q] warp access iters: " << Iter::kWarpIterM << " " << Iter::kWarpIterK << std::endl;
os << "[Q] warp arrangement: " << Iter::kWarpM << " " << Iter::kWarpK << std::endl;
os << "[Q] iterations: " << Iter::kIterM << " " << Iter::kIterK << std::endl;
os << "[Q] iters per tile: " << Iter::kIterCount << std::endl;
os << "[Q] warp footprint: " << Iter::kWarpFootprintM << " " << Iter::kWarpFootprintK << std::endl;
os << "[Q] size per stage: " << Iter::kSizePerStage << std::endl;
os << "[Q] shared memory: " << Iter::kSmemByteSize << std::endl;
os << std::endl;
}
os << "Dynamic shared memory size: " << kSmemByteSize << std::endl;
}
};
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "common.h"
#include "cta_iterator.h"
#include "warp_iterator.h"
#include <cuda_pipeline_primitives.h>
namespace turbomind {
__inline__ __device__ void
mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<half, 4>& b, Array<float, 4>& c)
{
#if TURBOMIND_ARCH_SM80
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);
float const* C = reinterpret_cast<float const*>(&c);
float* D = reinterpret_cast<float*>(&d);
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
"{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#else
assert(TURBOMIND_ARCH_SM80);
#endif
}
__inline__ __device__ uint transpose_m8n8_b16(uint a)
{
#if TURBOMIND_ARCH_SM75
uint d;
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" : "=r"(d) : "r"(a));
return d;
#else
assert(TURBOMIND_ARCH_SM75);
return 0;
#endif
}
namespace ops {
__inline__ __device__ float4 operator+(const float4& a, const float4& b)
{
return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
}
__inline__ __device__ float2 operator+(const float2& a, const float2& b)
{
return {a.x + b.x, a.y + b.y};
}
} // namespace ops
template<int CTA_M,
int CTA_N,
int CTA_K,
int WARP_M,
int WARP_N,
int WARP_K,
int STAGES,
int GROUP_SIZE,
typename OutputOps>
struct Gemm {
static constexpr int kWarpCountM = CTA_M / WARP_M;
static constexpr int kWarpCountN = CTA_N / WARP_N;
static constexpr int kWarpCountK = CTA_K / WARP_K;
static constexpr int kWarpCountMN = kWarpCountM * kWarpCountN;
static constexpr int kWarpCount = kWarpCountMN * kWarpCountK;
static constexpr int SLICES = kWarpCountK;
static constexpr int SLICE_K = CTA_K / SLICES;
static_assert(SLICE_K % WARP_K == 0, "infeasible sliced-k setting");
using IteratorA = turbomind::IteratorA<kWarpCountMN, CTA_M, CTA_N, CTA_K, STAGES, SLICES>;
using IteratorQ = turbomind::IteratorQ<kWarpCountMN, CTA_M, CTA_N, CTA_K, STAGES, SLICES, GROUP_SIZE>;
using IteratorB = turbomind::IteratorB<kWarpCountMN, CTA_M, CTA_N, CTA_K, STAGES, SLICES>;
static constexpr int OP_M = 16;
static constexpr int OP_N = 8;
static constexpr int OP_K = 16;
using WarpIterA = turbomind::WarpIteratorA<CTA_M,
CTA_K,
WARP_M,
WARP_K,
OP_M,
OP_K,
GROUP_SIZE,
STAGES,
IteratorA::kSizePerStage,
IteratorQ::kSizePerStage>;
using WarpIterB =
turbomind::WarpIteratorB<CTA_N, CTA_K, WARP_N, WARP_K, OP_N, OP_K, IteratorB::kSmemPadCtaK, STAGES>;
__device__ void warp_mma(IteratorA& iter_A,
IteratorQ& iter_Q,
IteratorB& iter_B,
WarpIterA& warp_iter_A,
WarpIterB& warp_iter_B,
float* accum,
int slice_id,
int& gemm_iter)
{
constexpr int ITER_M = WARP_M / OP_M;
constexpr int ITER_N = WARP_N / OP_N;
constexpr int ITER_K = WARP_K / OP_K;
constexpr int kBatchA = (IteratorA::kIterCount + ITER_K - 1) / ITER_K;
constexpr int kBatchQ = (IteratorQ::kIterCount + ITER_K - 1) / ITER_K;
constexpr int kBatchB = (IteratorB::kIterCount + ITER_K - 1) / ITER_K;
auto frag_C_ptr = (Array<float, 4>*)accum; // [ITER_N, ITER_M]
PRAGMA_UNROLL
for (int iter_k = 0; iter_k < ITER_K; ++iter_k) {
warp_iter_A.load(warp_frag_A_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K);
warp_iter_B.load(warp_frag_B_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K);
auto warp_frag_A = warp_frag_A_[iter_k % 2];
auto warp_frag_B = warp_frag_B_[iter_k % 2];
PRAGMA_UNROLL
for (int iter_m = 0; iter_m < ITER_M; ++iter_m) {
PRAGMA_UNROLL
for (int iter_n = 0; iter_n < ITER_N; ++iter_n) {
auto& frag_A = warp_frag_A[iter_m];
auto& frag_B = warp_frag_B[iter_n];
auto& frag_C = frag_C_ptr[iter_n * ITER_M + iter_m];
mma_m16n8k16_row_col(frag_C, frag_A, frag_B, frag_C);
}
}
if (iter_k < ITER_K - 1) {
iter_A.prefetch_batch(iter_k, kBatchA, gemm_iter > 0);
iter_Q.prefetch_batch(iter_k, kBatchQ, gemm_iter > 0);
iter_B.prefetch_batch(iter_k, kBatchB, gemm_iter > 0);
}
if (iter_k == ITER_K - 2) {
iter_A.prefetch_batch(iter_k + 1, kBatchA, gemm_iter > 0);
iter_Q.prefetch_batch(iter_k + 1, kBatchQ, gemm_iter > 0);
iter_B.prefetch_batch(iter_k + 1, kBatchB, gemm_iter > 0);
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
sync_slice(slice_id);
iter_A.next_stage();
iter_Q.next_stage();
iter_B.next_stage();
warp_iter_A.next_stage();
warp_iter_B.next_stage();
--gemm_iter;
}
}
}
template<typename T, int N>
__device__ static void copy(T (&dst)[N], const T (&src)[N])
{
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
dst[i] = src[i];
}
}
template<typename T, int N>
__device__ static void clear(T (&dst)[N])
{
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
dst[i] = T{};
}
}
__device__ void sync_slice(int slice_id)
{
if constexpr (SLICES == 1) {
__syncthreads();
}
else {
constexpr int SLICE_GROUP = (SLICES + 7) / 8;
constexpr uint32_t num_threads = kWarpCountMN * WARP_SIZE;
const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
}
}
__device__ void load_partial(float* tb_frag_C, const float* partial_C, int cta, int slice_id)
{
if (slice_id == 0) {
PRAGMA_UNROLL
for (int i = 0; i < CTA_N; ++i) {
tb_frag_C[i] += partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x];
}
}
}
__device__ void store_partial(float* partial_C, const float* tb_frag_C, int cta, int slice_id)
{
if (slice_id == 0) {
PRAGMA_UNROLL
for (int i = 0; i < CTA_N; ++i) {
partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x] = tb_frag_C[i];
}
}
}
template<int Index>
__device__ void store_accum(float* tb_frag_C,
float* tb_smem_C,
half* C,
int m,
int n,
int cta_m,
int cta_n,
int warp_id_m,
int warp_id_n,
int lane_id,
int slice_id)
{
if (slice_id != 0) {
return;
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-c
PRAGMA_UNROLL
for (int i = 0; i < WARP_N / OP_N; ++i) {
const float2* frag_C = (float2*)&tb_frag_C[i * WARP_M / OP_M * 4];
const int nn = cta_n + warp_id_n * WARP_N + i * OP_N + lane_id / 4;
PRAGMA_UNROLL
for (int j = 0; j < WARP_M / OP_M; ++j) {
PRAGMA_UNROLL
for (int x = 0; x < 2; ++x) {
const int mm = cta_m + warp_id_m * WARP_M + j * OP_M + x * 8 + lane_id % 4 * 2;
// convert to half
half2 half_C = __float22half2_rn(frag_C[j * 2 + x]);
// transpose 8x8 accum tile
uint trans_C = transpose_m8n8_b16((uint&)half_C);
// store to global memory
OutputOps::template apply<Index>(trans_C, mm, nn, C, m, n);
}
}
}
}
__device__ void
sum_slices(float* tb_frag_C, float* tb_smem_C, int warp_id_m, int warp_id_n, int lane_id, int slice_id)
{
int offset_m = warp_id_m * WARP_M / OP_M;
int offset_n = warp_id_n * WARP_N / OP_N;
PRAGMA_UNROLL
for (int z = 0; z < SLICES; ++z) {
if (slice_id == z) {
PRAGMA_UNROLL
for (int i = 0; i < WARP_N / OP_N; ++i) {
PRAGMA_UNROLL
for (int j = 0; j < WARP_M / OP_M; ++j) {
PRAGMA_UNROLL
for (int x = 0; x < 4; ++x) {
int src = (i * WARP_M / OP_M + j) * 4 + x;
int dst = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x;
if (z > 0) {
using namespace ops;
tb_frag_C[src] = tb_smem_C[dst * WARP_SIZE + lane_id] + tb_frag_C[src];
}
tb_smem_C[dst * WARP_SIZE + lane_id] = tb_frag_C[src];
}
}
}
}
__syncthreads();
}
if (slice_id == 0) {
PRAGMA_UNROLL
for (int i = 0; i < WARP_N / OP_N; ++i) {
PRAGMA_UNROLL
for (int j = 0; j < WARP_M / OP_M; ++j) {
PRAGMA_UNROLL
for (int x = 0; x < 4; ++x) {
int src = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x;
int dst = (i * WARP_M / OP_M + j) * 4 + x;
tb_frag_C[dst] = tb_smem_C[src * WARP_SIZE + lane_id];
}
}
}
}
}
Array<half, 8> warp_frag_A_[2][WARP_M / OP_M];
Array<half, 4> warp_frag_B_[2][WARP_N / OP_N];
__device__ void run_v2(half* __restrict__ C,
const uint* __restrict__ A,
const half* __restrict__ B,
const half2* __restrict__ Q,
int M,
int N,
int K,
int output_op_idx)
{
static_assert(WARP_M % OP_N == 0);
float tb_frag_C[(WARP_N / OP_N) * (WARP_M / OP_M) * 4];
extern __shared__ uint8_t smem[];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id_m = warp_id % kWarpCountM;
const int warp_id_nk = warp_id / kWarpCountM;
const int warp_id_n = warp_id_nk % kWarpCountN;
const int warp_id_k = warp_id_nk / kWarpCountN;
const int warp_id_mn = warp_id_n * kWarpCountM + warp_id_m;
const int slice_id = warp_id_k;
const int cta_k = slice_id * SLICE_K; // sliced-k offset
const int cta_m = blockIdx.x * CTA_M;
const int cta_n = blockIdx.y * CTA_N;
// each slice has its own partition of smem
uint4* const tb_smem_A = (uint4*)(smem + IteratorA::kSmemByteSize * slice_id);
half* const tb_smem_B = (half*)(smem + IteratorA::kSmemByteSize * SLICES + IteratorB::kSmemByteSize * slice_id);
// [CTA_N / OP_N, CTA_M / OP_M, 4, WARP_SIZE], all mn fragments in CTA
float* const tb_smem_C = (float*)smem;
__shared__ typename IteratorQ::Storage tb_smem_Q_storage;
auto tb_smem_Q = tb_smem_Q_storage.data[slice_id];
IteratorA iter_A{A, tb_smem_A, M, K, cta_m, cta_k, warp_id_mn, lane_id};
IteratorQ iter_Q{Q, tb_smem_Q, M, K, cta_m, cta_k, warp_id_mn, lane_id};
IteratorB iter_B{B, tb_smem_B, K, N, cta_n, cta_k, warp_id_mn, lane_id};
const int offset_m = warp_id_m * WARP_M + lane_id;
WarpIterA warp_iter_A(iter_A.smem_, iter_Q.smem_, warp_id, lane_id, offset_m, cta_k);
WarpIterB warp_iter_B(iter_B.smem_int_ptr_, warp_id_n, lane_id, 0);
int gemm_iter = (K + CTA_K - 1) / CTA_K;
PRAGMA_UNROLL
for (int stage = 0; stage < STAGES - 1; ++stage, --gemm_iter) {
iter_A.prefetch_stage(gemm_iter > 0);
iter_Q.prefetch_stage(gemm_iter > 0);
iter_B.prefetch_stage(gemm_iter > 0);
__pipeline_commit();
}
clear(tb_frag_C);
__pipeline_wait_prior(STAGES - 2);
sync_slice(slice_id);
warp_iter_A.load(warp_frag_A_[0], 0);
warp_iter_B.load(warp_frag_B_[0], 0);
PRAGMA_NO_UNROLL
for (; gemm_iter > -STAGES + 1;) {
warp_mma(iter_A, iter_Q, iter_B, warp_iter_A, warp_iter_B, tb_frag_C, slice_id, gemm_iter);
}
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
if constexpr (SLICES > 1) {
sum_slices(tb_frag_C, tb_smem_C, warp_id_m, warp_id_n, lane_id, slice_id);
}
switch (output_op_idx) {
case 0:
store_accum<0>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id);
break;
case 1:
store_accum<1>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id);
break;
default:
return;
}
}
};
template<typename Gemm>
__global__ void gemm_s4_f16_nn(half* __restrict__ C,
const uint* __restrict__ A,
const half* __restrict__ B,
const half2* __restrict__ Q,
int M,
int N,
int K,
int output_op_idx)
{
Gemm{}.run_v2(C, A, B, Q, M, N, K, output_op_idx);
}
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <array>
#include <iomanip>
#include <sstream>
#include <string>
#include <vector>
namespace turbomind {
struct Metric {
int id;
bool feasible;
bool prefer;
std::array<int, 3> cta_shape;
std::array<int, 3> warp_shape;
int warps;
int stages;
int max_active_ctas;
float smem;
float cta_cnt_m;
float cta_cnt_n;
float cta_iter_k;
float grid_size;
int active_ctas;
float waves;
float waves1;
float occupancy;
float tile_efficiency;
float wave_efficiency;
float grid_a0;
float grid_b0;
float grid_a1;
float grid_b1;
float grid_mm;
float grid_sum;
float grid_norm;
float cta_sum;
float cta_wave;
int best;
float time;
int count;
};
inline void DumpMetrics(std::ostream& os, const std::vector<Metric>& metrics, const std::vector<int>& indices = {})
{
auto dump_shape = [](const std::array<int, 3>& shape) {
std::stringstream ss;
ss << std::setw(4) << shape[0] << std::setw(4) << shape[1] << std::setw(4) << shape[2];
return ss.str();
};
std::vector<std::tuple<std::string, int>> infos{
{"id", 4}, {"valid", 6}, {"cta_mnk", 14}, {"warp_mnk", 14}, {"warps", 6}, {"stages", 8},
{"smem", 8}, {"cta_cnt_m", 10}, {"cta_cnt_n", 10}, {"cta_iter_k", 11}, {"max_ctas", 9}, {"act_ctas", 10},
{"waves", 12}, {"waves1", 12}, {"occupancy", 12}, {"%tile", 10}, {"%wave", 10}, {"grid_a0", 12},
{"grid_b0", 12}, {"grid_a1", 12}, {"grid_b1", 12}, {"grid_mm", 12}, {"grid_sum", 12}, {"cta_cnt", 8},
{"cta_sum", 8}, {"cta_wave", 9}, {"grid_norm", 12}, {"time", 12}, {"best", 7}};
for (const auto& [name, width] : infos) {
os << std::setw(width) << name;
}
os << "\n";
for (size_t i = 0; i < metrics.size(); ++i) {
auto& metric = indices.empty() ? metrics[i] : metrics[indices[i]];
int c = 0;
os << std::setw(std::get<1>(infos[c++])) << metric.id;
os << std::setw(std::get<1>(infos[c++])) << metric.feasible;
os << std::setw(std::get<1>(infos[c++])) << dump_shape(metric.cta_shape);
os << std::setw(std::get<1>(infos[c++])) << dump_shape(metric.warp_shape);
os << std::setw(std::get<1>(infos[c++])) << metric.warps;
os << std::setw(std::get<1>(infos[c++])) << metric.stages;
os << std::setw(std::get<1>(infos[c++])) << metric.smem;
os << std::setw(std::get<1>(infos[c++])) << metric.cta_cnt_m;
os << std::setw(std::get<1>(infos[c++])) << metric.cta_cnt_n;
os << std::setw(std::get<1>(infos[c++])) << metric.cta_iter_k;
os << std::setw(std::get<1>(infos[c++])) << metric.max_active_ctas;
os << std::setw(std::get<1>(infos[c++])) << metric.active_ctas;
os << std::setw(std::get<1>(infos[c++])) << metric.waves;
os << std::setw(std::get<1>(infos[c++])) << metric.waves1;
os << std::setw(std::get<1>(infos[c++])) << metric.occupancy;
os << std::setw(std::get<1>(infos[c++])) << metric.tile_efficiency;
os << std::setw(std::get<1>(infos[c++])) << metric.wave_efficiency;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_a0;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_b0;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_a1;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_b1;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_mm;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_sum;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_size;
os << std::setw(std::get<1>(infos[c++])) << metric.cta_sum;
os << std::setw(std::get<1>(infos[c++])) << metric.cta_wave;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_norm;
os << std::setw(std::get<1>(infos[c++])) << metric.time * 1000 / metric.count;
os << std::setw(std::get<1>(infos[c++])) << (metric.best ? "*" : "");
os << "\n";
}
}
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "common.h"
namespace turbomind {
template<int CTA_M,
int CTA_K,
int WARP_M,
int WARP_K,
int OP_M,
int OP_K,
int GROUP_SIZE,
int STAGES,
int kSizePerStageA,
int kSizePerStageQ>
struct WarpIteratorA {
static_assert(WARP_K % GROUP_SIZE == 0 || GROUP_SIZE % WARP_K == 0);
static constexpr int ITER_M = 32 / OP_M;
static constexpr int ITER_X = WARP_M / 32;
uint4 frag_A4_[ITER_X]; // 8 value per uint
half2 frag_Q_[ITER_X][4]; // 4 m8k8 tile along M, as WARP_M == 32
const uint4* smem_A_;
const half2* smem_Q_;
const int offset_m_;
const int offset_m_Q_;
int stage_{0};
int offset_A_{0};
int offset_Q_{0};
__device__ WarpIteratorA(uint4* smem_A, half2* smem_Q, int warp_id, int lane_id, int offset_m, int offset_k):
smem_A_(smem_A), smem_Q_(smem_Q), offset_m_(offset_m), offset_m_Q_(offset_m / 32 * 32 + lane_id / 4)
{
}
// iter_k must be a compile tile constant
__device__ void load(Array<half, 8>* data, int iter_k)
{
// load A
// smem_A uint4 [SLICE_K/32, CTA_M/32, WARP_SIZE], load as uint4 to avoid bank-conflicts
if (iter_k % 2 == 0) {
PRAGMA_UNROLL
for (int x = 0; x < ITER_X; ++x) {
frag_A4_[x] = smem_A_[offset_A_ + (iter_k / 2) * CTA_M + x * 32 + offset_m_];
}
}
// load Q
if (iter_k * OP_K % GROUP_SIZE == 0) {
const int g = iter_k * OP_K / GROUP_SIZE;
PRAGMA_UNROLL
for (int x = 0; x < ITER_X; ++x) {
PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
const int mm = offset_m_Q_ + x * 32 + i * 8; // stride of m8k8 tile
((uint&)frag_Q_[x][i]) = ((uint&)smem_Q_[offset_Q_ + g * CTA_M + mm]);
}
}
}
PRAGMA_UNROLL
for (int x = 0; x < ITER_X; ++x) {
const uint* frag_A = (uint*)&frag_A4_[x];
PRAGMA_UNROLL
for (int iter_m = 0; iter_m < ITER_M; ++iter_m) {
uint4 tmp = dequantize_s4_to_fp16x2_v2(frag_A[iter_k % 2 * 2 + iter_m]);
auto& vec = (Array<half2, 4>&)tmp;
vec[0] = apply_Q(vec[0], frag_Q_[x][iter_m * 2]);
vec[1] = apply_Q(vec[1], frag_Q_[x][iter_m * 2 + 1]);
vec[2] = apply_Q(vec[2], frag_Q_[x][iter_m * 2]);
vec[3] = apply_Q(vec[3], frag_Q_[x][iter_m * 2 + 1]);
data[x * ITER_M + iter_m] = (Array<half, 8>&)vec;
}
}
}
__device__ void next_stage()
{
++stage_;
if (stage_ >= STAGES) {
stage_ = 0;
}
offset_A_ = stage_ * kSizePerStageA;
offset_Q_ = stage_ * kSizePerStageQ;
}
};
template<int CTA_N, int CTA_K, int WARP_N, int WARP_K, int OP_N, int OP_K, int SMEM_STRIDE, int STAGES>
struct WarpIteratorB {
static constexpr int kLdsmNum = WARP_N == 8 ? 2 : 4;
static constexpr int ITER_N = WARP_N / OP_N;
static constexpr int ITER_K = WARP_K / OP_K;
static_assert(OP_N == 8 && OP_K == 16);
const int warp_id_n_;
const int lane_id_;
const int ldsm_group_id_;
const int offset_k_;
int offset_n_;
const uint32_t smem_base_ptr_;
uint32_t smem_ptr_;
int stage_{0};
__device__ WarpIteratorB(uint32_t smem_int_ptr, int warp_id_n, int lane_id, int offset_k):
smem_base_ptr_(smem_int_ptr),
smem_ptr_(smem_base_ptr_),
warp_id_n_(warp_id_n),
lane_id_(lane_id),
ldsm_group_id_(lane_id / 8),
offset_k_(ldsm_group_id_ % 2 * 8 + offset_k),
offset_n_(ldsm_group_id_ / 2 * 8 + lane_id % 8)
{
if (kLdsmNum == 2) {
offset_n_ -= ldsm_group_id_ / 2 * 8;
}
offset_n_ += warp_id_n_ * WARP_N;
}
__device__ void load(Array<half, 4>* data, int iter_k)
{
const int kk = iter_k * OP_K + offset_k_;
auto ptr = (uint*)data;
PRAGMA_UNROLL
for (int iter_n = 0; iter_n < ITER_N;) {
const int nn = offset_n_ + iter_n * OP_N;
auto src = smem_ptr_ + sizeof(half) * (nn * SMEM_STRIDE + kk);
if constexpr (kLdsmNum == 4) {
ldmatrix_m8n8_x4_b16(ptr[0], ptr[1], ptr[2], ptr[3], src);
ptr += 4;
iter_n += 2;
}
else {
ldmatrix_m8n8_x2_b16(ptr[0], ptr[1], src);
ptr += 2;
iter_n += 1;
}
}
}
__device__ void next_stage()
{
++stage_;
if (stage_ >= STAGES) {
stage_ = 0;
}
smem_ptr_ = smem_base_ptr_ + stage_ * sizeof(half) * CTA_N * SMEM_STRIDE;
}
};
} // namespace turbomind
......@@ -21,6 +21,7 @@ add_library(Llama STATIC
set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(Llama PUBLIC -lcudart
gemm_s4_f16
cublasMMWrapper
DynamicDecodeLayer
activation_kernels
......
......@@ -19,6 +19,7 @@
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc
#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/memory_utils.h"
#include <filesystem>
......@@ -31,6 +32,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
size_t size_per_head,
size_t inter_size,
WeightType weight_type,
int group_size,
bool attn_bias,
size_t tensor_para_size,
size_t tensor_para_rank):
......@@ -47,22 +49,32 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
self_attn_weights.qkv.input_dims = hidden_units_;
self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_;
self_attn_weights.qkv.type = weight_type;
self_attn_weights.qkv.group_size = group_size;
self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_;
self_attn_weights.output.output_dims = hidden_units_;
self_attn_weights.output.type = weight_type;
self_attn_weights.output.group_size = group_size;
ffn_weights.gating.input_dims = hidden_units_;
ffn_weights.gating.output_dims = inter_size_ / tensor_para_size_;
ffn_weights.gating.type = weight_type;
ffn_weights.gating.group_size = group_size;
ffn_weights.intermediate.input_dims = hidden_units_;
ffn_weights.intermediate.output_dims = inter_size_ / tensor_para_size_;
ffn_weights.intermediate.type = weight_type;
ffn_weights.intermediate.group_size = group_size;
ffn_weights.fused_gating_intermediate.input_dims = hidden_units_;
ffn_weights.fused_gating_intermediate.output_dims = inter_size_ / tensor_para_size_ * 2;
ffn_weights.fused_gating_intermediate.type = weight_type;
ffn_weights.fused_gating_intermediate.group_size = group_size;
ffn_weights.output.input_dims = inter_size_ / tensor_para_size_;
ffn_weights.output.output_dims = hidden_units_;
ffn_weights.output.type = weight_type;
ffn_weights.output.group_size = group_size;
mallocWeights();
}
......@@ -71,13 +83,11 @@ void freeWeights(LlamaDenseWeight<T>& weights)
{
cudaFree(weights.kernel);
cudaFree(weights.bias);
cudaFree(weights.scales);
cudaFree(weights.zeros);
cudaFree(weights.scales_and_zeros);
weights.kernel = nullptr;
weights.bias = nullptr;
weights.scales = nullptr;
weights.zeros = nullptr;
weights.scales_and_zeros = nullptr;
}
template<typename T>
......@@ -93,9 +103,10 @@ void mallocWeights(LlamaDenseWeight<T>& weights, bool bias)
else { // int8, int4
const int factor = sizeof(float) * 8 / bit_size;
FT_CHECK(weights.input_dims % factor == 0);
deviceMalloc((float**)&weights.kernel, weights.input_dims / factor * weights.output_dims);
deviceMalloc((T**)&weights.scales, weights.output_dims);
deviceMalloc((T**)&weights.zeros, weights.output_dims);
deviceMalloc((int**)&weights.kernel, weights.input_dims * weights.output_dims / factor);
deviceMemSetZero((int*)weights.kernel, weights.input_dims * weights.output_dims / factor);
// interleaved scales/zeros
deviceMalloc((T**)&weights.scales_and_zeros, weights.input_dims / weights.group_size * weights.output_dims * 2);
}
}
......@@ -195,40 +206,15 @@ void loadWeights(LlamaDenseWeight<T>& w,
}
else { // int8, int4
const int factor = sizeof(float) * 8 / bit_size;
FT_CHECK(dim0 % factor == 0);
const auto f32_type = FtCudaDataType::FP32;
std::vector<ConcateSlice> weight_slices{};
std::vector<ConcateSlice> bias_slices{};
if (enable_slice) {
if (slice_dim == 1) {
size_t start = 0;
ConcateSlice slice0{.slices = {{0, dim0}}};
ConcateSlice slice1{.slices = {{}}};
for (auto len : slice_shape) {
size_t stride = len / tensor_para_size;
slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
start += len;
}
weight_slices = {slice0, slice1};
ConcateSlice bias_slice0{.slices = {{0, 1}}};
bias_slices = {bias_slice0, slice1};
}
else {
size_t start = 0;
ConcateSlice slice0{.slices = {}};
ConcateSlice slice1{.slices = {{0, dim1}}};
for (auto len : slice_shape) {
size_t stride = len / factor / tensor_para_size;
slice0.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
start += len;
}
weight_slices = {slice0, slice1};
}
}
loadWeightFromBin((float*)w.kernel, {dim0 / factor, dim1}, prefix + ".qweight", f32_type, weight_slices);
loadWeightFromBin((T*)w.scales, {1, dim1}, prefix + ".scales", type, bias_slices);
loadWeightFromBin((T*)w.zeros, {1, dim1}, prefix + ".zeros", type, bias_slices);
FT_CHECK(dim1 % factor == 0);
std::vector<size_t> w_shape{dim0, dim1 / factor * sizeof(uint32_t)};
loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8, {});
const size_t group_count = w.group_size > 0 ? dim0 / w.group_size : 1;
loadWeightFromBin((half*)w.scales_and_zeros, {group_count, dim1 * 2}, prefix + ".scales_zeros", type, {});
}
}
......@@ -241,8 +227,14 @@ void LlamaDecoderLayerWeight<T>::mallocWeights()
turbomind::mallocWeights(self_attn_weights.qkv, attn_bias_);
turbomind::mallocWeights(self_attn_weights.output, attn_bias_);
if (weight_type_ == WeightType::kINT4) {
turbomind::mallocWeights(ffn_weights.fused_gating_intermediate, false);
}
else {
turbomind::mallocWeights(ffn_weights.gating, false);
turbomind::mallocWeights(ffn_weights.intermediate, false);
}
turbomind::mallocWeights(ffn_weights.output, false);
}
......@@ -254,8 +246,15 @@ LlamaDecoderLayerWeight<T>::~LlamaDecoderLayerWeight()
freeWeights(self_attn_weights.qkv);
freeWeights(self_attn_weights.output);
if (weight_type_ == WeightType::kINT4) {
freeWeights(ffn_weights.fused_gating_intermediate);
}
else {
freeWeights(ffn_weights.gating);
freeWeights(ffn_weights.intermediate);
}
freeWeights(ffn_weights.output);
}
......@@ -276,9 +275,22 @@ void LlamaDecoderLayerWeight<T>::loadModel(std::string dir_path, FtCudaDataType
tensor_para_size_,
1,
{head_num_ * size_per_head_, kv_head_num_ * size_per_head_, kv_head_num_ * size_per_head_});
loadWeights(self_attn_weights.output, dir_path + ".attention.wo", tensor_para_rank_, type, tensor_para_size_, 0);
if (weight_type_ == WeightType::kINT4) {
loadWeights(ffn_weights.fused_gating_intermediate,
dir_path + ".feed_forward.w13",
tensor_para_rank_,
type,
tensor_para_size_,
1);
}
else {
loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type, tensor_para_size_, 1);
loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type, tensor_para_size_, 1);
loadWeights(
ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type, tensor_para_size_, 1);
}
loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_, 0);
// load kv_cache quant scale
......
......@@ -33,6 +33,7 @@ public:
size_t size_per_head,
size_t inter_size,
WeightType weight_type,
int group_size,
bool attn_bias,
size_t tensor_para_size,
size_t tensor_para_rank);
......
......@@ -48,18 +48,18 @@ inline size_t getBitSize(WeightType type)
case WeightType::kINT4:
return 4;
}
return 0;
}
template<typename T>
struct LlamaDenseWeight {
size_t input_dims;
size_t output_dims;
void* kernel;
WeightType type;
T* bias;
T* scales;
T* zeros;
T* scales_and_zeros;
int group_size;
};
template<typename T>
......@@ -74,6 +74,7 @@ struct LlamaFfnWeight {
LlamaDenseWeight<T> gating;
LlamaDenseWeight<T> intermediate;
LlamaDenseWeight<T> output;
LlamaDenseWeight<T> fused_gating_intermediate;
};
} // namespace turbomind
......@@ -85,12 +85,16 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,
T* ffn_output_data = output_tensors->at("ffn_output").getPtr<T>();
PUSH_RANGE("ffn");
// TODO: fuse the two GEMMs with activation
linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating);
if (weights->fused_gating_intermediate.kernel) {
linear_.forward(
gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear<T>::kFusedSiluFfn);
}
else {
linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating);
linear_.forward(inter_buf_, ffn_input_data, num_token, weights->intermediate);
activation(num_token);
}
linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output);
POP_RANGE;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment