Commit 8efb9210 authored by zhouxiang's avatar zhouxiang
Browse files

Merge branch 'dtk24.04-v0.2.6_awq' into 'dtk24.04-v0.2.6'

合入dtk2404-v0.2.6版本int4量化推理部分

See merge request dcutoolkit/deeplearing/lmdeploy!2
parents 2326380c 175eaedb
......@@ -5,7 +5,7 @@ __pycache__/
.vscode/
.idea/
# C extensions
*.so
#*.so
# Distribution / packaging
.Python
......
该处so主要用于awq功能使用
\ No newline at end of file
......@@ -366,7 +366,7 @@ add_library(transformer-shared SHARED
# $<TARGET_OBJECTS:flash_attention2>
$<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:LlamaTritonBackend>
# $<TARGET_OBJECTS:gemm_s4_f16>
$<TARGET_OBJECTS:gemm_s4_f16>
$<TARGET_OBJECTS:TopKSamplingLayer>
$<TARGET_OBJECTS:TopPSamplingLayer>
$<TARGET_OBJECTS:TransformerTritonBackend>
......
......@@ -61,7 +61,11 @@ class CLI(object):
default=0,
help='A parameter used in awq to quantize fp16 weights '
'to 4 bits')
parser.add_argument(
'--w4-weight-layout',
type=int,
default=2,
help='A parameter used in AWQ to control the layout of weight ')
parser.set_defaults(run=CLI.convert)
@staticmethod
......
......@@ -196,6 +196,7 @@ def main(model_name: str,
tp: int = 1,
quant_path: str = None,
group_size: int = 0,
w4_weight_layout: int = 2,
**kwargs):
"""deploy llama family models via turbomind.
......@@ -215,6 +216,7 @@ def main(model_name: str,
quant_path (str): Path of the quantized model, which can be None.
group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits
w4_weight_layout (int) :a parameter used in AWQ to control the layout of weight
kwargs (dict): other params for convert
"""
......@@ -260,10 +262,13 @@ def main(model_name: str,
cfg.tensor_para_size = tp
cfg.rotary_embedding = cfg.size_per_head
cfg.group_size = group_size
cfg.w4_weight_layout=w4_weight_layout
if inferred_model_format.find('awq') != -1:
cfg.weight_type = 'int4'
output_format = 'w4'
assert group_size > 0, f'group_size: {group_size} should > 0'
print("w4_weight_layout:",w4_weight_layout)
assert w4_weight_layout>=0 and w4_weight_layout<3,f'w4_weight_layout: {w4_weight_layout} should >= 0 and < 3'
else:
#output_format = update_output_format(model_name, inferred_model_format,
# model_path, output_format)
......
......@@ -5,6 +5,7 @@ import inspect
import io
import json
import os.path as osp
import os
from abc import ABC, abstractmethod
from configparser import ConfigParser
......@@ -52,6 +53,7 @@ class TurbomindModelConfig:
rope_theta: float = 10000.0
size_per_head: int = 128
group_size: int = 0
w4_weight_layout : int = 2
max_batch_size: int = 64
max_context_token_num: int = 1
step_length: int = 1
......@@ -150,6 +152,12 @@ class BaseOutputModel(ABC):
self.to_file = to_file
self.out_dir = out_dir
self.tm_params = {}
#self.weight_layout= 1
#获取环境变量
#env_weight_layout = os.environ.get('LMDEPLOY_WEIGHTLAYOUT_SWITCH', '1')
#self.weight_layout =int(env_weight_layout)
#print("self.weight_layout:",self.weight_layout)
@abstractmethod
def get_config(self, cfg: TurbomindModelConfig) -> TurbomindModelConfig:
......@@ -317,6 +325,10 @@ def permute(x: torch.Tensor, size_per_head: int = 128):
return x.view(n_heads, 2, dim // n_heads // 2,
1).transpose(1, 2).reshape(dim, 1)
def permute_trans(x: torch.Tensor):
if x.shape[-1]>1:
dim = x.shape[-1]
return x.view(-1, x.shape[-1]).transpose(0, 1).reshape(dim,-1)
def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int,
dim: int):
......
......@@ -8,7 +8,7 @@ import lmdeploy
from ..source_model.base import BaseInputModel, BaseReader
from .base import (OUTPUT_MODELS, BaseOutputModel, TurbomindModelConfig,
merge_qkv, permute)
merge_qkv, permute,permute_trans)
# import _turbomind as _tm
# TODO: find another way import _turbomind
......@@ -56,6 +56,18 @@ def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
qw.size(-1) * 8, qw.size(0), group_size)
return _qw, _sz
def convert_s4_(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size: int):
assert qw.is_contiguous()
assert qz.is_contiguous()
assert s.is_contiguous()
_qw = torch.zeros_like(qw)
_sz = torch.zeros_like(s, dtype=torch.int32) # half2
_ws = torch.zeros_like(s)
_tm.convert_s4_k_m8_(_qw, _sz, _ws, qw, s, qz,
qw.size(-1) * 8, qw.size(0), group_size)
return _qw, _sz
def tp_m_s4(x: torch.Tensor, tp: int):
return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
......@@ -104,6 +116,7 @@ class TurbomindW4Model(BaseOutputModel):
"""Export transformer layer i."""
group_size = self.cfg.group_size
tp = self.cfg.tensor_para_size
w4_weight_layout = self.cfg.w4_weight_layout
size_per_head = self.cfg.size_per_head
# attn
q_qw, k_qw, v_qw, o_qw = get_cuda_tensor(bin.attn(i))
......@@ -121,12 +134,45 @@ class TurbomindW4Model(BaseOutputModel):
qkv_qz = merge_qkv(q_qz, k_qz, v_qz, tp, dim=2)
qkv_s = merge_qkv(q_s, k_s, v_s, tp, dim=2)
qkv_qw, qkv_sz = convert_s4(qkv_qw, qkv_qz, qkv_s, group_size)
qkv_qw = tp_m_s4(qkv_qw, tp)
pad_group_count=2
if w4_weight_layout==1 or w4_weight_layout==2:
if qkv_qw.shape[0]%4096==0:
qkv_qw_padding=torch.zeros(group_size*pad_group_count,qkv_qw.shape[1],dtype=torch.int32).cuda()
qkv_qw =torch.cat((qkv_qw,qkv_qw_padding),dim=0).contiguous()
qkv_qz_padding =torch.zeros(pad_group_count,qkv_qz.shape[1],dtype=torch.int32).cuda()
qkv_qz =torch.cat((qkv_qz,qkv_qz_padding),dim=0).contiguous()
qkv_s_padding =torch.zeros(pad_group_count,qkv_s.shape[1],dtype=torch.float16).cuda()
qkv_s =torch.cat((qkv_s,qkv_s_padding),dim=0).contiguous()
qkv_qw, qkv_sz = convert_s4_(qkv_qw, qkv_qz, qkv_s, group_size)
qkv_qw = tp_m_s4(qkv_qw, tp)
qkv_sz = permute_trans(qkv_sz)
else:
qkv_qw, qkv_sz = convert_s4(qkv_qw, qkv_qz, qkv_s, group_size)
qkv_qw = tp_m_s4(qkv_qw, tp)
#print("请设置weight layout\n")
self.save_split(qkv_qw, f'layers.{i}.attention.w_qkv.qweight', -1)
self.save_split(qkv_sz, f'layers.{i}.attention.w_qkv.scales_zeros', -1)
o_qw, o_sz = convert_s4(o_qw, o_qz, o_s, group_size)
if w4_weight_layout==1 or w4_weight_layout==2:
if o_qw.shape[0]%4096==0:
o_qw_padding=torch.zeros(group_size*pad_group_count,o_qw.shape[1],dtype=torch.int32).cuda()
o_qw =torch.cat((o_qw,o_qw_padding),dim=0).contiguous()
o_qz_padding =torch.zeros(pad_group_count,o_qz.shape[1],dtype=torch.int32).cuda()
o_qz =torch.cat((o_qz,o_qz_padding),dim=0).contiguous()
o_s_padding =torch.zeros(pad_group_count,o_s.shape[1],dtype=torch.float16).cuda()
o_s =torch.cat((o_s,o_s_padding),dim=0).contiguous()
o_qw, o_sz = convert_s4_(o_qw, o_qz, o_s, group_size)
o_sz = permute_trans(o_sz)
else:
o_qw, o_sz = convert_s4(o_qw, o_qz, o_s, group_size)
self.save_split(o_qw, f'layers.{i}.attention.wo.qweight', 0)
self.save_split(o_sz, f'layers.{i}.attention.wo.scales_zeros', 0)
......@@ -145,13 +191,45 @@ class TurbomindW4Model(BaseOutputModel):
w13_qw, w13_qz, w13_s = fuse_w1_w3_s4(w1_qw, w1_qz, w1_s, w3_qw, w3_qz,
w3_s)
w13_qw, w13_sz = convert_s4(w13_qw, w13_qz, w13_s, group_size)
w13_qw = tp_m_s4(w13_qw, tp)
if w4_weight_layout==1 or w4_weight_layout==2:
if w13_qw.shape[0]%4096==0:
w13_qw_padding=torch.zeros(group_size*pad_group_count,w13_qw.shape[1],dtype=torch.int32).cuda()
w13_qw =torch.cat((w13_qw,w13_qw_padding),dim=0).contiguous()
w13_qz_padding =torch.zeros(pad_group_count,w13_qz.shape[1],dtype=torch.int32).cuda()
w13_qz =torch.cat((w13_qz,w13_qz_padding),dim=0).contiguous()
w13_s_padding =torch.zeros(pad_group_count,w13_s.shape[1],dtype=torch.float16).cuda()
w13_s =torch.cat((w13_s,w13_s_padding),dim=0).contiguous()
w13_qw, w13_sz = convert_s4_(w13_qw, w13_qz, w13_s, group_size)
w13_qw = tp_m_s4(w13_qw, tp)
w13_sz = permute_trans(w13_sz)
else:
w13_qw, w13_sz = convert_s4(w13_qw, w13_qz, w13_s, group_size)
w13_qw = tp_m_s4(w13_qw, tp)
self.save_split(w13_qw, f'layers.{i}.feed_forward.w13.qweight', -1)
self.save_split(w13_sz, f'layers.{i}.feed_forward.w13.scales_zeros',
-1)
w2_qw, w2_sz = convert_s4(w2_qw, w2_qz, w2_s, group_size)
if w4_weight_layout==1 or w4_weight_layout==2:
#pading
if w2_qw.shape[0]%4096==0:
w2_qw_padding=torch.zeros(group_size*pad_group_count,w2_qw.shape[1],dtype=torch.int32).cuda()
w2_qw =torch.cat((w2_qw,w2_qw_padding),dim=0).contiguous()
w2_qz_padding =torch.zeros(pad_group_count,w2_qz.shape[1],dtype=torch.int32).cuda()
w2_qz =torch.cat((w2_qz,w2_qz_padding),dim=0).contiguous()
w2_s_padding =torch.zeros(pad_group_count,w2_s.shape[1],dtype=torch.float16).cuda()
w2_s =torch.cat((w2_s,w2_s_padding),dim=0).contiguous()
w2_qw, w2_sz = convert_s4_(w2_qw, w2_qz, w2_s, group_size)
w2_sz = permute_trans(w2_sz)
else:
w2_qw, w2_sz = convert_s4(w2_qw, w2_qz, w2_s, group_size)
self.save_split(w2_qw, f'layers.{i}.feed_forward.w2.qweight', 0)
self.save_split(w2_sz, f'layers.{i}.feed_forward.w2.scales_zeros', 0)
......
......@@ -147,6 +147,7 @@ class TurboMind:
model_name: Optional[str] = None,
model_format: Optional[str] = None,
group_size: Optional[int] = None,
w4_weight_layout: Optional[int] = None,
tp: Optional[int] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
**kwargs):
......@@ -179,6 +180,7 @@ class TurboMind:
engine_config = _update_engine_config(engine_config,
model_format=model_format,
group_size=group_size,
w4_weight_layout=w4_weight_layout,
tp=tp,
**kwargs)
......@@ -304,6 +306,7 @@ class TurboMind:
output_format = 'w4'
data_type = 'int4'
cfg.group_size = 128
cfg.w4_weight_layout=2
else:
# output_format = update_output_format(cfg.model_name,
# inferred_model_format,
......@@ -378,6 +381,7 @@ class TurboMind:
self.config = cfg
self.model_name = cfg.model_name
self.data_type = cfg.weight_type
#print("from_workspace_cfg:",cfg)
# create model
logger.warning(f'model_config:\n\n{cfg.toini()}')
......
......@@ -8,6 +8,7 @@ import subprocess
from typing import Optional, Union
from pathlib import Path
import torch
import shutil
pwd = os.path.dirname(__file__)
version_file = 'lmdeploy/version.py'
......@@ -69,7 +70,6 @@ def get_version_add(sha: Optional[str] = None) -> str:
file.writelines(lines)
file.close()
def get_version():
get_version_add()
version_file = 'lmdeploy/version.py'
......@@ -185,9 +185,24 @@ def parse_requirements(fname='requirements.txt', with_version=True):
packages += cuda_pkgs
return packages
def copy_ck_so():
lmdeploy_root = os.path.dirname(os.path.abspath(__file__))
so_path = os.path.join(os.path.join(lmdeploy_root, "3rdparty","composable_kernel"), "libgemm_multiB_int4.so")
# dtk version
target_path=os.path.join(lmdeploy_root, "lmdeploy","lib")
if os.path.exists(target_path):
shutil.copy(so_path, target_path)
elif os.getenv("ROCM_PATH"):
rocm_path = os.getenv('ROCM_PATH', "")
rocm_so_path = os.path.join(rocm_path, 'lib')
print("rocm_so_path:",rocm_so_path)
shutil.copy(so_path, rocm_so_path)
else:
shutil.copy(so_path, "usr/local/lib")
if __name__ == '__main__':
lmdeploy_package_data = ['lmdeploy/bin/llama_gemm']
copy_ck_so()
setup(
name='lmdeploy',
version=get_version(),
......
......@@ -72,5 +72,5 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu)
#set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#add_subdirectory(gemm_s_f16)
add_subdirectory(gemm_s_f16)
add_subdirectory(decoder_multihead_attention)
# Copyright (c) OpenMMLab. All rights reserved.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
add_library(gemm_s4_f16 STATIC gemm_s4_f16.cu format.cu)
target_compile_options(gemm_s4_f16 PRIVATE
......
......@@ -72,19 +72,23 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[0])
// : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
h[0]=(i4s & BOTTOM_MASK)|I4s_TO_F16s_MAGIC_NUM;
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[1])
// : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
h[1]=(i4s & TOP_MASK)|I4s_TO_F16s_MAGIC_NUM;
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[2])
// : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
h[2]=(top_i4s & BOTTOM_MASK)|I4s_TO_F16s_MAGIC_NUM;
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
// asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[3])
// : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
printf("=========common.h 86\n");
h[3]=(top_i4s & TOP_MASK)|I4s_TO_F16s_MAGIC_NUM;
// I use inline PTX below because I am not sure if the compiler will emit
// float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability.
......@@ -102,14 +106,17 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// Finally, we construct the output numbers.
// Convert elt_01
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// // Convert elt_23
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// // Convert elt_45
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// // Convert elt_67
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
//asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
h[0]=h[0]-FP16_TOP_MAGIC_NUM;
// Convert elt_23
//asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
h[1]=h[1]*ONE_SIXTEENTH+NEG_64;
// Convert elt_45
//asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
h[2]=h[2]-FP16_TOP_MAGIC_NUM;
// Convert elt_67
//asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
h[3]=h[3]*ONE_SIXTEENTH+NEG_64;
return result;
}
......@@ -131,31 +138,22 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8;
printf("=========common.h 133\n");
// if (0) { // 1024 & 64
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MAGIC_NUM_0));
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MAGIC_NUM_1));
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MAGIC_NUM_0));
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MAGIC_NUM_1));
// }
// else { // 64 only, trade 4 hfma2 with 2 shifts
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
// h[0] <<= 4;
// h[2] <<= 4;
// // we don't need to subtract the magic nums because zeros will go through the same dequant function
// // and carry the same magic constant, the magic num will be canceled out after subtracting zeros
// }
// 64 only, trade 4 hfma2 with 2 shifts
h[0] =(i4s & BOT_MASK) |MAGIC_NUM_2;
h[1] =(i4s & TOP_MASK) |MAGIC_NUM_1;
h[2] =(top_i4s & BOT_MASK) |MAGIC_NUM_2;
h[3] =(top_i4s & TOP_MASK) |MAGIC_NUM_1;
h[0] <<= 4;
h[2] <<= 4;
// we don't need to subtract the magic nums because zeros will go through the same dequant function
// and carry the same magic constant, the magic num will be canceled out after subtracting zeros
return result;
}
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr)
{
uint32_t smem_int_ptr;
......@@ -220,12 +218,12 @@ __inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t sme
__inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
{
uint s, z;
(half2&)z = __halves2half2(q.x, q.x);
(half2&)s = __halves2half2(q.y, q.y);
//uint s, z;
//(half2&)z = __halves2half2(q.x, q.x);
//(half2&)s = __halves2half2(q.y, q.y);
auto& t = (const uint&)x;
uint u, v;
//auto& t = (const uint&)x;
uint v;
// if (TURBOMIND_S4_DEQUANT_USE_FMA) {
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z));
// }
......@@ -233,7 +231,7 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
// asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z));
// asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s));
// }
printf("=========common.h 235\n");
return (half2&)v;
}
......
......@@ -2,6 +2,7 @@
#include "common.h"
#include <iostream>
#define BLOCKSIZE 256
namespace turbomind {
......@@ -71,7 +72,17 @@ void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStre
// permutation for [k, m/8] layout
Array<int, 10> shape{k / 32, 2, 2, 4, 2, m / 32, 2, 2, 2, 4};
// |warp| lane | 2x2 | a0-7 |
permute_u4<0, 5, 9, 8, 3, 1, 6, 4, 2, 7><<<512, 512, 0, st>>>(dst, src, shape);
//permute_u4<0, 5, 9, 8, 3, 1, 6, 4, 2, 7><<<512, 512, 0, st>>>(dst, src, shape);
permute_u4<0, 1, 2, 3, 4, 5, 6, 7, 8, 9><<<512, 512, 0, st>>>(dst, src, shape);
}
void reformat_s4_k_m8_tarnsw4(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
{
// permutation for [k, m/8] layout
Array<int, 10> shape{1, k / 8, 2, 2, 2, 1, m / 8, 2, 2, 2};
// 0123456-->4,6,7,5,0,3,1,2
//permute_u4<4, 6, 7, 5, 0, 3, 1, 2><<<512, 512, 0, st>>>(dst, src, shape);
permute_u4<5, 6, 8, 9, 7, 0, 1, 4, 2, 3><<<512, 512, 0, st>>>(dst, src, shape);
}
__global__ void dequantize_s4_offset_64(uint4* dst, const uint32_t* src, size_t count)
......@@ -112,6 +123,22 @@ void convert_s4_k_m8(uint32_t* A_dst,
reformat_s4_k_m8(A_dst, A_src, m, k, st);
}
void convert_s4_k_m8_(uint32_t* A_dst,
half2* Q_dst,
half* workspace,
const uint32_t* A_src,
const half* scales,
const uint32_t* qzeros,
int m,
int k,
int group_size,
cudaStream_t st)
{
dequantize_s4_offset_64<<<256, 256, 0, st>>>((uint4*)workspace, qzeros, k / group_size * m / 8);
merge_Q<<<256, 256, 0, st>>>(Q_dst, scales, workspace, k / group_size * m);
reformat_s4_k_m8_tarnsw4(A_dst, A_src, m, k, st);
}
void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st)
{
Array<int, 7> shape{k, m / size_per_head, 2, size_per_head / 2 / 8, 2, 2, 2};
......@@ -140,5 +167,118 @@ void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t s
{
dequantize_s4_kernel<<<512, 512>>>(dst, src, count);
}
__global__ void dequant_kernel(int num_kernels,half* weight ,const half2* zeros_and_scales,int k,int n,int group_size)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int j=id%n;
int i=id/n;
half x=zeros_and_scales[i/group_size*n+j].data[0];
half y= zeros_and_scales[i/group_size*n+j].data[1];
float tmp=(weight[id]-x)*y;
weight[id]=__float2half(tmp);
}
__global__ void dequant_kernel_colmajor(int num_kernels,half* weight ,const half2* zeros_and_scales,int k,int n,int group_size)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int j=id/group_size;
half x=zeros_and_scales[j].data[0];
half y= zeros_and_scales[j].data[1];
float tmp=(weight[id]-x)*y;
weight[id]=__float2half(tmp);
}
void dequant_w4_gemm(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size)
{
dequantize_s4_offset_64<<<256, 256, 0, stream>>>((uint4*)output, weight, k * n / 8);
int num_kernels=k*n;
dequant_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels,output,zeros_and_scales,k,n,group_size);
}
void dequant_w4_gemm_colmajor(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size)
{
dequantize_s4_offset_64<<<256, 256, 0, stream>>>((uint4*)output, weight, k * n / 8);
int num_kernels=k*n;
dequant_kernel_colmajor<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels,output,zeros_and_scales,k,n,group_size);
}
__global__ void FusedSiluActivation_kernel(int num_kernels,half* output ,const uint32_t* src,int m,int n)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
auto data = ((half2*)src)[id];
float x= __half2float(data.data[0]);
float y= __half2float(data.data[1]);
float silu=x / (1.f + __expf(-x))*y;
output[id]=__float2half(silu);
}
__global__ void assign_kernel(int num_kernels,half* output ,const half* src,int m,int n)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
output[id]=src[id];
}
void addFusedSiluActivation(cudaStream_t stream,half* output, const half* src,int m,int n,int type)
{
int num_kernels=m*n;
switch (type) {
case 0:
assign_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels,output,src,m,n);
break;
case 1:
FusedSiluActivation_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(int(num_kernels/2),output,(const uint32_t*)src,m,n);
break;
default:
return;
}
}
template <typename T>
__global__ void input_padding_kernel(int num_kernels,T* output,const T* input,int m,int k,int group_size,int count)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int j=id%(k+count*group_size);
int i=id/(k+count*group_size);
if(j<k)
{
output[i*(k+count*group_size)+j]=input[i*(k)+j];
}
else
{
output[i*(k+count*group_size)+j]=0.f;
}
}
template <typename T>
void input_padding(cudaStream_t stream, T* output,const T* input,int m,int k,int group_size,int pad_groupcount)
{
//input的size是[m,k],output的size是[m,n+group_size]
//
int num_kernels=m*(k+pad_groupcount*group_size);
input_padding_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels, output,input,m,k,group_size,pad_groupcount);
}
#define INSTANTIATEINPUTPADING(T) \
template void input_padding(cudaStream_t stream, T* output,const T* input,int m,int k,int group_size,int pad_groupcount);
INSTANTIATEINPUTPADING(__half)
} // namespace turbomind
......@@ -23,6 +23,17 @@ void convert_s4_k_m8(uint32_t* A_dst,
int group_size,
cudaStream_t st = {});
void convert_s4_k_m8_(uint32_t* A_dst,
half2* Q_dst,
half* workspace,
const uint32_t* A_src,
const half* scales,
const uint32_t* qzeros,
int m,
int k,
int group_size,
cudaStream_t st = {});
void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st = {});
void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st = {});
......
......@@ -9,10 +9,32 @@
#include <memory>
#include <vector>
typedef struct ihipStream_t* hipStream_t;
extern void run_weight_only_gemm(const void *A,
const void *B0,
const void *B1,
void *C,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideB_padded,
int StrideC,
int Group,
void* splitK_padA_workspace,
int splitK_padA_workspace_elementSize,
hipStream_t stream_id=0);
namespace turbomind {
extern bool g_dump_kernel_info_once;
void dequant_w4_gemm(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size);
void addFusedSiluActivation(cudaStream_t stream,half* output, const half* src,int m,int n,int type);
void dequant_w4_gemm_colmajor(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size);
template <typename T>
void input_padding(cudaStream_t stream, T* output,const T* input,int m,int k,int group_size,int pad_groupcount);
class GemmS4F16 {
public:
......
......@@ -5,7 +5,7 @@
#include "common.h"
#include "cta_iterator.h"
#include "warp_iterator.h"
#include <cuda_pipeline_primitives.h>
//#include <cuda_pipeline_primitives.h>
namespace turbomind {
......@@ -48,19 +48,19 @@ mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<ha
__inline__ __device__ uint transpose_m8n8_b16_warp_shuffle(uint value, int lane_id)
{
int src_lane = lane_id / 8 + lane_id % 4 * 8;
uint u0 = __shfl_sync(0xffffffff, value, src_lane);
uint u1 = __shfl_sync(0xffffffff, value, src_lane + 4);
// int src_lane = lane_id / 8 + lane_id % 4 * 8;
// uint u0 = __shfl_sync(0xffffffff, value, src_lane);
// uint u1 = __shfl_sync(0xffffffff, value, src_lane + 4);
short2 r;
if (lane_id % 8 < 4) {
r.x = ((short2&)u0).x;
r.y = ((short2&)u1).x;
}
else {
r.x = ((short2&)u0).y;
r.y = ((short2&)u1).y;
}
// if (lane_id % 8 < 4) {
// r.x = ((short2&)u0).x;
// r.y = ((short2&)u1).x;
// }
// else {
// r.x = ((short2&)u0).y;
// r.y = ((short2&)u1).y;
// }
return (uint&)r;
}
......@@ -87,6 +87,7 @@ __inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id)
// #else
// return transpose_m8n8_b16_warp_shuffle(a, lane_id);
// #endif
return a;
}
namespace ops {
......@@ -158,61 +159,61 @@ struct Gemm {
int& gemm_iter)
{
constexpr int ITER_M = WARP_M / OP_M;
constexpr int ITER_N = WARP_N / OP_N;
constexpr int ITER_K = WARP_K / OP_K;
constexpr int kBatchA = (IteratorA::kIterCount + ITER_K - 1) / ITER_K;
constexpr int kBatchQ = (IteratorQ::kIterCount + ITER_K - 1) / ITER_K;
constexpr int kBatchB = (IteratorB::kIterCount + ITER_K - 1) / ITER_K;
auto frag_C_ptr = (Array<float, 4>*)accum; // [ITER_N, ITER_M]
PRAGMA_UNROLL
for (int iter_k = 0; iter_k < ITER_K; ++iter_k) {
warp_iter_A.load(warp_frag_A_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K);
warp_iter_B.load(warp_frag_B_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K);
auto warp_frag_A = warp_frag_A_[iter_k % 2];
auto warp_frag_B = warp_frag_B_[iter_k % 2];
PRAGMA_UNROLL
for (int iter_m = 0; iter_m < ITER_M; ++iter_m) {
PRAGMA_UNROLL
for (int iter_n = 0; iter_n < ITER_N; ++iter_n) {
auto& frag_A = warp_frag_A[iter_m];
auto& frag_B = warp_frag_B[iter_n];
auto& frag_C = frag_C_ptr[iter_n * ITER_M + iter_m];
mma_m16n8k16_row_col(frag_C, frag_A, frag_B, frag_C);
}
}
if (iter_k < ITER_K - 1) {
iter_A.prefetch_batch(iter_k, kBatchA, gemm_iter > 0);
iter_Q.prefetch_batch(iter_k, kBatchQ, gemm_iter > 0);
iter_B.prefetch_batch(iter_k, kBatchB, gemm_iter > 0);
}
if (iter_k == ITER_K - 2) {
iter_A.prefetch_batch(iter_k + 1, kBatchA, gemm_iter > 0);
iter_Q.prefetch_batch(iter_k + 1, kBatchQ, gemm_iter > 0);
iter_B.prefetch_batch(iter_k + 1, kBatchB, gemm_iter > 0);
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
sync_slice(slice_id);
iter_A.next_stage();
iter_Q.next_stage();
iter_B.next_stage();
warp_iter_A.next_stage();
warp_iter_B.next_stage();
--gemm_iter;
}
}
// constexpr int ITER_M = WARP_M / OP_M;
// constexpr int ITER_N = WARP_N / OP_N;
// constexpr int ITER_K = WARP_K / OP_K;
// constexpr int kBatchA = (IteratorA::kIterCount + ITER_K - 1) / ITER_K;
// constexpr int kBatchQ = (IteratorQ::kIterCount + ITER_K - 1) / ITER_K;
// constexpr int kBatchB = (IteratorB::kIterCount + ITER_K - 1) / ITER_K;
// auto frag_C_ptr = (Array<float, 4>*)accum; // [ITER_N, ITER_M]
// PRAGMA_UNROLL
// for (int iter_k = 0; iter_k < ITER_K; ++iter_k) {
// warp_iter_A.load(warp_frag_A_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K);
// warp_iter_B.load(warp_frag_B_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K);
// auto warp_frag_A = warp_frag_A_[iter_k % 2];
// auto warp_frag_B = warp_frag_B_[iter_k % 2];
// PRAGMA_UNROLL
// for (int iter_m = 0; iter_m < ITER_M; ++iter_m) {
// PRAGMA_UNROLL
// for (int iter_n = 0; iter_n < ITER_N; ++iter_n) {
// auto& frag_A = warp_frag_A[iter_m];
// auto& frag_B = warp_frag_B[iter_n];
// auto& frag_C = frag_C_ptr[iter_n * ITER_M + iter_m];
// mma_m16n8k16_row_col(frag_C, frag_A, frag_B, frag_C);
// }
// }
// if (iter_k < ITER_K - 1) {
// iter_A.prefetch_batch(iter_k, kBatchA, gemm_iter > 0);
// iter_Q.prefetch_batch(iter_k, kBatchQ, gemm_iter > 0);
// iter_B.prefetch_batch(iter_k, kBatchB, gemm_iter > 0);
// }
// if (iter_k == ITER_K - 2) {
// iter_A.prefetch_batch(iter_k + 1, kBatchA, gemm_iter > 0);
// iter_Q.prefetch_batch(iter_k + 1, kBatchQ, gemm_iter > 0);
// iter_B.prefetch_batch(iter_k + 1, kBatchB, gemm_iter > 0);
// __pipeline_commit();
// __pipeline_wait_prior(STAGES - 2);
// sync_slice(slice_id);
// iter_A.next_stage();
// iter_Q.next_stage();
// iter_B.next_stage();
// warp_iter_A.next_stage();
// warp_iter_B.next_stage();
// --gemm_iter;
// }
// }
}
template<typename T, int N>
......@@ -235,35 +236,35 @@ struct Gemm {
__device__ void sync_slice(int slice_id)
{
if constexpr (SLICES == 1) {
__syncthreads();
}
else {
constexpr int SLICE_GROUP = (SLICES + 7) / 8;
constexpr uint32_t num_threads = kWarpCountMN * WARP_SIZE;
const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
// asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
}
// if constexpr (SLICES == 1) {
// __syncthreads();
// }
// else {
// constexpr int SLICE_GROUP = (SLICES + 7) / 8;
// constexpr uint32_t num_threads = kWarpCountMN * WARP_SIZE;
// const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
// // asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
// }
}
__device__ void load_partial(float* tb_frag_C, const float* partial_C, int cta, int slice_id)
{
if (slice_id == 0) {
PRAGMA_UNROLL
for (int i = 0; i < CTA_N; ++i) {
tb_frag_C[i] += partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x];
}
}
// if (slice_id == 0) {
// PRAGMA_UNROLL
// for (int i = 0; i < CTA_N; ++i) {
// tb_frag_C[i] += partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x];
// }
// }
}
__device__ void store_partial(float* partial_C, const float* tb_frag_C, int cta, int slice_id)
{
if (slice_id == 0) {
PRAGMA_UNROLL
for (int i = 0; i < CTA_N; ++i) {
partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x] = tb_frag_C[i];
}
}
// if (slice_id == 0) {
// PRAGMA_UNROLL
// for (int i = 0; i < CTA_N; ++i) {
// partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x] = tb_frag_C[i];
// }
// }
}
template<int Index>
......@@ -280,80 +281,80 @@ struct Gemm {
int slice_id)
{
if (slice_id != 0) {
return;
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-c
PRAGMA_UNROLL
for (int i = 0; i < WARP_N / OP_N; ++i) {
const float2* frag_C = (float2*)&tb_frag_C[i * WARP_M / OP_M * 4];
const int nn = cta_n + warp_id_n * WARP_N + i * OP_N + lane_id / 4;
PRAGMA_UNROLL
for (int j = 0; j < WARP_M / OP_M; ++j) {
PRAGMA_UNROLL
for (int x = 0; x < 2; ++x) {
const int mm = cta_m + warp_id_m * WARP_M + j * OP_M + x * 8 + lane_id % 4 * 2;
// convert to half
half2 half_C = __float22half2_rn(frag_C[j * 2 + x]);
// transpose 8x8 accum tile
uint trans_C = transpose_m8n8_b16((uint&)half_C, lane_id);
// store to global memory
OutputOps::template apply<Index>(trans_C, mm, nn, C, m, n);
}
}
}
// if (slice_id != 0) {
// return;
// }
// // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-c
// PRAGMA_UNROLL
// for (int i = 0; i < WARP_N / OP_N; ++i) {
// const float2* frag_C = (float2*)&tb_frag_C[i * WARP_M / OP_M * 4];
// const int nn = cta_n + warp_id_n * WARP_N + i * OP_N + lane_id / 4;
// PRAGMA_UNROLL
// for (int j = 0; j < WARP_M / OP_M; ++j) {
// PRAGMA_UNROLL
// for (int x = 0; x < 2; ++x) {
// const int mm = cta_m + warp_id_m * WARP_M + j * OP_M + x * 8 + lane_id % 4 * 2;
// // convert to half
// half2 half_C = __float22half2_rn(frag_C[j * 2 + x]);
// // transpose 8x8 accum tile
// uint trans_C = transpose_m8n8_b16((uint&)half_C, lane_id);
// // store to global memory
// OutputOps::template apply<Index>(trans_C, mm, nn, C, m, n);
// }
// }
// }
}
__device__ void
sum_slices(float* tb_frag_C, float* tb_smem_C, int warp_id_m, int warp_id_n, int lane_id, int slice_id)
{
int offset_m = warp_id_m * WARP_M / OP_M;
int offset_n = warp_id_n * WARP_N / OP_N;
PRAGMA_UNROLL
for (int z = 0; z < SLICES; ++z) {
if (slice_id == z) {
PRAGMA_UNROLL
for (int i = 0; i < WARP_N / OP_N; ++i) {
PRAGMA_UNROLL
for (int j = 0; j < WARP_M / OP_M; ++j) {
PRAGMA_UNROLL
for (int x = 0; x < 4; ++x) {
int src = (i * WARP_M / OP_M + j) * 4 + x;
int dst = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x;
if (z > 0) {
using namespace ops;
tb_frag_C[src] = tb_smem_C[dst * WARP_SIZE + lane_id] + tb_frag_C[src];
}
tb_smem_C[dst * WARP_SIZE + lane_id] = tb_frag_C[src];
}
}
}
}
__syncthreads();
}
if (slice_id == 0) {
PRAGMA_UNROLL
for (int i = 0; i < WARP_N / OP_N; ++i) {
PRAGMA_UNROLL
for (int j = 0; j < WARP_M / OP_M; ++j) {
PRAGMA_UNROLL
for (int x = 0; x < 4; ++x) {
int src = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x;
int dst = (i * WARP_M / OP_M + j) * 4 + x;
tb_frag_C[dst] = tb_smem_C[src * WARP_SIZE + lane_id];
}
}
}
}
// int offset_m = warp_id_m * WARP_M / OP_M;
// int offset_n = warp_id_n * WARP_N / OP_N;
// PRAGMA_UNROLL
// for (int z = 0; z < SLICES; ++z) {
// if (slice_id == z) {
// PRAGMA_UNROLL
// for (int i = 0; i < WARP_N / OP_N; ++i) {
// PRAGMA_UNROLL
// for (int j = 0; j < WARP_M / OP_M; ++j) {
// PRAGMA_UNROLL
// for (int x = 0; x < 4; ++x) {
// int src = (i * WARP_M / OP_M + j) * 4 + x;
// int dst = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x;
// if (z > 0) {
// using namespace ops;
// tb_frag_C[src] = tb_smem_C[dst * WARP_SIZE + lane_id] + tb_frag_C[src];
// }
// tb_smem_C[dst * WARP_SIZE + lane_id] = tb_frag_C[src];
// }
// }
// }
// }
// __syncthreads();
// }
// if (slice_id == 0) {
// PRAGMA_UNROLL
// for (int i = 0; i < WARP_N / OP_N; ++i) {
// PRAGMA_UNROLL
// for (int j = 0; j < WARP_M / OP_M; ++j) {
// PRAGMA_UNROLL
// for (int x = 0; x < 4; ++x) {
// int src = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x;
// int dst = (i * WARP_M / OP_M + j) * 4 + x;
// tb_frag_C[dst] = tb_smem_C[src * WARP_SIZE + lane_id];
// }
// }
// }
// }
}
Array<half, 8> warp_frag_A_[2][WARP_M / OP_M];
Array<half, 4> warp_frag_B_[2][WARP_N / OP_N];
// Array<half, 8> warp_frag_A_[2][WARP_M / OP_M];
// Array<half, 4> warp_frag_B_[2][WARP_N / OP_N];
__device__ void run_v2(half* __restrict__ C,
const uint* __restrict__ A,
......@@ -364,89 +365,89 @@ struct Gemm {
int K,
int output_op_idx)
{
static_assert(WARP_M % OP_N == 0);
// static_assert(WARP_M % OP_N == 0);
float tb_frag_C[(WARP_N / OP_N) * (WARP_M / OP_M) * 4];
// float tb_frag_C[(WARP_N / OP_N) * (WARP_M / OP_M) * 4];
extern __shared__ uint8_t smem[];
// extern __shared__ uint8_t smem[];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
// const int warp_id = threadIdx.x / WARP_SIZE;
// const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id_m = warp_id % kWarpCountM;
const int warp_id_nk = warp_id / kWarpCountM;
const int warp_id_n = warp_id_nk % kWarpCountN;
const int warp_id_k = warp_id_nk / kWarpCountN;
// const int warp_id_m = warp_id % kWarpCountM;
// const int warp_id_nk = warp_id / kWarpCountM;
// const int warp_id_n = warp_id_nk % kWarpCountN;
// const int warp_id_k = warp_id_nk / kWarpCountN;
const int warp_id_mn = warp_id_n * kWarpCountM + warp_id_m;
// const int warp_id_mn = warp_id_n * kWarpCountM + warp_id_m;
const int slice_id = warp_id_k;
// const int slice_id = warp_id_k;
const int cta_k = slice_id * SLICE_K; // sliced-k offset
const int cta_m = blockIdx.x * CTA_M;
const int cta_n = blockIdx.y * CTA_N;
// const int cta_k = slice_id * SLICE_K; // sliced-k offset
// const int cta_m = blockIdx.x * CTA_M;
// const int cta_n = blockIdx.y * CTA_N;
// each slice has its own partition of smem
uint4* const tb_smem_A = (uint4*)(smem + IteratorA::kSmemByteSize * slice_id);
half* const tb_smem_B = (half*)(smem + IteratorA::kSmemByteSize * SLICES + IteratorB::kSmemByteSize * slice_id);
// // each slice has its own partition of smem
// uint4* const tb_smem_A = (uint4*)(smem + IteratorA::kSmemByteSize * slice_id);
// half* const tb_smem_B = (half*)(smem + IteratorA::kSmemByteSize * SLICES + IteratorB::kSmemByteSize * slice_id);
// [CTA_N / OP_N, CTA_M / OP_M, 4, WARP_SIZE], all mn fragments in CTA
float* const tb_smem_C = (float*)smem;
// // [CTA_N / OP_N, CTA_M / OP_M, 4, WARP_SIZE], all mn fragments in CTA
// float* const tb_smem_C = (float*)smem;
__shared__ typename IteratorQ::Storage tb_smem_Q_storage;
// __shared__ typename IteratorQ::Storage tb_smem_Q_storage;
auto tb_smem_Q = tb_smem_Q_storage.data[slice_id];
// auto tb_smem_Q = tb_smem_Q_storage.data[slice_id];
IteratorA iter_A{A, tb_smem_A, M, K, cta_m, cta_k, warp_id_mn, lane_id};
IteratorQ iter_Q{Q, tb_smem_Q, M, K, cta_m, cta_k, warp_id_mn, lane_id};
IteratorB iter_B{B, tb_smem_B, K, N, cta_n, cta_k, warp_id_mn, lane_id};
// IteratorA iter_A{A, tb_smem_A, M, K, cta_m, cta_k, warp_id_mn, lane_id};
// IteratorQ iter_Q{Q, tb_smem_Q, M, K, cta_m, cta_k, warp_id_mn, lane_id};
// IteratorB iter_B{B, tb_smem_B, K, N, cta_n, cta_k, warp_id_mn, lane_id};
const int offset_m = warp_id_m * WARP_M + lane_id;
// const int offset_m = warp_id_m * WARP_M + lane_id;
WarpIterA warp_iter_A(iter_A.smem_, iter_Q.smem_, warp_id, lane_id, offset_m, cta_k);
WarpIterB warp_iter_B(iter_B.smem_int_ptr_, warp_id_n, lane_id, 0);
// WarpIterA warp_iter_A(iter_A.smem_, iter_Q.smem_, warp_id, lane_id, offset_m, cta_k);
// WarpIterB warp_iter_B(iter_B.smem_int_ptr_, warp_id_n, lane_id, 0);
int gemm_iter = (K + CTA_K - 1) / CTA_K;
// int gemm_iter = (K + CTA_K - 1) / CTA_K;
PRAGMA_UNROLL
for (int stage = 0; stage < STAGES - 1; ++stage, --gemm_iter) {
iter_A.prefetch_stage(gemm_iter > 0);
iter_Q.prefetch_stage(gemm_iter > 0);
iter_B.prefetch_stage(gemm_iter > 0);
__pipeline_commit();
}
// PRAGMA_UNROLL
// for (int stage = 0; stage < STAGES - 1; ++stage, --gemm_iter) {
// iter_A.prefetch_stage(gemm_iter > 0);
// iter_Q.prefetch_stage(gemm_iter > 0);
// iter_B.prefetch_stage(gemm_iter > 0);
// __pipeline_commit();
// }
clear(tb_frag_C);
// clear(tb_frag_C);
__pipeline_wait_prior(STAGES - 2);
sync_slice(slice_id);
// __pipeline_wait_prior(STAGES - 2);
// sync_slice(slice_id);
warp_iter_A.load(warp_frag_A_[0], 0);
warp_iter_B.load(warp_frag_B_[0], 0);
// warp_iter_A.load(warp_frag_A_[0], 0);
// warp_iter_B.load(warp_frag_B_[0], 0);
PRAGMA_NO_UNROLL
for (; gemm_iter > -STAGES + 1;) {
warp_mma(iter_A, iter_Q, iter_B, warp_iter_A, warp_iter_B, tb_frag_C, slice_id, gemm_iter);
}
// PRAGMA_NO_UNROLL
// for (; gemm_iter > -STAGES + 1;) {
// warp_mma(iter_A, iter_Q, iter_B, warp_iter_A, warp_iter_B, tb_frag_C, slice_id, gemm_iter);
// }
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
// __pipeline_commit();
// __pipeline_wait_prior(0);
// __syncthreads();
if constexpr (SLICES > 1) {
sum_slices(tb_frag_C, tb_smem_C, warp_id_m, warp_id_n, lane_id, slice_id);
}
// if constexpr (SLICES > 1) {
// sum_slices(tb_frag_C, tb_smem_C, warp_id_m, warp_id_n, lane_id, slice_id);
// }
switch (output_op_idx) {
case 0:
store_accum<0>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id);
break;
case 1:
store_accum<1>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id);
break;
default:
return;
}
// switch (output_op_idx) {
// case 0:
// store_accum<0>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id);
// break;
// case 1:
// store_accum<1>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id);
// break;
// default:
// return;
// }
}
};
......
......@@ -78,6 +78,7 @@ bool BlockManager::Malloc()
return false;
}
//auto ptr = (std::byte*)allocator_->malloc(block_size_ * chunk_size);
auto ptr = (uint8_t*)allocator_->malloc(block_size_ * chunk_size);
if (!ptr) {
return false;
......
......@@ -19,13 +19,15 @@ add_library(Llama STATIC
unified_attention_layer.cc
llama_kernels.cu
llama_decoder_kernels.cu
llama_utils.cu)
llama_utils.cu
)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
#set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_directories(Llama PUBLIC ../../../../3rdparty/composable_kernel/)
target_link_libraries(Llama PUBLIC cudart
# gemm_s4_f16
gemm_s4_f16
cublasMMWrapper
DynamicDecodeLayer
activation_kernels
......@@ -41,7 +43,8 @@ target_link_libraries(Llama PUBLIC cudart
memory_utils
nccl_utils
cuda_utils
logger)
logger
gemm_multiB_int4)
# llama_fmha)
if (NOT MSVC)
......
......@@ -41,6 +41,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
size_t inter_size,
WeightType weight_type,
int group_size,
int w4_weight_layout,
bool attn_bias,
size_t tensor_para_size,
size_t tensor_para_rank):
......@@ -58,31 +59,37 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_;
self_attn_weights.qkv.type = weight_type;
self_attn_weights.qkv.group_size = group_size;
self_attn_weights.qkv.w4_weight_layout = w4_weight_layout;
self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_;
self_attn_weights.output.output_dims = hidden_units_;
self_attn_weights.output.type = weight_type;
self_attn_weights.output.group_size = group_size;
self_attn_weights.output.w4_weight_layout = w4_weight_layout;
ffn_weights.gating.input_dims = hidden_units_;
ffn_weights.gating.output_dims = inter_size_ / tensor_para_size_;
ffn_weights.gating.type = weight_type;
ffn_weights.gating.group_size = group_size;
ffn_weights.gating.w4_weight_layout = w4_weight_layout;
ffn_weights.intermediate.input_dims = hidden_units_;
ffn_weights.intermediate.output_dims = inter_size_ / tensor_para_size_;
ffn_weights.intermediate.type = weight_type;
ffn_weights.intermediate.group_size = group_size;
ffn_weights.intermediate.w4_weight_layout = w4_weight_layout;
ffn_weights.fused_gating_intermediate.input_dims = hidden_units_;
ffn_weights.fused_gating_intermediate.output_dims = inter_size_ / tensor_para_size_ * 2;
ffn_weights.fused_gating_intermediate.type = weight_type;
ffn_weights.fused_gating_intermediate.group_size = group_size;
ffn_weights.fused_gating_intermediate.w4_weight_layout = w4_weight_layout;
ffn_weights.output.input_dims = inter_size_ / tensor_para_size_;
ffn_weights.output.output_dims = hidden_units_;
ffn_weights.output.type = weight_type;
ffn_weights.output.group_size = group_size;
ffn_weights.output.w4_weight_layout = w4_weight_layout;
mallocWeights();
}
......@@ -111,10 +118,28 @@ void mallocWeights(LlamaDenseWeight<T>& weights, bool bias)
else { // int8, int4
const int factor = sizeof(float) * 8 / bit_size;
FT_CHECK(weights.input_dims % factor == 0);
deviceMalloc((int**)&weights.kernel, weights.input_dims * weights.output_dims / factor);
deviceMemSetZero((int*)weights.kernel, weights.input_dims * weights.output_dims / factor);
// interleaved scales/zeros
deviceMalloc((T**)&weights.scales_and_zeros, weights.input_dims / weights.group_size * weights.output_dims * 2);
// //读环境变量
// int m_weightlayout_switch=1;
// const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
// if (env_weightlayout_str != nullptr) {
// m_weightlayout_switch = std::stoi(env_weightlayout_str);
// }
if((weights.input_dims%4096==0)&&(weights.w4_weight_layout==1||weights.w4_weight_layout==2))
{
size_t new_input_dims=weights.input_dims+2*weights.group_size;
deviceMalloc((int**)&weights.kernel, new_input_dims * weights.output_dims / factor);
deviceMemSetZero((int*)weights.kernel, new_input_dims* weights.output_dims / factor);
// interleaved scales/zeros
deviceMalloc((T**)&weights.scales_and_zeros, new_input_dims / weights.group_size * weights.output_dims * 2);
}
else{
deviceMalloc((int**)&weights.kernel, weights.input_dims * weights.output_dims / factor);
deviceMemSetZero((int*)weights.kernel, weights.input_dims * weights.output_dims / factor);
// interleaved scales/zeros
deviceMalloc((T**)&weights.scales_and_zeros, weights.input_dims / weights.group_size * weights.output_dims * 2);
}
}
}
......@@ -146,16 +171,39 @@ void getWeightTensor(LlamaDenseWeight<T>& weights, bool bias, const std::string&
}
else { // int8, int4
const int factor = sizeof(float) * 8 / bit_size;
output.insert(get_name("qweight"),
Tensor{MEMORY_GPU,
TYPE_INT32,
{weights.input_dims * weights.output_dims * sizeof(int) / factor},
weights.kernel});
output.insert(get_name("scales_zeros"),
Tensor{MEMORY_GPU,
getTensorType<T>(),
{weights.input_dims / weights.group_size * weights.output_dims * 2 * sizeof(T)},
weights.scales_and_zeros});
// //读环境变量
// int m_weightlayout_switch=1;
// const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
// if (env_weightlayout_str != nullptr) {
// m_weightlayout_switch = std::stoi(env_weightlayout_str);
// }
if((weights.input_dims%4096==0)&&(weights.w4_weight_layout==1||weights.w4_weight_layout==2))
{
size_t new_input_dims=weights.input_dims+weights.group_size;
output.insert(get_name("qweight"),
Tensor{MEMORY_GPU,
TYPE_INT32,
{new_input_dims * weights.output_dims * sizeof(int) / factor},
weights.kernel});
output.insert(get_name("scales_zeros"),
Tensor{MEMORY_GPU,
getTensorType<T>(),
{new_input_dims / weights.group_size * weights.output_dims * 2 * sizeof(T)},
weights.scales_and_zeros});
}
else{
output.insert(get_name("qweight"),
Tensor{MEMORY_GPU,
TYPE_INT32,
{weights.input_dims * weights.output_dims * sizeof(int) / factor},
weights.kernel});
output.insert(get_name("scales_zeros"),
Tensor{MEMORY_GPU,
getTensorType<T>(),
{weights.input_dims / weights.group_size * weights.output_dims * 2 * sizeof(T)},
weights.scales_and_zeros});
}
}
}
......@@ -259,12 +307,31 @@ void loadWeights(LlamaDenseWeight<T>& w,
FT_CHECK(dim1 % factor == 0);
std::vector<size_t> w_shape{dim0, dim1 / factor * sizeof(uint32_t)};
loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8, {});
// //读环境变量
// int m_weightlayout_switch=1;
// const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
// if (env_weightlayout_str != nullptr) {
// m_weightlayout_switch = std::stoi(env_weightlayout_str);
// }
if((dim0%4096==0)&&(w.w4_weight_layout==1||w.w4_weight_layout==2))
{
size_t new_dim0=dim0+2*w.group_size;
std::vector<size_t> w_shape{new_dim0, dim1 / factor * sizeof(uint32_t)};
loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8, {});
const size_t group_count = w.group_size > 0 ? new_dim0 / w.group_size : 1;
const size_t group_count = w.group_size > 0 ? dim0 / w.group_size : 1;
loadWeightFromBin((half*)w.scales_and_zeros, {group_count, dim1 * 2}, prefix + ".scales_zeros", type, {});
}
else{
std::vector<size_t> w_shape{dim0, dim1 / factor * sizeof(uint32_t)};
loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8, {});
const size_t group_count = w.group_size > 0 ? dim0 / w.group_size : 1;
loadWeightFromBin((half*)w.scales_and_zeros, {group_count, dim1 * 2}, prefix + ".scales_zeros", type, {});
loadWeightFromBin((half*)w.scales_and_zeros, {group_count, dim1 * 2}, prefix + ".scales_zeros", type, {});
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment