Unverified Commit c3290cad authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

[Feature] Blazing fast W4A16 inference (#202)

* add w4a16

* fix `deploy.py`

* add doc

* add w4a16 kernels

* fuse w1/w3 & bugfixes

* fix typo

* python

* guard sm75/80 features

* add missing header

* refactor

* qkvo bias

* update cost model

* fix lint

* update `deploy.py`
parent d3dbe179
...@@ -304,6 +304,7 @@ add_library(transformer-shared SHARED ...@@ -304,6 +304,7 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:llama_fmha> $<TARGET_OBJECTS:llama_fmha>
$<TARGET_OBJECTS:Llama> $<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:LlamaTritonBackend> $<TARGET_OBJECTS:LlamaTritonBackend>
$<TARGET_OBJECTS:gemm_s4_f16>
$<TARGET_OBJECTS:TopKSamplingLayer> $<TARGET_OBJECTS:TopKSamplingLayer>
$<TARGET_OBJECTS:TopPSamplingLayer> $<TARGET_OBJECTS:TopPSamplingLayer>
$<TARGET_OBJECTS:TransformerTritonBackend> $<TARGET_OBJECTS:TransformerTritonBackend>
......
...@@ -34,6 +34,19 @@ bash workspace/service_docker_up.sh ...@@ -34,6 +34,19 @@ bash workspace/service_docker_up.sh
</details> </details>
<details open>
<summary><b>7B with INT4 weight only quantization</b></summary>
```shell
python3 -m lmdeploy.serve.turbomind.deploy llama2 /path/to/llama-2-7b-chat-hf \
--model_format awq \
--group_size 128 \
--quant_path /path/to/awq-quant-weight.pt
bash workspace/service_docker_up.sh
```
</details>
## Serving [LLaMA](https://github.com/facebookresearch/llama) ## Serving [LLaMA](https://github.com/facebookresearch/llama)
Weights for the LLaMA models can be obtained from by filling out [this form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) Weights for the LLaMA models can be obtained from by filling out [this form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
......
...@@ -5,6 +5,7 @@ import os ...@@ -5,6 +5,7 @@ import os
import os.path as osp import os.path as osp
import re import re
import shutil import shutil
import sys
from pathlib import Path from pathlib import Path
import fire import fire
...@@ -12,9 +13,10 @@ import safetensors ...@@ -12,9 +13,10 @@ import safetensors
import torch import torch
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
import lmdeploy
from lmdeploy.model import MODELS from lmdeploy.model import MODELS
supported_formats = ['llama', 'hf'] supported_formats = ['llama', 'hf', 'awq']
def get_package_root_path(): def get_package_root_path():
...@@ -107,7 +109,9 @@ def export(model_name: str, ...@@ -107,7 +109,9 @@ def export(model_name: str,
tokenizer_path: str, tokenizer_path: str,
out_dir: str, out_dir: str,
tp: int, tp: int,
size_per_head: int = 128): size_per_head: int = 128,
group_size: int = 0,
weight_type: str = 'fp16'):
"""Export deploying information to a config file. """Export deploying information to a config file.
Args: Args:
...@@ -127,9 +131,10 @@ def export(model_name: str, ...@@ -127,9 +131,10 @@ def export(model_name: str,
print(name, param.shape) print(name, param.shape)
if param.dtype in [torch.float, torch.bfloat16]: if param.dtype in [torch.float, torch.bfloat16]:
param = param.half() param = param.half()
param.contiguous().numpy().tofile(osp.join(out_dir, name)) param.contiguous().cpu().numpy().tofile(osp.join(out_dir, name))
attn_bias = False attn_bias = False
inter_size = 0
# reverse the splitting axes since the weights are transposed above # reverse the splitting axes since the weights are transposed above
for param_name, param_data in model_params.items(): for param_name, param_data in model_params.items():
...@@ -141,10 +146,14 @@ def export(model_name: str, ...@@ -141,10 +146,14 @@ def export(model_name: str,
if key == 'w_qkv' and ext == 'bias': if key == 'w_qkv' and ext == 'bias':
attn_bias = True attn_bias = True
copy = False copy = False
if key in ['w1', 'w3']: if key in ['w1', 'w3', 'w13']:
split_dim = -1 split_dim = -1
# TODO: move parameter extraction outside of the loop
if key == 'w1': if key == 'w1':
inter_size = param_data.shape[-1] inter_size = max(inter_size, param_data.shape[-1])
elif key == 'w13':
inter_size = max(inter_size, param_data.shape[-1] // 2)
elif key == 'w_qkv': elif key == 'w_qkv':
split_dim = -2 split_dim = -2
elif key in ['w2', 'wo']: elif key in ['w2', 'wo']:
...@@ -170,6 +179,8 @@ def export(model_name: str, ...@@ -170,6 +179,8 @@ def export(model_name: str,
else: else:
save_bin(param_data, param_name) save_bin(param_data, param_name)
assert inter_size > 0
# export config and save it to {out_dir}/config.ini # export config and save it to {out_dir}/config.ini
model = MODELS.get(model_name)() model = MODELS.get(model_name)()
vocab_size, bos_id, eos_id = tokenizer_info(tokenizer_path) vocab_size, bos_id, eos_id = tokenizer_info(tokenizer_path)
...@@ -188,7 +199,8 @@ def export(model_name: str, ...@@ -188,7 +199,8 @@ def export(model_name: str,
attn_bias=int(attn_bias), attn_bias=int(attn_bias),
start_id=bos_id, start_id=bos_id,
end_id=eos_id, end_id=eos_id,
weight_type='fp16', weight_type=weight_type,
group_size=group_size,
# parameters for turbomind # parameters for turbomind
max_batch_size=32, max_batch_size=32,
max_context_token_num=4, max_context_token_num=4,
...@@ -329,7 +341,7 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str, ...@@ -329,7 +341,7 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
def permute(x: torch.Tensor): def permute(x: torch.Tensor):
SIZE_PER_HEAD = 128 SIZE_PER_HEAD = 128
if x.shape[-1] > 1: # qweights if x.shape[-1] > 1:
dim = x.shape[-1] dim = x.shape[-1]
n_heads = dim // SIZE_PER_HEAD n_heads = dim // SIZE_PER_HEAD
return x.view(-1, n_heads, 2, return x.view(-1, n_heads, 2,
...@@ -491,6 +503,228 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, ...@@ -491,6 +503,228 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
tokenizer_path, triton_models_path, tp) tokenizer_path, triton_models_path, tp)
def deploy_awq(model_name: str, model_path: str, tokenizer_path: str,
triton_models_path: str, tp: int, quant_path: str,
group_size: int):
"""Deploy a model with huggingface transformers' format.
Args:
model_name (str): the name of the to-be-deployed model
model_path (str): the path of the directory where the model weight
files are
tokenizer_path (str): the path of the tokenizer model path
triton_models_path (str): the path of the exported triton models
tp (int): the number of tensor parallelism
quant_path (str): path of the quantized model, which can be None
group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits
"""
if tokenizer_path is None:
tokenizer_path = osp.join(model_path, 'tokenizer.model')
if osp.exists(tokenizer_path):
shutil.copy(tokenizer_path,
osp.join(triton_models_path, 'tokenizer/tokenizer.model'))
for _file in os.listdir(model_path):
if _file.endswith('.json') or _file.endswith('.py'):
json_path = osp.join(model_path, _file)
shutil.copy(json_path,
osp.join(triton_models_path, 'tokenizer', _file))
with get_package_root_path() as root_path:
shutil.copy(osp.join(root_path, 'turbomind/tokenizer.py'),
osp.join(triton_models_path, 'tokenizer'))
else:
print(f'tokenizer model {tokenizer_path} does not exist')
exit(-1)
# read model arguments from params.json
try:
params_path = osp.join(model_path, 'config.json')
with open(params_path) as f:
model_arg = json.load(f)
num_layer = model_arg['num_hidden_layers']
norm_eps = model_arg['rms_norm_eps']
if 'num_key_value_heads' in model_arg:
kv_head_num = model_arg['num_key_value_heads']
else:
kv_head_num = model_arg['num_attention_heads']
except Exception as e:
print(f'get "num_hidden_layers" and "rms_norm_eps" from '
f'{params_path} failed: {e}')
return False
# convert weights from hf to turbomind
if quant_path is None:
_files = [
osp.join(model_path, file) for file in os.listdir(model_path)
if file.endswith('.bin')
]
_files = sorted(_files)
else:
_files = [quant_path]
model_params = {}
_params = {}
for _file in _files:
_tmp = torch.load(_file, map_location='cpu')
_params.update(_tmp)
def get_tensor(name):
"""return tensor according its name."""
return _params[name].cuda().contiguous()
# import _turbomind as _tm
# TODO: find another way import _turbomind
lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
sys.path.append(osp.join(lmdeploy_dir, 'lib'))
import _turbomind as _tm # noqa: E402
def transpose_qk(src: torch.Tensor):
assert src.is_contiguous()
dst = torch.zeros_like(src)
_tm.transpose_qk_s4_k_m8(src, dst,
src.size(-1) * 8, src.size(0), group_size)
return dst
def fuse_w1_w3(w1_qw: torch.Tensor, w1_qz: torch.Tensor,
w1_s: torch.Tensor, w3_qw: torch.Tensor,
w3_qz: torch.Tensor, w3_s: torch.Tensor):
def fuse(a: torch.Tensor, b: torch.Tensor):
ab = torch.cat((a, b)).contiguous()
_ab = torch.zeros_like(ab)
_tm.fuse_w1_w3_s4_k_m8(ab, _ab, a.size(-1) * 8, a.size(0))
return _ab.view(a.size(0), -1)
w13_qw = fuse(w1_qw, w3_qw)
w13_qz = fuse(w1_qz, w3_qz)
w13_s = torch.cat((w1_s, w3_s)).view(2, w1_s.size(0), -1)
w13_s = w13_s.permute(1, 2, 0).contiguous().view(w1_s.size(0), -1)
return w13_qw, w13_qz, w13_s
def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size: int):
assert qw.is_contiguous()
assert qz.is_contiguous()
assert s.is_contiguous()
_qw = torch.zeros_like(qw)
_sz = torch.zeros_like(s, dtype=torch.int32)
_ws = torch.zeros_like(s)
_tm.convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
qw.size(-1) * 8, qw.size(0), group_size)
return _qw, _sz
attn_bias = False
for i in range(num_layer):
print(i)
# attention weights
q_qw = get_tensor(f'model.layers.{i}.self_attn.q_proj.qweight')
k_qw = get_tensor(f'model.layers.{i}.self_attn.k_proj.qweight')
v_qw = get_tensor(f'model.layers.{i}.self_attn.v_proj.qweight')
o_qw = get_tensor(f'model.layers.{i}.self_attn.o_proj.qweight')
q_qz = get_tensor(f'model.layers.{i}.self_attn.q_proj.qzeros')
k_qz = get_tensor(f'model.layers.{i}.self_attn.k_proj.qzeros')
v_qz = get_tensor(f'model.layers.{i}.self_attn.v_proj.qzeros')
o_qz = get_tensor(f'model.layers.{i}.self_attn.o_proj.qzeros')
q_s = get_tensor(f'model.layers.{i}.self_attn.q_proj.scales')
k_s = get_tensor(f'model.layers.{i}.self_attn.k_proj.scales')
v_s = get_tensor(f'model.layers.{i}.self_attn.v_proj.scales')
o_s = get_tensor(f'model.layers.{i}.self_attn.o_proj.scales')
try:
q_b = get_tensor(f'model.layers.{i}.self_attn.q_proj.bias')
k_b = get_tensor(f'model.layers.{i}.self_attn.k_proj.bias')
v_b = get_tensor(f'model.layers.{i}.self_attn.v_proj.bias')
o_b = get_tensor(f'model.layers.{i}.self_attn.o_proj.bias')
attn_bias = True
except: # noqa: E722
pass
q_qw = transpose_qk(q_qw)
k_qw = transpose_qk(k_qw)
q_qz = transpose_qk(q_qz)
k_qz = transpose_qk(k_qz)
q_s = permute(q_s)
k_s = permute(k_s)
qkv_qw = merge_qkv(q_qw, k_qw, v_qw, tp, dim=2)
qkv_qz = merge_qkv(q_qz, k_qz, v_qz, tp, dim=2)
qkv_s = merge_qkv(q_s, k_s, v_s, tp, dim=2)
qkv_qw, qkv_sz = convert_s4(qkv_qw, qkv_qz, qkv_s, group_size)
model_params[f'layers.{i}.attention.w_qkv.qweight'] = qkv_qw
model_params[f'layers.{i}.attention.w_qkv.scales_zeros'] = qkv_sz
o_qw, o_sz = convert_s4(o_qw, o_qz, o_s, group_size)
model_params[f'layers.{i}.attention.wo.qweight'] = o_qw
model_params[f'layers.{i}.attention.wo.scales_zeros'] = o_sz
if attn_bias:
q_b = permute(q_b)
k_b = permute(k_b)
qkv_b = merge_qkv(q_b, k_b, v_b, tp, dim=1)
model_params[f'layers.{i}.attention.w_qkv.bias'] = qkv_b
model_params[f'layers.{i}.attention.wo.bias'] = o_b
# ffn weights
w1_qw = get_tensor(f'model.layers.{i}.mlp.gate_proj.qweight')
w2_qw = get_tensor(f'model.layers.{i}.mlp.down_proj.qweight')
w3_qw = get_tensor(f'model.layers.{i}.mlp.up_proj.qweight')
w1_qz = get_tensor(f'model.layers.{i}.mlp.gate_proj.qzeros')
w2_qz = get_tensor(f'model.layers.{i}.mlp.down_proj.qzeros')
w3_qz = get_tensor(f'model.layers.{i}.mlp.up_proj.qzeros')
w1_s = get_tensor(f'model.layers.{i}.mlp.gate_proj.scales')
w2_s = get_tensor(f'model.layers.{i}.mlp.down_proj.scales')
w3_s = get_tensor(f'model.layers.{i}.mlp.up_proj.scales')
w13_qw, w13_qz, w13_s = fuse_w1_w3(w1_qw, w1_qz, w1_s, w3_qw, w3_qz,
w3_s)
w13_qw, w13_sz = convert_s4(w13_qw, w13_qz, w13_s, group_size)
w2_qw, w2_sz = convert_s4(w2_qw, w2_qz, w2_s, group_size)
model_params[f'layers.{i}.feed_forward.w13.qweight'] = w13_qw
model_params[f'layers.{i}.feed_forward.w13.scales_zeros'] = w13_sz
model_params[f'layers.{i}.feed_forward.w2.qweight'] = w2_qw
model_params[f'layers.{i}.feed_forward.w2.scales_zeros'] = w2_sz
# norm weights
attn_norm = get_tensor(f'model.layers.{i}.input_layernorm.weight')
ffn_norm = get_tensor(
f'model.layers.{i}.post_attention_layernorm.weight')
model_params[f'layers.{i}.attention_norm.weight'] = attn_norm
model_params[f'layers.{i}.ffn_norm.weight'] = ffn_norm
other = [('tok_embeddings.weight', 'model.embed_tokens.weight'),
('norm.weight', 'model.norm.weight'),
('output.weight', 'lm_head.weight')]
for ft, hf in other:
model_params[ft] = get_tensor(hf)
return export(model_name,
num_layer,
norm_eps,
kv_head_num,
model_params,
tokenizer_path,
triton_models_path,
tp,
weight_type='int4',
group_size=group_size)
def pack_model_repository(workspace_path: str): def pack_model_repository(workspace_path: str):
"""package the model repository. """package the model repository.
...@@ -521,7 +755,9 @@ def main(model_name: str, ...@@ -521,7 +755,9 @@ def main(model_name: str,
model_format: str = 'hf', model_format: str = 'hf',
tokenizer_path: str = None, tokenizer_path: str = None,
dst_path: str = './workspace', dst_path: str = './workspace',
tp: int = 1): tp: int = 1,
quant_path: str = None,
group_size: int = 0):
"""deploy llama family models via turbomind. """deploy llama family models via turbomind.
Args: Args:
...@@ -533,6 +769,9 @@ def main(model_name: str, ...@@ -533,6 +769,9 @@ def main(model_name: str,
tokenizer_path (str): the path of tokenizer model tokenizer_path (str): the path of tokenizer model
dst_path (str): the destination path that saves outputs dst_path (str): the destination path that saves outputs
tp (int): the number of GPUs used for tensor parallelism tp (int): the number of GPUs used for tensor parallelism
quant_path (str): path of the quantized model, which can be None
group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits
""" """
assert model_name in MODELS.module_dict.keys(), \ assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \ f"'{model_name}' is not supported. " \
...@@ -558,9 +797,12 @@ def main(model_name: str, ...@@ -558,9 +797,12 @@ def main(model_name: str,
if model_format == 'llama': if model_format == 'llama':
res = deploy_llama(model_name, model_path, tokenizer_path, res = deploy_llama(model_name, model_path, tokenizer_path,
triton_models_path, tp) triton_models_path, tp)
else: elif model_format == 'hf':
res = deploy_hf(model_name, model_path, tokenizer_path, res = deploy_hf(model_name, model_path, tokenizer_path,
triton_models_path, tp) triton_models_path, tp)
elif model_format == 'awq':
res = deploy_awq(model_name, model_path, tokenizer_path,
triton_models_path, tp, quant_path, group_size)
# update `tensor_para_size` in `triton_models/interactive/config.pbtxt` # update `tensor_para_size` in `triton_models/interactive/config.pbtxt`
with open(osp.join(triton_models_path, 'interactive/config.pbtxt'), with open(osp.join(triton_models_path, 'interactive/config.pbtxt'),
......
...@@ -69,3 +69,5 @@ set_property(TARGET sampling_penalty_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOL ...@@ -69,3 +69,5 @@ set_property(TARGET sampling_penalty_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOL
add_library(custom_ar_kernels STATIC custom_ar_kernels.cu) add_library(custom_ar_kernels STATIC custom_ar_kernels.cu)
set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_subdirectory(gemm_s_f16)
# Copyright (c) OpenMMLab. All rights reserved.
add_library(gemm_s4_f16 STATIC gemm_s4_f16.cu format.cu)
target_compile_options(gemm_s4_f16 PRIVATE
--generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
set_property(TARGET gemm_s4_f16 PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET gemm_s4_f16 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <cassert>
#include <cstdint>
#include <cuda_fp16.h>
#include <type_traits>
namespace turbomind {
#ifndef TURBOMIND_S4_DEQUANT_USE_FMA
#define TURBOMIND_S4_DEQUANT_USE_FMA 0
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
#define TURBOMIND_ARCH_SM75 1
#else
#define TURBOMIND_ARCH_SM75 0
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
#define TURBOMIND_ARCH_SM80 1
#else
#define TURBOMIND_ARCH_SM80 0
#endif
constexpr int WARP_SIZE = 32;
#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
#if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
#define PRAGMA_UNROLL _Pragma("unroll")
#define PRAGMA_NO_UNROLL _Pragma("unroll 1")
#else
#define PRAGMA_UNROLL #pragma unroll
#define PRAGMA_NO_UNROLL #pragma unroll 1
#endif
#else
#define PRAGMA_UNROLL
#define PRAGMA_NO_UNROLL
#endif
// Modified from NVIDIA FasterTransformer:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
// Modified from llm-awq https://github.com/mit-han-lab/llm-awq/blob/main/awq/kernels/csrc/quantization/dequantize.cuh
__inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is
// thanks to the register packing format and the fact that we force our
// integers to be unsigned, and account for this in the fp16 subtractions. In
// addition, I exploit the fact that sub and fma have the same throughput in
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
// the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit
// float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result;
}
__inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
{
uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const& i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOT_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t MAGIC_NUM_0 = 0x64006400; // `1024`
static constexpr uint32_t MAGIC_NUM_1 = 0x54005400; // `64`
static constexpr uint32_t MAGIC_NUM_2 = MAGIC_NUM_1 >> 4; // `64` >> 4
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8;
if (0) { // 1024 & 64
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MAGIC_NUM_0));
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MAGIC_NUM_1));
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MAGIC_NUM_0));
asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MAGIC_NUM_1));
}
else { // 64 only, trade 4 hfma2 with 2 shifts
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
h[0] <<= 4;
h[2] <<= 4;
// we don't need to subtract the magic nums because zeros will go through the same dequant function
// and carry the same magic constant, the magic num will be canceled out after subtracting zeros
}
return result;
}
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr)
{
uint32_t smem_int_ptr;
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
: "=r"(smem_int_ptr)
: "l"(ptr));
return smem_int_ptr;
}
__inline__ __device__ void ldmatrix_m8n8_x4_b16(uint& d0, uint& d1, uint& d2, uint& d3, uint32_t smem_int_ptr)
{
#if TURBOMIND_ARCH_SM75
asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(smem_int_ptr));
#else
assert(TURBOMIND_ARCH_SM75);
#endif
}
__inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t smem_int_ptr)
{
#if TURBOMIND_ARCH_SM75
asm("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(d0), "=r"(d1) : "r"(smem_int_ptr));
#else
assert(TURBOMIND_ARCH_SM75);
#endif
}
__inline__ __device__ void wait_flag(int* lock, int status, int thread_id)
{
int state = 0;
while (__syncthreads_and(state != status)) {
if (thread_id == 0) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#else
asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#endif
}
}
__syncthreads(); // memory fence
}
__inline__ __device__ void release_flag(int* lock, int status, int thread_id)
{
__syncthreads(); // memory fence
if (thread_id == 0) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#else
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#endif
}
}
__inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
{
uint s, z;
(half2&)z = __halves2half2(q.x, q.x);
(half2&)s = __halves2half2(q.y, q.y);
auto& t = (const uint&)x;
uint u, v;
if (TURBOMIND_S4_DEQUANT_USE_FMA) {
asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z));
}
else {
asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z));
asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s));
}
return (half2&)v;
}
template<typename T, int N>
struct Array {
T a[N];
__device__ __host__ constexpr T& operator[](int i) noexcept
{
return a[i];
}
__device__ __host__ constexpr const T& operator[](int i) const noexcept
{
return a[i];
}
};
template<int... Ns>
struct Shape {
static constexpr Array<int, sizeof...(Ns)> data_{Ns...};
constexpr Shape() = default;
Shape(std::integral_constant<int, Ns>...){};
template<int index>
constexpr auto get() const noexcept
{
return std::integral_constant<int, data_[index]>{};
}
constexpr auto m() const noexcept
{
return get<0>();
}
constexpr auto n() const noexcept
{
return get<1>();
}
constexpr auto k() const noexcept
{
return get<2>();
}
constexpr int c() const noexcept
{
return get<0>();
}
constexpr int s() const noexcept
{
return get<1>();
}
constexpr int count() const noexcept
{
return (Ns * ...);
}
};
template<int... Ns>
Shape(std::integral_constant<int, Ns>...) -> Shape<Ns...>;
template<int... Ns>
inline constexpr Shape<Ns...> shape_c{};
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "common.h"
#include <cstdint>
namespace turbomind {
template<typename T>
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
{
#if TURBOMIND_ARCH_SM80
constexpr int cp_size = sizeof(T);
static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16");
asm volatile("{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global.L2::256B [%1], [%2], %3;\n"
"}\n" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
#else
assert(TURBOMIND_ARCH_SM80);
#endif
}
template<typename T>
__inline__ __device__ void cp_async_cg_B(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
{
#if TURBOMIND_ARCH_SM80
constexpr int cp_size = sizeof(T);
static_assert(cp_size == 16, "cp.async.cg requreis cp_size == 16");
asm volatile("{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n"
"}\n" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
#else
assert(TURBOMIND_ARCH_SM80);
#endif
}
template<typename T>
__inline__ __device__ void cp_async_ca(uint32_t smem_int_ptr, const T* __restrict__ src, bool mask)
{
#if TURBOMIND_ARCH_SM80
constexpr int cp_size = sizeof(T);
asm volatile("{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n"
"}\n" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
#else
assert(TURBOMIND_ARCH_SM80);
#endif
}
template<int WARPS, int CTA_M, int CTA_N, int CTA_K, int STAGES, int SLICES>
struct IteratorA {
static constexpr int SLICE_K = CTA_K / SLICES;
using AccessType = uint4;
static constexpr int kAccessSize = sizeof(AccessType);
static_assert(CTA_M % 32 == 0 && CTA_K % 32 == 0, "A is pre-formatted as 32x32 tiles");
// A is [K/32, M/32, WARP_SIZE] uint4
static constexpr int kShapeM = CTA_M;
static constexpr int kShapeK = SLICE_K / 32;
// thread access shape
static constexpr int kAccessM = 1;
static constexpr int kAccessK = 1;
// warp thread arrangement
static constexpr int kWarpThreadC = 32;
static constexpr int kWarpThreadS = 1;
// warp shape per access
static constexpr int kWarpAccessM = kWarpThreadC * kAccessM; // 32
static constexpr int kWarpAccessK = kWarpThreadS * kAccessK; // 1
// warp access iterations
static constexpr int kWarpIterM = kShapeM / kWarpAccessM;
static constexpr int kWarpIterK = kShapeK / kWarpAccessK;
// warp arrangement
static constexpr int kWarpM = kWarpIterM >= WARPS ? WARPS : kWarpIterM;
static constexpr int kWarpK = WARPS > kWarpIterM ? (WARPS / kWarpM) : 1;
// iterations
static constexpr int kIterM = kWarpIterM / kWarpM;
static constexpr int kIterK = kWarpIterK / kWarpK;
static constexpr int kIterCount = kIterM * kIterK;
static_assert(kIterCount > 0);
// warp footprint
static constexpr int kWarpFootprintM = kWarpAccessM * kIterM;
static constexpr int kWarpFootprintK = kWarpAccessK * kIterK;
static constexpr int kSizePerStage = kShapeK * kShapeM;
static constexpr int kSmemByteSize = kAccessSize * STAGES * kSizePerStage;
const uint* src_;
AccessType* smem_;
uint32_t smem_int_ptr_;
const int m_;
const int k_;
const int warp_id_;
const int lane_id_;
int src_offset_;
int dst_offset_;
int src_step_m_;
int src_step_k_;
int src_step_s_;
int dst_step_m_;
int dst_step_k_;
int dst_step_s_;
int iter_m_{0};
IteratorA() = default;
__device__ IteratorA(const uint* src, void* smem, int m, int k, int cta_m, int cta_k, int warp_id, int lane_id):
src_(src),
smem_((AccessType*)smem),
smem_int_ptr_(cast_smem_ptr_to_uint(smem)),
m_(m),
k_(k),
warp_id_(warp_id),
lane_id_(lane_id)
{
const int warp_offset_m = warp_id_ % kWarpM;
const int warp_offset_k = warp_id_ / kWarpM;
const int warp_thread_offset_m = lane_id_ % kWarpThreadC;
const int warp_thread_offset_k = lane_id_ / kWarpThreadC;
const int cta_thread_offset_m = kWarpFootprintM * warp_offset_m + warp_thread_offset_m * kAccessM;
const int cta_thread_offset_k = kWarpFootprintK * warp_offset_k + warp_thread_offset_k * kAccessK;
const int src_offset_m = cta_thread_offset_m + cta_m;
const int src_offset_k = cta_thread_offset_k + cta_k / 32;
src_offset_ = src_offset_k * m_ + src_offset_m;
src_step_m_ = kWarpAccessM;
src_step_k_ = kWarpAccessK * m_ - kIterM * kWarpAccessM;
src_step_s_ = CTA_K / 32 * m_ - kIterK * kWarpAccessK * m_;
const int dst_offset_m = cta_thread_offset_m;
const int dst_offset_k = cta_thread_offset_k;
dst_offset_ = dst_offset_k * kShapeM + dst_offset_m;
dst_step_m_ = kWarpAccessM;
dst_step_k_ = kWarpAccessK * kShapeM - kIterM * kWarpAccessM;
dst_step_s_ = SLICE_K / 32 * kShapeM - kIterK * kWarpAccessK * kShapeM;
dst_offset_ *= kAccessSize;
dst_step_m_ *= kAccessSize;
dst_step_k_ *= kAccessSize;
dst_step_s_ *= kAccessSize;
}
__device__ void prefetch_stage(bool mask)
{
PRAGMA_UNROLL
for (int i = 0; i < kIterCount; ++i) {
prefetch(mask);
++(*this);
}
next_stage();
}
__device__ void prefetch_batch(int batch_idx, int batch_size, bool mask)
{
PRAGMA_UNROLL
for (int i = 0; i < batch_size; ++i) {
if (batch_idx * batch_size + i < kIterCount) {
prefetch(mask);
++(*this);
}
}
}
__device__ IteratorA& operator++()
{
src_offset_ += src_step_m_;
dst_offset_ += dst_step_m_;
++iter_m_;
if (iter_m_ < kIterM) {
return *this;
}
iter_m_ = 0;
src_offset_ += src_step_k_;
dst_offset_ += dst_step_k_;
return *this;
}
__device__ void next_stage()
{
src_offset_ += src_step_s_;
dst_offset_ += dst_step_s_;
if (dst_offset_ >= kSmemByteSize) {
dst_offset_ -= kSmemByteSize;
}
}
__device__ void prefetch(bool mask)
{
cp_async_cg_A(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
}
};
template<int WARPS, int CTA_M, int CTA_N, int CTA_K, int STAGES, int SLICES, int GROUP_SIZE>
struct IteratorQ {
static constexpr int SLICE_K = CTA_K / SLICES;
using AccessType = uint;
static constexpr int kAccessSize = sizeof(AccessType);
static constexpr int kAccessM = kAccessSize / sizeof(half2);
static constexpr int kAccessK = GROUP_SIZE;
// warp thread arrangement
static constexpr int kWarpThreadC = 32;
static constexpr int kWarpThreadS = 1;
// warp shape per access
static constexpr int kWarpAccessM = kWarpThreadC * kAccessM; // 32
static constexpr int kWarpAccessK = kWarpThreadS * kAccessK; // GROUP_SIZE
// warp access iterations
static constexpr int kWarpIterM = CTA_M / kWarpAccessM; // CTA_M / 32
static constexpr int kWarpIterK = SLICE_K / kWarpAccessK; // SLICE_K / GROUP_SIZE, maybe 0
// kWarpIterK == 0 => SLICE_K < kWarpAccessK => kIterK == 1
// warp arrangement
static constexpr int kWarpM = kWarpIterM >= WARPS ? WARPS : kWarpIterM;
static constexpr int kWarpK = WARPS > kWarpIterM ? WARPS / kWarpM : 1;
// iterations
static constexpr int kIterM = kWarpIterM / kWarpM;
static constexpr int kIterK = kWarpIterK >= kWarpK ? kWarpIterK / kWarpK : 1;
static constexpr int kIterCount = kIterM * kIterK;
// warp footprint
static constexpr int kWarpFootprintM = kWarpAccessM * kIterM;
static constexpr int kWarpFootprintK = kWarpAccessK * kIterK;
static constexpr int kSizePerStage = std::max(SLICE_K / GROUP_SIZE, 1) * CTA_M;
static constexpr int kSmemByteSize = sizeof(uint) * STAGES * kSizePerStage;
const half2* const src_;
half2* const smem_;
uint32_t const smem_int_ptr_;
const int m_;
const int k_;
bool is_out_of_bound_; // mask for out-of-bound warps
int src_offset_k_;
int src_offset_m_;
int src_offset_;
int src_step_m_;
int src_step_k_;
int dst_offset_;
int dst_step_m_;
int dst_step_k_;
int tmp_src_offset_;
int tmp_dst_offset_;
int iter_m_{0};
struct Storage {
half2 data[SLICES][STAGES * kSizePerStage];
};
IteratorQ() = default;
__device__ IteratorQ(const half2* src, half2* smem, int m, int k, int cta_m, int cta_k, int warp_id, int lane_id):
src_(src), smem_(smem), smem_int_ptr_(cast_smem_ptr_to_uint(smem)), m_(m), k_(k)
{
const int warp_offset_m = warp_id % kWarpM;
const int warp_offset_k = warp_id / kWarpM;
const int warp_thread_offset_m = lane_id % kWarpThreadC;
const int warp_thread_offset_k = lane_id / kWarpThreadC;
const int cta_thread_offset_m = kWarpFootprintM * warp_offset_m + warp_thread_offset_m * kAccessM;
const int cta_thread_offset_k = kWarpFootprintK * warp_offset_k + warp_thread_offset_k * kAccessK;
// mask out-of-bound warps
is_out_of_bound_ = cta_thread_offset_k >= SLICE_K;
src_offset_m_ = cta_thread_offset_m + cta_m;
src_offset_k_ = cta_thread_offset_k + cta_k;
src_offset_ = src_offset_k_ / GROUP_SIZE * m_ + src_offset_m_;
src_step_m_ = kWarpAccessM;
src_step_k_ = m_ - kIterM * kWarpAccessM; // valid only when SLICE_K >= GROUP_SIZE
const int dst_offset_m = cta_thread_offset_m;
const int dst_offset_k = cta_thread_offset_k;
dst_offset_ = dst_offset_k / GROUP_SIZE * CTA_M + dst_offset_m;
dst_step_m_ = kWarpAccessM;
dst_step_k_ = CTA_M - kIterM * kWarpAccessM; // valid only when SLICE_K >= GROUP_SIZE
dst_offset_ *= kAccessSize;
dst_step_m_ *= kAccessSize;
dst_step_k_ *= kAccessSize;
tmp_src_offset_ = src_offset_;
tmp_dst_offset_ = dst_offset_;
}
__device__ void prefetch_stage(bool mask)
{
if (is_out_of_bound_) {
return;
}
PRAGMA_UNROLL
for (int i = 0; i < kIterCount; ++i) {
prefetch(mask);
++(*this);
}
next_stage();
}
__device__ void prefetch_batch(int batch_idx, int batch_size, bool mask)
{
if (is_out_of_bound_) {
return;
}
PRAGMA_UNROLL
for (int i = 0; i < batch_size; ++i) {
if (batch_idx * batch_size + i < kIterCount) {
prefetch(mask);
++(*this);
}
}
}
__device__ IteratorQ& operator++()
{
++iter_m_;
src_offset_ += src_step_m_;
dst_offset_ += dst_step_m_;
if (iter_m_ < kIterM) {
return *this;
}
iter_m_ = 0;
if constexpr (SLICE_K >= GROUP_SIZE) {
src_offset_ += src_step_k_;
dst_offset_ += dst_step_k_;
}
// else advnace offsets in `next_stage`
return *this;
}
__device__ void next_stage()
{
if constexpr (SLICE_K >= GROUP_SIZE) {
src_offset_ += (CTA_K / GROUP_SIZE - kIterK) * m_;
dst_offset_ += kAccessSize * (SLICE_K / GROUP_SIZE - kIterK) * CTA_M;
}
else { // SLICE_K < GROUP_SIZE, recompute `src_offset_`
src_offset_k_ += CTA_K;
src_offset_ = (src_offset_k_ / GROUP_SIZE) * m_ + src_offset_m_;
dst_offset_ += dst_step_k_;
}
if (dst_offset_ >= kSmemByteSize) {
dst_offset_ -= kSmemByteSize;
}
}
__device__ void prefetch(bool mask)
{
cp_async_ca(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
}
};
template<int WARPS, int CTA_M, int CTA_N, int CTA_K, int STAGES, int SLICES>
struct IteratorB {
static constexpr int SLICE_K = CTA_K / SLICES;
static constexpr int kElementSize = sizeof(half);
using AccessType = uint4;
static constexpr int kAccessSize = sizeof(AccessType);
static constexpr int kShapeK = SLICE_K;
static constexpr int kShapeN = CTA_N;
static constexpr int kAccessK = kAccessSize / sizeof(half);
static_assert(kShapeK % kAccessSize == 0);
// warp thread arrangement
static constexpr int kWarpThreadC = std::max(kShapeK / kAccessK, 1);
static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;
// warp shape per access
static constexpr int kWarpAccessK = kWarpThreadC * kAccessK;
static constexpr int kWarpAccessN = kWarpThreadS;
// warp access iterations
static constexpr int kWarpIterK = kShapeK / kWarpAccessK;
static constexpr int kWarpIterN = kShapeN / kWarpAccessN;
// warp arrangement
static constexpr int kWarpK = kWarpIterK >= WARPS ? WARPS : kWarpIterK;
static constexpr int kWarpN = WARPS > kWarpIterK ? WARPS / kWarpK : 1;
// iterations
static constexpr int kIterK = kWarpIterK / kWarpK;
static constexpr int kIterN = kWarpIterN >= kWarpN ? kWarpIterN / kWarpN : 1;
static constexpr int kIterCount = kIterK * kIterN;
static_assert(kIterCount > 0);
// warp footprint
static constexpr int kWarpFootprintK = kWarpAccessK * kIterK;
static constexpr int kWarpFootprintN = kWarpAccessN * kIterN;
// Eliminate bank-conflicts for 8x4 half2 tiles, watch out for misalignment
static constexpr int kSmemPadCtaK = SLICE_K + 8;
static constexpr int kSizePerTile = CTA_N * kSmemPadCtaK;
static constexpr int kSmemByteSize = kElementSize * STAGES * kSizePerTile;
const half* src_;
AccessType* const smem_; // [CTA_N, SLICE_K + 8]
const uint32_t smem_int_ptr_;
const int k_;
const int n_;
const int cta_n_;
const int warp_id_;
const int lane_id_;
const int c_;
const int s_;
int src_offset_n_;
int src_offset_;
int dst_offset_;
int src_step_k_;
int src_step_n_;
int dst_step_k_;
int dst_step_n_;
bool is_valid_n_;
int tmp_src_offset_;
int tmp_dst_offset_;
int tmp_src_offset_n_;
int iter_k_{0};
int iter_n_{0};
IteratorB() = default;
__device__ IteratorB(const half* src, void* smem, int k, int n, int cta_n, int cta_k, int warp_id, int lane_id):
src_(src),
smem_((AccessType*)smem),
smem_int_ptr_(cast_smem_ptr_to_uint(smem)),
k_(k),
n_(n),
cta_n_(cta_n),
warp_id_(warp_id),
lane_id_(lane_id),
c_(lane_id_ % kWarpThreadC),
s_(lane_id_ / kWarpThreadC)
{
const int warp_offset_k = warp_id_ % kWarpK;
const int warp_offset_n = warp_id_ / kWarpK;
const int warp_thread_offset_k = lane_id_ % kWarpThreadC;
const int warp_thread_offset_n = lane_id_ / kWarpThreadC;
const int cta_thread_offset_k = kWarpFootprintK * warp_offset_k + warp_thread_offset_k * kAccessK;
const int cta_thread_offset_n = kWarpFootprintN * warp_offset_n + warp_thread_offset_n;
const int src_offset_k = cta_thread_offset_k + cta_k;
src_offset_n_ = cta_thread_offset_n + cta_n_;
src_offset_ = src_offset_n_ * k_ + src_offset_k;
const int dst_offset_k = cta_thread_offset_k;
const int dst_offset_n = cta_thread_offset_n;
dst_offset_ = dst_offset_n * kSmemPadCtaK + dst_offset_k;
src_step_k_ = kWarpAccessK;
src_step_n_ = kWarpAccessN * k_ - kIterK * kWarpAccessK;
dst_step_k_ = kWarpAccessK;
dst_step_n_ = kWarpAccessN * kSmemPadCtaK - kIterK * kWarpAccessK;
dst_offset_ *= kElementSize;
dst_step_k_ *= kElementSize;
dst_step_n_ *= kElementSize;
tmp_src_offset_ = src_offset_;
tmp_dst_offset_ = dst_offset_;
tmp_src_offset_n_ = src_offset_n_;
is_valid_n_ = tmp_src_offset_n_ < n_;
}
__device__ void prefetch_stage(bool mask)
{
PRAGMA_UNROLL
for (int i = 0; i < kIterCount; ++i) {
prefetch(mask);
++(*this);
}
next_stage();
}
__device__ void prefetch_batch(int batch_idx, int batch_size, bool mask)
{
PRAGMA_UNROLL
for (int i = 0; i < batch_size; ++i) {
if (batch_idx * batch_size + i < kIterCount) {
prefetch(mask);
++(*this);
}
}
}
__device__ IteratorB& operator++()
{
if (!is_valid_n_) {
return *this;
}
// move to next k
tmp_src_offset_ += src_step_k_;
tmp_dst_offset_ += dst_step_k_;
++iter_k_;
if (iter_k_ < kIterK) {
return *this;
}
// move to next n
iter_k_ = 0;
tmp_src_offset_n_ += kWarpAccessN;
tmp_src_offset_ += src_step_n_;
tmp_dst_offset_ += dst_step_n_;
is_valid_n_ = tmp_src_offset_n_ < n_;
++iter_n_;
return *this;
}
__device__ void next_stage()
{
iter_n_ = 0;
src_offset_ += CTA_K;
dst_offset_ += kElementSize * kSizePerTile;
if (dst_offset_ >= kSmemByteSize) {
dst_offset_ -= kSmemByteSize;
}
tmp_src_offset_ = src_offset_;
tmp_dst_offset_ = dst_offset_;
tmp_src_offset_n_ = src_offset_n_;
is_valid_n_ = tmp_src_offset_n_ < n_;
}
__device__ void prefetch(bool mask)
{
cp_async_cg_B(
smem_int_ptr_ + tmp_dst_offset_, (const AccessType*)(src_ + tmp_src_offset_), is_valid_n_ && mask);
}
};
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#include "common.h"
#include <iostream>
namespace turbomind {
__device__ void atomic_assign_u4(uint32_t* address, uint32_t index, uint32_t value)
{
uint32_t old = *address;
uint32_t assumed;
do {
assumed = old;
uint32_t tmp = (assumed & ~(0xfu << (index * 4u))) | (value << (index * 4u));
old = atomicCAS(address, assumed, tmp);
} while (assumed != old);
}
__device__ uint32_t read_u4(const uint32_t* address, uint32_t index)
{
return (*address >> (index * 4u)) & 0xfu;
}
template<int... Ds>
__global__ void permute_u4(uint* dst, const uint* src, Array<int, sizeof...(Ds)> dims)
{
constexpr int N = sizeof...(Ds);
size_t count = 1;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
count *= dims[i];
}
constexpr int order[] = {Ds...};
for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
int indices[N]{};
PRAGMA_UNROLL
for (int j = N - 1, ii = i; j >= 0; --j) {
indices[j] = ii % dims[j];
ii /= dims[j];
}
auto data = read_u4(src + i / 8, i % 8);
int index = 0;
PRAGMA_UNROLL
for (int j = N - 1, stride = 1; j >= 0; --j) {
index += indices[order[j]] * stride;
stride *= dims[order[j]];
}
atomic_assign_u4(dst + index / 8, index % 8, data);
}
}
void reformat_s4_k8_m(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
{
// permutation for [k/8, m] layout
Array<int, 10> shape{k / 32, 2, 2, m / 32, 2, 2, 8, 2, 2, 2};
// |warp| lane | 2x2 | a0-7 |
permute_u4<0, 3, 6, 8, 9, 1, 4, 7, 2, 5><<<512, 512, 0, st>>>(dst, src, shape);
}
void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
{
// permutation for [k, m/8] layout
Array<int, 10> shape{k / 32, 2, 2, 4, 2, m / 32, 2, 2, 2, 4};
// |warp| lane | 2x2 | a0-7 |
permute_u4<0, 5, 9, 8, 3, 1, 6, 4, 2, 7><<<512, 512, 0, st>>>(dst, src, shape);
}
__global__ void dequantize_s4_offset_64(uint4* dst, const uint32_t* src, size_t count)
{
for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
dst[i] = dequantize_s4_to_fp16x2_v2(src[i]);
}
}
__global__ void merge_Q(half2* Q, const half* scales, const half* zeros, int count)
{
for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
if (TURBOMIND_S4_DEQUANT_USE_FMA) {
// dequant via HFMA2 has numerical statbility issue
Q[i] = __halves2half2(-zeros[i] * scales[i], scales[i]);
}
else {
Q[i] = __halves2half2(zeros[i], scales[i]);
}
}
}
void convert_s4_k_m8(uint32_t* A_dst,
half2* Q_dst,
half* workspace,
const uint32_t* A_src,
const half* scales,
const uint32_t* qzeros,
int m,
int k,
int group_size,
cudaStream_t st)
{
dequantize_s4_offset_64<<<256, 256, 0, st>>>((uint4*)workspace, qzeros, k / group_size * m / 8);
merge_Q<<<256, 256, 0, st>>>(Q_dst, scales, workspace, k / group_size * m);
reformat_s4_k_m8(A_dst, A_src, m, k, st);
}
void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st)
{
Array<int, 7> shape{k, m / size_per_head, 2, size_per_head / 2 / 8, 2, 2, 2};
// dequant transpose quant
// 0123456 -> 0123564 -> 0135642 -> 0135264
permute_u4<0, 1, 3, 5, 2, 6, 4><<<512, 512, 0, st>>>(dst, src, shape);
}
// [2, k, m/8] -> [k, m/8, 2]
void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
{
Array<int, 6> shape{2, k, m / 8, 2, 2, 2};
// dequant transpose quant
// 012345 -> 012453 -> 124530 -> 124053
permute_u4<1, 2, 4, 0, 5, 3><<<512, 512, 0, st>>>(dst, src, shape);
}
__global__ void dequantize_s4_kernel(uint4* dst, const uint* src, size_t count)
{
for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
dst[i] = dequantize_s4_to_fp16x2(src[i]);
}
}
void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t st)
{
dequantize_s4_kernel<<<512, 512>>>(dst, src, count);
}
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <cstdint>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
namespace turbomind {
void reformat_s4_k8_m(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st = {});
void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st = {});
void convert_s4_k_m8(uint32_t* A_dst,
half2* Q_dst,
half* workspace,
const uint32_t* A_src,
const half* scales,
const uint32_t* qzeros,
int m,
int k,
int group_size,
cudaStream_t st = {});
void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st = {});
void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st = {});
void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t st = {});
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#include "gemm_s4_f16.h"
#include "gemm_s4_f16_kernel.h"
#include "metric.h"
#include <algorithm>
#include <iomanip>
#include <ios>
#include <iostream>
#include <limits>
#include <numeric>
#include <stdexcept>
#include <tuple>
#include <vector>
namespace turbomind {
bool g_dump_kernel_info_once = false;
namespace ops {
struct Identity {
static __inline__ __device__ void apply(uint data, int m, int n, half* C, int M, int N)
{
if (n < N) {
(uint&)C[n * M + m] = (uint&)data;
}
}
};
struct SiluActivation {
static __inline__ __device__ void apply(uint data, int m, int n, half* C, int M, int N)
{
auto u = __half22float2((half2&)data);
float silu = u.x / (1.f + __expf(-u.x));
half val = __float2half_rn(silu * u.y);
if (n < N) {
C[n * (M / 2) + m / 2] = val;
}
}
};
} // namespace ops
template<typename... Ts>
struct OutputOps {
template<int index>
static __inline__ __device__ void apply(uint data, int m, int n, half* C, int M, int N)
{
std::tuple_element_t<index, std::tuple<Ts...>>::apply(data, m, n, C, M, N);
}
};
struct GemmS4F16::Impl {
using Kernels = std::vector<std::unique_ptr<IGemmKernel>>;
template<int GS, typename Op>
void Generate(std::vector<Kernels>& kernels)
{
// smem size (KB):
// sm75: 64
// sm80: 163
// sm86: 99
// sm89: 99
// sm90: 227
Kernels k;
// 256
k.emplace_back(new GemmKernel<Shape<256, 128, 32>, Shape<32, 128, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<256, 64, 64>, Shape<64, 64, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<256, 64, 32>, Shape<64, 64, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<256, 32, 64>, Shape<64, 32, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<256, 16, 256>, Shape<32, 16, 128>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<256, 8, 256>, Shape<32, 8, 128>, 3, GS, Op>{});
// 128
k.emplace_back(new GemmKernel<Shape<128, 128, 64>, Shape<32, 128, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 128, 32>, Shape<32, 128, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 96, 64>, Shape<32, 96, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 64, 64>, Shape<32, 64, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 64, 32>, Shape<32, 64, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 32, 128>, Shape<32, 32, 64>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 16, 256>, Shape<32, 16, 64>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 8, 512>, Shape<32, 8, 128>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<128, 8, 512>, Shape<32, 8, 128>, 2, GS, Op>{}); // for 86/89
// 64
k.emplace_back(new GemmKernel<Shape<64, 16, 256>, Shape<32, 16, 32>, 3, GS, Op>{});
k.emplace_back(new GemmKernel<Shape<64, 8, 256>, Shape<32, 8, 32>, 3, GS, Op>{});
kernels.push_back(std::move(k));
}
void Measure(half* C,
const uint* A,
const half* B,
const half2* Q,
int m,
int n,
int k,
int group_size,
Type type,
std::vector<Metric>& metrics,
cudaStream_t st,
std::vector<Kernels>& _kernels)
{
int gid = -1;
for (size_t i = 0; i < group_sizes_.size(); ++i) {
if (group_sizes_[i] == group_size) {
gid = i;
break;
}
}
if (gid < 0) {
throw std::runtime_error("unsupported group size");
}
const auto& kernels = _kernels[gid];
metrics = std::vector<Metric>(kernels.size());
int best = 0;
for (size_t i = 0; i < kernels.size(); ++i) {
metrics[i].id = i;
kernels[i]->GetMetric(metrics[i], m, n, k);
if (!metrics[i].feasible) {
metrics[i].time = std::numeric_limits<float>::infinity();
metrics[i].count = 1;
continue;
}
if (Compare(metrics[i], metrics[best])) {
best = i;
}
for (size_t j = 0; j < kWarmup + kMeasure; ++j) {
if (j == kWarmup) {
cudaEventRecord(ev_start_, st);
}
kernels[i]->Launch(C, A, B, Q, m, n, k, type, st);
}
cudaEventRecord(ev_end_, st);
cudaEventSynchronize(ev_end_);
float ms{};
cudaEventElapsedTime(&ms, ev_start_, ev_end_);
metrics[i].time = ms;
metrics[i].count = kMeasure;
}
metrics[best].best = 1;
// sort metrics
std::vector<int> indices(kernels.size());
std::iota(indices.begin(), indices.end(), 0);
std::stable_sort(
indices.begin(), indices.end(), [&](int i, int j) { return metrics[i].time < metrics[j].time; });
if (g_dump_kernel_info_once) {
DumpMetrics(std::cerr, metrics, indices);
g_dump_kernel_info_once = 0;
}
std::vector<Metric> tmp;
for (size_t i = 0; i < indices.size(); ++i) {
tmp.push_back(metrics[indices[i]]);
}
metrics.swap(tmp);
}
static bool Compare(const Metric& a, const Metric& b)
{
if (a.feasible != b.feasible) {
return a.feasible > b.feasible;
}
if (a.prefer != b.prefer) {
return a.prefer > b.prefer;
}
return a.grid_norm < b.grid_norm;
}
int Estimate(int m, int n, int k, Kernels& kernels)
{
int best = 0;
std::vector<Metric> metrics(kernels.size());
for (size_t i = 0; i < kernels.size(); ++i) {
metrics[i].id = i;
kernels[i]->GetMetric(metrics[i], m, n, k);
if (Compare(metrics[i], metrics[best])) {
best = i;
}
}
if (g_dump_kernel_info_once) {
std::vector<int> indices(kernels.size());
std::iota(indices.begin(), indices.end(), 0);
std::stable_sort(
indices.begin(), indices.end(), [&](int i, int j) { return Compare(metrics[i], metrics[j]); });
DumpMetrics(std::cerr, metrics, indices);
g_dump_kernel_info_once = 0;
}
return best;
}
void Run(half* C,
const uint* A,
const half* B,
const half2* Q,
int m,
int n,
int k,
int group_size,
Type type,
int algo_id,
cudaStream_t st,
std::vector<Kernels>& kernels)
{
for (size_t i = 0; i < group_sizes_.size(); ++i) {
if (group_sizes_[i] == group_size) {
if (algo_id < 0) {
algo_id = Estimate(m, n, k, kernels[i]);
}
if (algo_id < 0) {
throw std::runtime_error("no feasible kernel found");
}
kernels[i].at(algo_id)->Launch(C, A, B, Q, m, n, k, type, st);
return;
}
}
throw std::runtime_error("unsupported group size");
}
Impl()
{
cudaEventCreate(&ev_start_);
cudaEventCreate(&ev_end_);
using Ops = OutputOps<ops::Identity, ops::SiluActivation>;
/// TODO: add more group sizes
Generate<128, Ops>(kernels_);
group_sizes_.push_back(128);
}
~Impl()
{
cudaEventDestroy(ev_end_);
cudaEventDestroy(ev_start_);
}
std::vector<Kernels> kernels_;
std::vector<int> group_sizes_;
static constexpr int kWarmup = 10;
static constexpr int kMeasure = 100;
cudaEvent_t ev_start_{};
cudaEvent_t ev_end_{};
};
GemmS4F16::GemmS4F16(): impl_(std::make_unique<Impl>()) {}
GemmS4F16::~GemmS4F16() = default;
void GemmS4F16::Measure(half* C,
const uint* A,
const half* B,
const half2* Q,
int m,
int n,
int k,
int group_size,
Type type,
std::vector<Metric>& metrics,
cudaStream_t st)
{
impl_->Measure(C, A, B, Q, m, n, k, group_size, type, metrics, st, impl_->kernels_);
}
void GemmS4F16::Run(half* C,
const uint* A,
const half* B,
const half2* Q,
int m,
int n,
int k,
int group_size,
Type type,
int algo_id,
cudaStream_t st)
{
impl_->Run(C, A, B, Q, m, n, k, group_size, type, algo_id, st, impl_->kernels_);
}
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "metric.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <memory>
#include <vector>
namespace turbomind {
extern bool g_dump_kernel_info_once;
class GemmS4F16 {
public:
GemmS4F16();
~GemmS4F16();
enum Type
{
kGemm,
kFusedSiluFfn
};
void Measure(half* C,
const uint* A,
const half* B,
const half2* Q,
int m,
int n,
int k,
int group_size,
Type type,
std::vector<Metric>& metrics,
cudaStream_t st);
void Run(half* C,
const uint* A,
const half* B,
const half2* Q,
int m,
int n,
int k,
int group_size,
Type type,
int algo_id,
cudaStream_t st);
private:
struct Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "gemm_template.h"
#include "metric.h"
#include <iostream>
#include <memory>
#include <sstream>
namespace turbomind {
struct IGemmKernel {
virtual ~IGemmKernel() = default;
virtual void GetMetric(Metric& metric, int m, int n, int k) = 0;
virtual void Launch(half* C,
const uint* A,
const half* B,
const half2* Q,
int M,
int N,
int K,
int output_op_idx,
cudaStream_t) = 0;
virtual void Dump(std::ostream& os) = 0;
};
template<typename CtaShape, typename WarpShape, int Stages, int GroupSize, typename OutputOps>
struct GemmKernel: public IGemmKernel {
static constexpr CtaShape cta_shape{};
static constexpr WarpShape warp_shape{};
using GemmType = Gemm<cta_shape.m(),
cta_shape.n(),
cta_shape.k(),
warp_shape.m(),
warp_shape.n(),
warp_shape.k(),
Stages,
GroupSize,
OutputOps>;
decltype(&gemm_s4_f16_nn<GemmType>) kernel_func_;
std::shared_ptr<cudaDeviceProp> props_;
int max_active_ctas_{};
static constexpr int kSlices = GemmType::SLICES;
static constexpr int kSmemSizeA = GemmType::IteratorA::kSmemByteSize * kSlices;
static constexpr int kSmemSizeB = GemmType::IteratorB::kSmemByteSize * kSlices;
static constexpr int kSmemSizeC = sizeof(float) * cta_shape.m() * cta_shape.n();
static constexpr int kSmemByteSize = std::max(kSmemSizeA + kSmemSizeB, kSmemSizeC);
// static shared memory size of Q
static constexpr int kSmemSizeQ = sizeof(typename GemmType::IteratorQ::Storage);
explicit GemmKernel(std::shared_ptr<cudaDeviceProp> props = {}): props_(std::move(props))
{
if (!props_) {
props_ = std::make_shared<cudaDeviceProp>();
int device_id = -1;
cudaGetDevice(&device_id);
cudaGetDeviceProperties(props_.get(), device_id);
}
kernel_func_ = gemm_s4_f16_nn<GemmType>;
cudaFuncSetAttribute(kernel_func_, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_ctas_, kernel_func_, GemmType::kWarpCount * WARP_SIZE, kSmemByteSize);
};
bool is_feasible(int m, int n, int k)
{
return m % cta_shape.m() == 0 && k % cta_shape.k() == 0;
}
void GetMetric(Metric& metric, int m, int n, int k) override
{
metric.cta_shape = {cta_shape.m(), cta_shape.n(), cta_shape.k()};
metric.warp_shape = {warp_shape.m(), warp_shape.n(), warp_shape.k()};
metric.warps = GemmType::kWarpCount;
metric.stages = Stages;
metric.smem = (kSmemByteSize + kSmemSizeQ) / 1024.f;
metric.feasible = is_feasible(m, n, k) && max_active_ctas_ > 0;
metric.prefer = cta_shape.m() != 64 || m <= k;
if (!metric.feasible) {
return;
}
int grid_size = ((m + cta_shape.m() - 1) / cta_shape.m()) * ((n + cta_shape.n() - 1) / cta_shape.n());
metric.grid_size = grid_size;
metric.max_active_ctas = max_active_ctas_;
metric.active_ctas =
std::min(max_active_ctas_, (grid_size + props_->multiProcessorCount - 1) / props_->multiProcessorCount);
metric.waves = (float)grid_size / (props_->multiProcessorCount * metric.active_ctas);
metric.occupancy = (metric.active_ctas * GemmType::kWarpCount)
/ (float)(props_->maxThreadsPerMultiProcessor / props_->warpSize);
metric.cta_cnt_m = (m + cta_shape.m() - 1) / cta_shape.m();
metric.cta_cnt_n = (n + cta_shape.n() - 1) / cta_shape.n();
metric.cta_iter_k = (k + cta_shape.k() - 1) / cta_shape.k();
metric.tile_efficiency = (float)n / (metric.cta_cnt_n * cta_shape.n());
metric.wave_efficiency = metric.waves / std::ceil(metric.waves);
const int m_pad = (m + cta_shape.m() - 1) / cta_shape.m() * cta_shape.m();
const int n_pad = (n + cta_shape.n() - 1) / cta_shape.n() * cta_shape.n();
metric.grid_a0 = 0.25f * m * n_pad / cta_shape.n(); // Ta0 * M * [N / ctaN]
metric.grid_b0 = 1.00f * n * m_pad / cta_shape.m(); // Tb0 * N * [M / ctaM]
metric.grid_a1 = 0.65f * m_pad * n_pad / warp_shape.n(); // Ta1 * [M] * [N] / warpN
metric.grid_b1 = 0.25f * m_pad * n_pad / warp_shape.m(); // Tb1 * [M] * [N] / warpM
metric.grid_mm = 1.00f * m_pad * n_pad / 64; // Tm * [M] * [N]
metric.grid_sum = metric.grid_a0 + metric.grid_b0 + metric.grid_a1 + metric.grid_b1 + metric.grid_mm;
metric.cta_sum = metric.grid_sum / grid_size;
metric.waves1 = (float)grid_size / (props_->multiProcessorCount * metric.active_ctas);
metric.cta_wave = std::ceil(metric.waves1) * metric.active_ctas;
metric.grid_norm = metric.cta_wave * metric.cta_sum;
}
void Launch(
half* C, const uint* A, const half* B, const half2* Q, int M, int N, int K, int output_op_idx, cudaStream_t st)
override
{
constexpr int block_size = GemmType::kWarpCount * WARP_SIZE;
dim3 grid_size((M + cta_shape.m() - 1) / cta_shape.m(), (N + cta_shape.n() - 1) / cta_shape.n());
kernel_func_<<<grid_size, block_size, kSmemByteSize, st>>>(C, A, B, Q, M, N, K, output_op_idx);
}
void Dump(std::ostream& os) override
{
{
os << "[Gemm] CTA shape: " << cta_shape.m() << "x" << cta_shape.n() << "x" << cta_shape.k() << std::endl;
os << "[Gemm] warp shape: " << warp_shape.m() << "x" << warp_shape.n() << "x" << warp_shape.k()
<< std::endl;
os << "[Gemm] warp count: " << GemmType::kWarpCountM << "x" << GemmType::kWarpCountN << "x"
<< GemmType::kWarpCountK << " (" << GemmType::kWarpCount << ")" << std::endl;
os << std::endl;
}
{
using Iter = typename GemmType::IteratorA;
os << "[A] shape: " << Iter::kShapeM << " " << Iter::kShapeK << std::endl;
os << "[A] warp thread arrangement: " << Iter::kWarpThreadC << " " << Iter::kWarpThreadS << std::endl;
os << "[A] warp shape per access: " << Iter::kWarpAccessM << " " << Iter::kWarpAccessK << std::endl;
os << "[A] warp access iters: " << Iter::kWarpIterM << " " << Iter::kWarpIterK << std::endl;
os << "[A] warp arrangement: " << Iter::kWarpM << " " << Iter::kWarpK << std::endl;
os << "[A] iterations: " << Iter::kIterM << " " << Iter::kIterK << std::endl;
os << "[A] iters per tile: " << Iter::kIterCount << std::endl;
os << "[A] warp footprint: " << Iter::kWarpFootprintM << " " << Iter::kWarpFootprintK << std::endl;
os << "[A] shared memory: " << Iter::kSmemByteSize << std::endl;
os << std::endl;
}
{
using Iter = typename GemmType::IteratorB;
os << "[B] shape: " << Iter::kShapeK << " " << Iter::kShapeN << std::endl;
os << "[B] warp thread arrangement: " << Iter::kWarpThreadC << " " << Iter::kWarpThreadS << std::endl;
os << "[B] warp shape per access: " << Iter::kWarpAccessK << " " << Iter::kWarpAccessN << std::endl;
os << "[B] warp access iters: " << Iter::kWarpIterK << " " << Iter::kWarpIterN << std::endl;
os << "[B] warp arrangement: " << Iter::kWarpK << " " << Iter::kWarpN << std::endl;
os << "[B] iterations: " << Iter::kIterK << " " << Iter::kIterN << std::endl;
os << "[B] iters per tile: " << Iter::kIterCount << std::endl;
os << "[B] warp footprint: " << Iter::kWarpFootprintK << " " << Iter::kWarpFootprintN << std::endl;
os << "[B] shared memory: " << Iter::kSmemByteSize << std::endl;
os << std::endl;
}
{
using Iter = typename GemmType::IteratorQ;
// os << "[Q] shape: " << CTA_M << " " << Iter::SLICE_K << std::endl;
os << "[Q] warp thread arrangement: " << Iter::kWarpThreadC << " " << Iter::kWarpThreadS << std::endl;
os << "[Q] warp shape per access: " << Iter::kWarpAccessM << " " << Iter::kWarpAccessK << std::endl;
os << "[Q] warp access iters: " << Iter::kWarpIterM << " " << Iter::kWarpIterK << std::endl;
os << "[Q] warp arrangement: " << Iter::kWarpM << " " << Iter::kWarpK << std::endl;
os << "[Q] iterations: " << Iter::kIterM << " " << Iter::kIterK << std::endl;
os << "[Q] iters per tile: " << Iter::kIterCount << std::endl;
os << "[Q] warp footprint: " << Iter::kWarpFootprintM << " " << Iter::kWarpFootprintK << std::endl;
os << "[Q] size per stage: " << Iter::kSizePerStage << std::endl;
os << "[Q] shared memory: " << Iter::kSmemByteSize << std::endl;
os << std::endl;
}
os << "Dynamic shared memory size: " << kSmemByteSize << std::endl;
}
};
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "common.h"
#include "cta_iterator.h"
#include "warp_iterator.h"
#include <cuda_pipeline_primitives.h>
namespace turbomind {
__inline__ __device__ void
mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<half, 4>& b, Array<float, 4>& c)
{
#if TURBOMIND_ARCH_SM80
uint32_t const* A = reinterpret_cast<uint32_t const*>(&a);
uint32_t const* B = reinterpret_cast<uint32_t const*>(&b);
float const* C = reinterpret_cast<float const*>(&c);
float* D = reinterpret_cast<float*>(&d);
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, "
"{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#else
assert(TURBOMIND_ARCH_SM80);
#endif
}
__inline__ __device__ uint transpose_m8n8_b16(uint a)
{
#if TURBOMIND_ARCH_SM75
uint d;
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" : "=r"(d) : "r"(a));
return d;
#else
assert(TURBOMIND_ARCH_SM75);
return 0;
#endif
}
namespace ops {
__inline__ __device__ float4 operator+(const float4& a, const float4& b)
{
return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
}
__inline__ __device__ float2 operator+(const float2& a, const float2& b)
{
return {a.x + b.x, a.y + b.y};
}
} // namespace ops
template<int CTA_M,
int CTA_N,
int CTA_K,
int WARP_M,
int WARP_N,
int WARP_K,
int STAGES,
int GROUP_SIZE,
typename OutputOps>
struct Gemm {
static constexpr int kWarpCountM = CTA_M / WARP_M;
static constexpr int kWarpCountN = CTA_N / WARP_N;
static constexpr int kWarpCountK = CTA_K / WARP_K;
static constexpr int kWarpCountMN = kWarpCountM * kWarpCountN;
static constexpr int kWarpCount = kWarpCountMN * kWarpCountK;
static constexpr int SLICES = kWarpCountK;
static constexpr int SLICE_K = CTA_K / SLICES;
static_assert(SLICE_K % WARP_K == 0, "infeasible sliced-k setting");
using IteratorA = turbomind::IteratorA<kWarpCountMN, CTA_M, CTA_N, CTA_K, STAGES, SLICES>;
using IteratorQ = turbomind::IteratorQ<kWarpCountMN, CTA_M, CTA_N, CTA_K, STAGES, SLICES, GROUP_SIZE>;
using IteratorB = turbomind::IteratorB<kWarpCountMN, CTA_M, CTA_N, CTA_K, STAGES, SLICES>;
static constexpr int OP_M = 16;
static constexpr int OP_N = 8;
static constexpr int OP_K = 16;
using WarpIterA = turbomind::WarpIteratorA<CTA_M,
CTA_K,
WARP_M,
WARP_K,
OP_M,
OP_K,
GROUP_SIZE,
STAGES,
IteratorA::kSizePerStage,
IteratorQ::kSizePerStage>;
using WarpIterB =
turbomind::WarpIteratorB<CTA_N, CTA_K, WARP_N, WARP_K, OP_N, OP_K, IteratorB::kSmemPadCtaK, STAGES>;
__device__ void warp_mma(IteratorA& iter_A,
IteratorQ& iter_Q,
IteratorB& iter_B,
WarpIterA& warp_iter_A,
WarpIterB& warp_iter_B,
float* accum,
int slice_id,
int& gemm_iter)
{
constexpr int ITER_M = WARP_M / OP_M;
constexpr int ITER_N = WARP_N / OP_N;
constexpr int ITER_K = WARP_K / OP_K;
constexpr int kBatchA = (IteratorA::kIterCount + ITER_K - 1) / ITER_K;
constexpr int kBatchQ = (IteratorQ::kIterCount + ITER_K - 1) / ITER_K;
constexpr int kBatchB = (IteratorB::kIterCount + ITER_K - 1) / ITER_K;
auto frag_C_ptr = (Array<float, 4>*)accum; // [ITER_N, ITER_M]
PRAGMA_UNROLL
for (int iter_k = 0; iter_k < ITER_K; ++iter_k) {
warp_iter_A.load(warp_frag_A_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K);
warp_iter_B.load(warp_frag_B_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K);
auto warp_frag_A = warp_frag_A_[iter_k % 2];
auto warp_frag_B = warp_frag_B_[iter_k % 2];
PRAGMA_UNROLL
for (int iter_m = 0; iter_m < ITER_M; ++iter_m) {
PRAGMA_UNROLL
for (int iter_n = 0; iter_n < ITER_N; ++iter_n) {
auto& frag_A = warp_frag_A[iter_m];
auto& frag_B = warp_frag_B[iter_n];
auto& frag_C = frag_C_ptr[iter_n * ITER_M + iter_m];
mma_m16n8k16_row_col(frag_C, frag_A, frag_B, frag_C);
}
}
if (iter_k < ITER_K - 1) {
iter_A.prefetch_batch(iter_k, kBatchA, gemm_iter > 0);
iter_Q.prefetch_batch(iter_k, kBatchQ, gemm_iter > 0);
iter_B.prefetch_batch(iter_k, kBatchB, gemm_iter > 0);
}
if (iter_k == ITER_K - 2) {
iter_A.prefetch_batch(iter_k + 1, kBatchA, gemm_iter > 0);
iter_Q.prefetch_batch(iter_k + 1, kBatchQ, gemm_iter > 0);
iter_B.prefetch_batch(iter_k + 1, kBatchB, gemm_iter > 0);
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
sync_slice(slice_id);
iter_A.next_stage();
iter_Q.next_stage();
iter_B.next_stage();
warp_iter_A.next_stage();
warp_iter_B.next_stage();
--gemm_iter;
}
}
}
template<typename T, int N>
__device__ static void copy(T (&dst)[N], const T (&src)[N])
{
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
dst[i] = src[i];
}
}
template<typename T, int N>
__device__ static void clear(T (&dst)[N])
{
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
dst[i] = T{};
}
}
__device__ void sync_slice(int slice_id)
{
if constexpr (SLICES == 1) {
__syncthreads();
}
else {
constexpr int SLICE_GROUP = (SLICES + 7) / 8;
constexpr uint32_t num_threads = kWarpCountMN * WARP_SIZE;
const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
}
}
__device__ void load_partial(float* tb_frag_C, const float* partial_C, int cta, int slice_id)
{
if (slice_id == 0) {
PRAGMA_UNROLL
for (int i = 0; i < CTA_N; ++i) {
tb_frag_C[i] += partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x];
}
}
}
__device__ void store_partial(float* partial_C, const float* tb_frag_C, int cta, int slice_id)
{
if (slice_id == 0) {
PRAGMA_UNROLL
for (int i = 0; i < CTA_N; ++i) {
partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x] = tb_frag_C[i];
}
}
}
template<int Index>
__device__ void store_accum(float* tb_frag_C,
float* tb_smem_C,
half* C,
int m,
int n,
int cta_m,
int cta_n,
int warp_id_m,
int warp_id_n,
int lane_id,
int slice_id)
{
if (slice_id != 0) {
return;
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-c
PRAGMA_UNROLL
for (int i = 0; i < WARP_N / OP_N; ++i) {
const float2* frag_C = (float2*)&tb_frag_C[i * WARP_M / OP_M * 4];
const int nn = cta_n + warp_id_n * WARP_N + i * OP_N + lane_id / 4;
PRAGMA_UNROLL
for (int j = 0; j < WARP_M / OP_M; ++j) {
PRAGMA_UNROLL
for (int x = 0; x < 2; ++x) {
const int mm = cta_m + warp_id_m * WARP_M + j * OP_M + x * 8 + lane_id % 4 * 2;
// convert to half
half2 half_C = __float22half2_rn(frag_C[j * 2 + x]);
// transpose 8x8 accum tile
uint trans_C = transpose_m8n8_b16((uint&)half_C);
// store to global memory
OutputOps::template apply<Index>(trans_C, mm, nn, C, m, n);
}
}
}
}
__device__ void
sum_slices(float* tb_frag_C, float* tb_smem_C, int warp_id_m, int warp_id_n, int lane_id, int slice_id)
{
int offset_m = warp_id_m * WARP_M / OP_M;
int offset_n = warp_id_n * WARP_N / OP_N;
PRAGMA_UNROLL
for (int z = 0; z < SLICES; ++z) {
if (slice_id == z) {
PRAGMA_UNROLL
for (int i = 0; i < WARP_N / OP_N; ++i) {
PRAGMA_UNROLL
for (int j = 0; j < WARP_M / OP_M; ++j) {
PRAGMA_UNROLL
for (int x = 0; x < 4; ++x) {
int src = (i * WARP_M / OP_M + j) * 4 + x;
int dst = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x;
if (z > 0) {
using namespace ops;
tb_frag_C[src] = tb_smem_C[dst * WARP_SIZE + lane_id] + tb_frag_C[src];
}
tb_smem_C[dst * WARP_SIZE + lane_id] = tb_frag_C[src];
}
}
}
}
__syncthreads();
}
if (slice_id == 0) {
PRAGMA_UNROLL
for (int i = 0; i < WARP_N / OP_N; ++i) {
PRAGMA_UNROLL
for (int j = 0; j < WARP_M / OP_M; ++j) {
PRAGMA_UNROLL
for (int x = 0; x < 4; ++x) {
int src = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x;
int dst = (i * WARP_M / OP_M + j) * 4 + x;
tb_frag_C[dst] = tb_smem_C[src * WARP_SIZE + lane_id];
}
}
}
}
}
Array<half, 8> warp_frag_A_[2][WARP_M / OP_M];
Array<half, 4> warp_frag_B_[2][WARP_N / OP_N];
__device__ void run_v2(half* __restrict__ C,
const uint* __restrict__ A,
const half* __restrict__ B,
const half2* __restrict__ Q,
int M,
int N,
int K,
int output_op_idx)
{
static_assert(WARP_M % OP_N == 0);
float tb_frag_C[(WARP_N / OP_N) * (WARP_M / OP_M) * 4];
extern __shared__ uint8_t smem[];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id_m = warp_id % kWarpCountM;
const int warp_id_nk = warp_id / kWarpCountM;
const int warp_id_n = warp_id_nk % kWarpCountN;
const int warp_id_k = warp_id_nk / kWarpCountN;
const int warp_id_mn = warp_id_n * kWarpCountM + warp_id_m;
const int slice_id = warp_id_k;
const int cta_k = slice_id * SLICE_K; // sliced-k offset
const int cta_m = blockIdx.x * CTA_M;
const int cta_n = blockIdx.y * CTA_N;
// each slice has its own partition of smem
uint4* const tb_smem_A = (uint4*)(smem + IteratorA::kSmemByteSize * slice_id);
half* const tb_smem_B = (half*)(smem + IteratorA::kSmemByteSize * SLICES + IteratorB::kSmemByteSize * slice_id);
// [CTA_N / OP_N, CTA_M / OP_M, 4, WARP_SIZE], all mn fragments in CTA
float* const tb_smem_C = (float*)smem;
__shared__ typename IteratorQ::Storage tb_smem_Q_storage;
auto tb_smem_Q = tb_smem_Q_storage.data[slice_id];
IteratorA iter_A{A, tb_smem_A, M, K, cta_m, cta_k, warp_id_mn, lane_id};
IteratorQ iter_Q{Q, tb_smem_Q, M, K, cta_m, cta_k, warp_id_mn, lane_id};
IteratorB iter_B{B, tb_smem_B, K, N, cta_n, cta_k, warp_id_mn, lane_id};
const int offset_m = warp_id_m * WARP_M + lane_id;
WarpIterA warp_iter_A(iter_A.smem_, iter_Q.smem_, warp_id, lane_id, offset_m, cta_k);
WarpIterB warp_iter_B(iter_B.smem_int_ptr_, warp_id_n, lane_id, 0);
int gemm_iter = (K + CTA_K - 1) / CTA_K;
PRAGMA_UNROLL
for (int stage = 0; stage < STAGES - 1; ++stage, --gemm_iter) {
iter_A.prefetch_stage(gemm_iter > 0);
iter_Q.prefetch_stage(gemm_iter > 0);
iter_B.prefetch_stage(gemm_iter > 0);
__pipeline_commit();
}
clear(tb_frag_C);
__pipeline_wait_prior(STAGES - 2);
sync_slice(slice_id);
warp_iter_A.load(warp_frag_A_[0], 0);
warp_iter_B.load(warp_frag_B_[0], 0);
PRAGMA_NO_UNROLL
for (; gemm_iter > -STAGES + 1;) {
warp_mma(iter_A, iter_Q, iter_B, warp_iter_A, warp_iter_B, tb_frag_C, slice_id, gemm_iter);
}
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
if constexpr (SLICES > 1) {
sum_slices(tb_frag_C, tb_smem_C, warp_id_m, warp_id_n, lane_id, slice_id);
}
switch (output_op_idx) {
case 0:
store_accum<0>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id);
break;
case 1:
store_accum<1>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id);
break;
default:
return;
}
}
};
template<typename Gemm>
__global__ void gemm_s4_f16_nn(half* __restrict__ C,
const uint* __restrict__ A,
const half* __restrict__ B,
const half2* __restrict__ Q,
int M,
int N,
int K,
int output_op_idx)
{
Gemm{}.run_v2(C, A, B, Q, M, N, K, output_op_idx);
}
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <array>
#include <iomanip>
#include <sstream>
#include <string>
#include <vector>
namespace turbomind {
struct Metric {
int id;
bool feasible;
bool prefer;
std::array<int, 3> cta_shape;
std::array<int, 3> warp_shape;
int warps;
int stages;
int max_active_ctas;
float smem;
float cta_cnt_m;
float cta_cnt_n;
float cta_iter_k;
float grid_size;
int active_ctas;
float waves;
float waves1;
float occupancy;
float tile_efficiency;
float wave_efficiency;
float grid_a0;
float grid_b0;
float grid_a1;
float grid_b1;
float grid_mm;
float grid_sum;
float grid_norm;
float cta_sum;
float cta_wave;
int best;
float time;
int count;
};
inline void DumpMetrics(std::ostream& os, const std::vector<Metric>& metrics, const std::vector<int>& indices = {})
{
auto dump_shape = [](const std::array<int, 3>& shape) {
std::stringstream ss;
ss << std::setw(4) << shape[0] << std::setw(4) << shape[1] << std::setw(4) << shape[2];
return ss.str();
};
std::vector<std::tuple<std::string, int>> infos{
{"id", 4}, {"valid", 6}, {"cta_mnk", 14}, {"warp_mnk", 14}, {"warps", 6}, {"stages", 8},
{"smem", 8}, {"cta_cnt_m", 10}, {"cta_cnt_n", 10}, {"cta_iter_k", 11}, {"max_ctas", 9}, {"act_ctas", 10},
{"waves", 12}, {"waves1", 12}, {"occupancy", 12}, {"%tile", 10}, {"%wave", 10}, {"grid_a0", 12},
{"grid_b0", 12}, {"grid_a1", 12}, {"grid_b1", 12}, {"grid_mm", 12}, {"grid_sum", 12}, {"cta_cnt", 8},
{"cta_sum", 8}, {"cta_wave", 9}, {"grid_norm", 12}, {"time", 12}, {"best", 7}};
for (const auto& [name, width] : infos) {
os << std::setw(width) << name;
}
os << "\n";
for (size_t i = 0; i < metrics.size(); ++i) {
auto& metric = indices.empty() ? metrics[i] : metrics[indices[i]];
int c = 0;
os << std::setw(std::get<1>(infos[c++])) << metric.id;
os << std::setw(std::get<1>(infos[c++])) << metric.feasible;
os << std::setw(std::get<1>(infos[c++])) << dump_shape(metric.cta_shape);
os << std::setw(std::get<1>(infos[c++])) << dump_shape(metric.warp_shape);
os << std::setw(std::get<1>(infos[c++])) << metric.warps;
os << std::setw(std::get<1>(infos[c++])) << metric.stages;
os << std::setw(std::get<1>(infos[c++])) << metric.smem;
os << std::setw(std::get<1>(infos[c++])) << metric.cta_cnt_m;
os << std::setw(std::get<1>(infos[c++])) << metric.cta_cnt_n;
os << std::setw(std::get<1>(infos[c++])) << metric.cta_iter_k;
os << std::setw(std::get<1>(infos[c++])) << metric.max_active_ctas;
os << std::setw(std::get<1>(infos[c++])) << metric.active_ctas;
os << std::setw(std::get<1>(infos[c++])) << metric.waves;
os << std::setw(std::get<1>(infos[c++])) << metric.waves1;
os << std::setw(std::get<1>(infos[c++])) << metric.occupancy;
os << std::setw(std::get<1>(infos[c++])) << metric.tile_efficiency;
os << std::setw(std::get<1>(infos[c++])) << metric.wave_efficiency;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_a0;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_b0;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_a1;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_b1;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_mm;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_sum;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_size;
os << std::setw(std::get<1>(infos[c++])) << metric.cta_sum;
os << std::setw(std::get<1>(infos[c++])) << metric.cta_wave;
os << std::setw(std::get<1>(infos[c++])) << metric.grid_norm;
os << std::setw(std::get<1>(infos[c++])) << metric.time * 1000 / metric.count;
os << std::setw(std::get<1>(infos[c++])) << (metric.best ? "*" : "");
os << "\n";
}
}
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "common.h"
namespace turbomind {
template<int CTA_M,
int CTA_K,
int WARP_M,
int WARP_K,
int OP_M,
int OP_K,
int GROUP_SIZE,
int STAGES,
int kSizePerStageA,
int kSizePerStageQ>
struct WarpIteratorA {
static_assert(WARP_K % GROUP_SIZE == 0 || GROUP_SIZE % WARP_K == 0);
static constexpr int ITER_M = 32 / OP_M;
static constexpr int ITER_X = WARP_M / 32;
uint4 frag_A4_[ITER_X]; // 8 value per uint
half2 frag_Q_[ITER_X][4]; // 4 m8k8 tile along M, as WARP_M == 32
const uint4* smem_A_;
const half2* smem_Q_;
const int offset_m_;
const int offset_m_Q_;
int stage_{0};
int offset_A_{0};
int offset_Q_{0};
__device__ WarpIteratorA(uint4* smem_A, half2* smem_Q, int warp_id, int lane_id, int offset_m, int offset_k):
smem_A_(smem_A), smem_Q_(smem_Q), offset_m_(offset_m), offset_m_Q_(offset_m / 32 * 32 + lane_id / 4)
{
}
// iter_k must be a compile tile constant
__device__ void load(Array<half, 8>* data, int iter_k)
{
// load A
// smem_A uint4 [SLICE_K/32, CTA_M/32, WARP_SIZE], load as uint4 to avoid bank-conflicts
if (iter_k % 2 == 0) {
PRAGMA_UNROLL
for (int x = 0; x < ITER_X; ++x) {
frag_A4_[x] = smem_A_[offset_A_ + (iter_k / 2) * CTA_M + x * 32 + offset_m_];
}
}
// load Q
if (iter_k * OP_K % GROUP_SIZE == 0) {
const int g = iter_k * OP_K / GROUP_SIZE;
PRAGMA_UNROLL
for (int x = 0; x < ITER_X; ++x) {
PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
const int mm = offset_m_Q_ + x * 32 + i * 8; // stride of m8k8 tile
((uint&)frag_Q_[x][i]) = ((uint&)smem_Q_[offset_Q_ + g * CTA_M + mm]);
}
}
}
PRAGMA_UNROLL
for (int x = 0; x < ITER_X; ++x) {
const uint* frag_A = (uint*)&frag_A4_[x];
PRAGMA_UNROLL
for (int iter_m = 0; iter_m < ITER_M; ++iter_m) {
uint4 tmp = dequantize_s4_to_fp16x2_v2(frag_A[iter_k % 2 * 2 + iter_m]);
auto& vec = (Array<half2, 4>&)tmp;
vec[0] = apply_Q(vec[0], frag_Q_[x][iter_m * 2]);
vec[1] = apply_Q(vec[1], frag_Q_[x][iter_m * 2 + 1]);
vec[2] = apply_Q(vec[2], frag_Q_[x][iter_m * 2]);
vec[3] = apply_Q(vec[3], frag_Q_[x][iter_m * 2 + 1]);
data[x * ITER_M + iter_m] = (Array<half, 8>&)vec;
}
}
}
__device__ void next_stage()
{
++stage_;
if (stage_ >= STAGES) {
stage_ = 0;
}
offset_A_ = stage_ * kSizePerStageA;
offset_Q_ = stage_ * kSizePerStageQ;
}
};
template<int CTA_N, int CTA_K, int WARP_N, int WARP_K, int OP_N, int OP_K, int SMEM_STRIDE, int STAGES>
struct WarpIteratorB {
static constexpr int kLdsmNum = WARP_N == 8 ? 2 : 4;
static constexpr int ITER_N = WARP_N / OP_N;
static constexpr int ITER_K = WARP_K / OP_K;
static_assert(OP_N == 8 && OP_K == 16);
const int warp_id_n_;
const int lane_id_;
const int ldsm_group_id_;
const int offset_k_;
int offset_n_;
const uint32_t smem_base_ptr_;
uint32_t smem_ptr_;
int stage_{0};
__device__ WarpIteratorB(uint32_t smem_int_ptr, int warp_id_n, int lane_id, int offset_k):
smem_base_ptr_(smem_int_ptr),
smem_ptr_(smem_base_ptr_),
warp_id_n_(warp_id_n),
lane_id_(lane_id),
ldsm_group_id_(lane_id / 8),
offset_k_(ldsm_group_id_ % 2 * 8 + offset_k),
offset_n_(ldsm_group_id_ / 2 * 8 + lane_id % 8)
{
if (kLdsmNum == 2) {
offset_n_ -= ldsm_group_id_ / 2 * 8;
}
offset_n_ += warp_id_n_ * WARP_N;
}
__device__ void load(Array<half, 4>* data, int iter_k)
{
const int kk = iter_k * OP_K + offset_k_;
auto ptr = (uint*)data;
PRAGMA_UNROLL
for (int iter_n = 0; iter_n < ITER_N;) {
const int nn = offset_n_ + iter_n * OP_N;
auto src = smem_ptr_ + sizeof(half) * (nn * SMEM_STRIDE + kk);
if constexpr (kLdsmNum == 4) {
ldmatrix_m8n8_x4_b16(ptr[0], ptr[1], ptr[2], ptr[3], src);
ptr += 4;
iter_n += 2;
}
else {
ldmatrix_m8n8_x2_b16(ptr[0], ptr[1], src);
ptr += 2;
iter_n += 1;
}
}
}
__device__ void next_stage()
{
++stage_;
if (stage_ >= STAGES) {
stage_ = 0;
}
smem_ptr_ = smem_base_ptr_ + stage_ * sizeof(half) * CTA_N * SMEM_STRIDE;
}
};
} // namespace turbomind
...@@ -21,6 +21,7 @@ add_library(Llama STATIC ...@@ -21,6 +21,7 @@ add_library(Llama STATIC
set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(Llama PUBLIC -lcudart target_link_libraries(Llama PUBLIC -lcudart
gemm_s4_f16
cublasMMWrapper cublasMMWrapper
DynamicDecodeLayer DynamicDecodeLayer
activation_kernels activation_kernels
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc // https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc
#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h" #include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/memory_utils.h" #include "src/turbomind/utils/memory_utils.h"
#include <filesystem> #include <filesystem>
...@@ -31,6 +32,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num, ...@@ -31,6 +32,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
WeightType weight_type, WeightType weight_type,
int group_size,
bool attn_bias, bool attn_bias,
size_t tensor_para_size, size_t tensor_para_size,
size_t tensor_para_rank): size_t tensor_para_rank):
...@@ -47,22 +49,32 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num, ...@@ -47,22 +49,32 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
self_attn_weights.qkv.input_dims = hidden_units_; self_attn_weights.qkv.input_dims = hidden_units_;
self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_; self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_;
self_attn_weights.qkv.type = weight_type; self_attn_weights.qkv.type = weight_type;
self_attn_weights.qkv.group_size = group_size;
self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_; self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_;
self_attn_weights.output.output_dims = hidden_units_; self_attn_weights.output.output_dims = hidden_units_;
self_attn_weights.output.type = weight_type; self_attn_weights.output.type = weight_type;
self_attn_weights.output.group_size = group_size;
ffn_weights.gating.input_dims = hidden_units_; ffn_weights.gating.input_dims = hidden_units_;
ffn_weights.gating.output_dims = inter_size_ / tensor_para_size_; ffn_weights.gating.output_dims = inter_size_ / tensor_para_size_;
ffn_weights.gating.type = weight_type; ffn_weights.gating.type = weight_type;
ffn_weights.gating.group_size = group_size;
ffn_weights.intermediate.input_dims = hidden_units_; ffn_weights.intermediate.input_dims = hidden_units_;
ffn_weights.intermediate.output_dims = inter_size_ / tensor_para_size_; ffn_weights.intermediate.output_dims = inter_size_ / tensor_para_size_;
ffn_weights.intermediate.type = weight_type; ffn_weights.intermediate.type = weight_type;
ffn_weights.intermediate.group_size = group_size;
ffn_weights.fused_gating_intermediate.input_dims = hidden_units_;
ffn_weights.fused_gating_intermediate.output_dims = inter_size_ / tensor_para_size_ * 2;
ffn_weights.fused_gating_intermediate.type = weight_type;
ffn_weights.fused_gating_intermediate.group_size = group_size;
ffn_weights.output.input_dims = inter_size_ / tensor_para_size_; ffn_weights.output.input_dims = inter_size_ / tensor_para_size_;
ffn_weights.output.output_dims = hidden_units_; ffn_weights.output.output_dims = hidden_units_;
ffn_weights.output.type = weight_type; ffn_weights.output.type = weight_type;
ffn_weights.output.group_size = group_size;
mallocWeights(); mallocWeights();
} }
...@@ -71,13 +83,11 @@ void freeWeights(LlamaDenseWeight<T>& weights) ...@@ -71,13 +83,11 @@ void freeWeights(LlamaDenseWeight<T>& weights)
{ {
cudaFree(weights.kernel); cudaFree(weights.kernel);
cudaFree(weights.bias); cudaFree(weights.bias);
cudaFree(weights.scales); cudaFree(weights.scales_and_zeros);
cudaFree(weights.zeros);
weights.kernel = nullptr; weights.kernel = nullptr;
weights.bias = nullptr; weights.bias = nullptr;
weights.scales = nullptr; weights.scales_and_zeros = nullptr;
weights.zeros = nullptr;
} }
template<typename T> template<typename T>
...@@ -93,9 +103,10 @@ void mallocWeights(LlamaDenseWeight<T>& weights, bool bias) ...@@ -93,9 +103,10 @@ void mallocWeights(LlamaDenseWeight<T>& weights, bool bias)
else { // int8, int4 else { // int8, int4
const int factor = sizeof(float) * 8 / bit_size; const int factor = sizeof(float) * 8 / bit_size;
FT_CHECK(weights.input_dims % factor == 0); FT_CHECK(weights.input_dims % factor == 0);
deviceMalloc((float**)&weights.kernel, weights.input_dims / factor * weights.output_dims); deviceMalloc((int**)&weights.kernel, weights.input_dims * weights.output_dims / factor);
deviceMalloc((T**)&weights.scales, weights.output_dims); deviceMemSetZero((int*)weights.kernel, weights.input_dims * weights.output_dims / factor);
deviceMalloc((T**)&weights.zeros, weights.output_dims); // interleaved scales/zeros
deviceMalloc((T**)&weights.scales_and_zeros, weights.input_dims / weights.group_size * weights.output_dims * 2);
} }
} }
...@@ -195,40 +206,15 @@ void loadWeights(LlamaDenseWeight<T>& w, ...@@ -195,40 +206,15 @@ void loadWeights(LlamaDenseWeight<T>& w,
} }
else { // int8, int4 else { // int8, int4
const int factor = sizeof(float) * 8 / bit_size; const int factor = sizeof(float) * 8 / bit_size;
FT_CHECK(dim0 % factor == 0);
const auto f32_type = FtCudaDataType::FP32;
std::vector<ConcateSlice> weight_slices{};
std::vector<ConcateSlice> bias_slices{};
if (enable_slice) {
if (slice_dim == 1) {
size_t start = 0;
ConcateSlice slice0{.slices = {{0, dim0}}};
ConcateSlice slice1{.slices = {{}}};
for (auto len : slice_shape) {
size_t stride = len / tensor_para_size;
slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
start += len;
}
weight_slices = {slice0, slice1};
ConcateSlice bias_slice0{.slices = {{0, 1}}}; FT_CHECK(dim1 % factor == 0);
bias_slices = {bias_slice0, slice1};
} std::vector<size_t> w_shape{dim0, dim1 / factor * sizeof(uint32_t)};
else { loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8, {});
size_t start = 0;
ConcateSlice slice0{.slices = {}}; const size_t group_count = w.group_size > 0 ? dim0 / w.group_size : 1;
ConcateSlice slice1{.slices = {{0, dim1}}};
for (auto len : slice_shape) { loadWeightFromBin((half*)w.scales_and_zeros, {group_count, dim1 * 2}, prefix + ".scales_zeros", type, {});
size_t stride = len / factor / tensor_para_size;
slice0.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
start += len;
}
weight_slices = {slice0, slice1};
}
}
loadWeightFromBin((float*)w.kernel, {dim0 / factor, dim1}, prefix + ".qweight", f32_type, weight_slices);
loadWeightFromBin((T*)w.scales, {1, dim1}, prefix + ".scales", type, bias_slices);
loadWeightFromBin((T*)w.zeros, {1, dim1}, prefix + ".zeros", type, bias_slices);
} }
} }
...@@ -241,8 +227,14 @@ void LlamaDecoderLayerWeight<T>::mallocWeights() ...@@ -241,8 +227,14 @@ void LlamaDecoderLayerWeight<T>::mallocWeights()
turbomind::mallocWeights(self_attn_weights.qkv, attn_bias_); turbomind::mallocWeights(self_attn_weights.qkv, attn_bias_);
turbomind::mallocWeights(self_attn_weights.output, attn_bias_); turbomind::mallocWeights(self_attn_weights.output, attn_bias_);
turbomind::mallocWeights(ffn_weights.gating, false); if (weight_type_ == WeightType::kINT4) {
turbomind::mallocWeights(ffn_weights.intermediate, false); turbomind::mallocWeights(ffn_weights.fused_gating_intermediate, false);
}
else {
turbomind::mallocWeights(ffn_weights.gating, false);
turbomind::mallocWeights(ffn_weights.intermediate, false);
}
turbomind::mallocWeights(ffn_weights.output, false); turbomind::mallocWeights(ffn_weights.output, false);
} }
...@@ -254,8 +246,15 @@ LlamaDecoderLayerWeight<T>::~LlamaDecoderLayerWeight() ...@@ -254,8 +246,15 @@ LlamaDecoderLayerWeight<T>::~LlamaDecoderLayerWeight()
freeWeights(self_attn_weights.qkv); freeWeights(self_attn_weights.qkv);
freeWeights(self_attn_weights.output); freeWeights(self_attn_weights.output);
freeWeights(ffn_weights.gating);
freeWeights(ffn_weights.intermediate); if (weight_type_ == WeightType::kINT4) {
freeWeights(ffn_weights.fused_gating_intermediate);
}
else {
freeWeights(ffn_weights.gating);
freeWeights(ffn_weights.intermediate);
}
freeWeights(ffn_weights.output); freeWeights(ffn_weights.output);
} }
...@@ -276,9 +275,22 @@ void LlamaDecoderLayerWeight<T>::loadModel(std::string dir_path, FtCudaDataType ...@@ -276,9 +275,22 @@ void LlamaDecoderLayerWeight<T>::loadModel(std::string dir_path, FtCudaDataType
tensor_para_size_, tensor_para_size_,
1, 1,
{head_num_ * size_per_head_, kv_head_num_ * size_per_head_, kv_head_num_ * size_per_head_}); {head_num_ * size_per_head_, kv_head_num_ * size_per_head_, kv_head_num_ * size_per_head_});
loadWeights(self_attn_weights.output, dir_path + ".attention.wo", tensor_para_rank_, type, tensor_para_size_, 0); loadWeights(self_attn_weights.output, dir_path + ".attention.wo", tensor_para_rank_, type, tensor_para_size_, 0);
loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type, tensor_para_size_, 1);
loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type, tensor_para_size_, 1); if (weight_type_ == WeightType::kINT4) {
loadWeights(ffn_weights.fused_gating_intermediate,
dir_path + ".feed_forward.w13",
tensor_para_rank_,
type,
tensor_para_size_,
1);
}
else {
loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type, tensor_para_size_, 1);
loadWeights(
ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type, tensor_para_size_, 1);
}
loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_, 0); loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_, 0);
// load kv_cache quant scale // load kv_cache quant scale
......
...@@ -33,6 +33,7 @@ public: ...@@ -33,6 +33,7 @@ public:
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
WeightType weight_type, WeightType weight_type,
int group_size,
bool attn_bias, bool attn_bias,
size_t tensor_para_size, size_t tensor_para_size,
size_t tensor_para_rank); size_t tensor_para_rank);
......
...@@ -48,18 +48,18 @@ inline size_t getBitSize(WeightType type) ...@@ -48,18 +48,18 @@ inline size_t getBitSize(WeightType type)
case WeightType::kINT4: case WeightType::kINT4:
return 4; return 4;
} }
return 0;
} }
template<typename T> template<typename T>
struct LlamaDenseWeight { struct LlamaDenseWeight {
size_t input_dims; size_t input_dims;
size_t output_dims; size_t output_dims;
void* kernel; void* kernel;
WeightType type; WeightType type;
T* bias; T* bias;
T* scales; T* scales_and_zeros;
T* zeros; int group_size;
}; };
template<typename T> template<typename T>
...@@ -74,6 +74,7 @@ struct LlamaFfnWeight { ...@@ -74,6 +74,7 @@ struct LlamaFfnWeight {
LlamaDenseWeight<T> gating; LlamaDenseWeight<T> gating;
LlamaDenseWeight<T> intermediate; LlamaDenseWeight<T> intermediate;
LlamaDenseWeight<T> output; LlamaDenseWeight<T> output;
LlamaDenseWeight<T> fused_gating_intermediate;
}; };
} // namespace turbomind } // namespace turbomind
...@@ -85,12 +85,16 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors, ...@@ -85,12 +85,16 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,
T* ffn_output_data = output_tensors->at("ffn_output").getPtr<T>(); T* ffn_output_data = output_tensors->at("ffn_output").getPtr<T>();
PUSH_RANGE("ffn"); PUSH_RANGE("ffn");
// TODO: fuse the two GEMMs with activation
linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating);
linear_.forward(inter_buf_, ffn_input_data, num_token, weights->intermediate); if (weights->fused_gating_intermediate.kernel) {
linear_.forward(
activation(num_token); gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear<T>::kFusedSiluFfn);
}
else {
linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating);
linear_.forward(inter_buf_, ffn_input_data, num_token, weights->intermediate);
activation(num_token);
}
linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output); linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output);
POP_RANGE; POP_RANGE;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment