Commit d26f4c73 authored by gaoqiong's avatar gaoqiong
Browse files

增加awq模块

parent 2326380c
...@@ -366,7 +366,7 @@ add_library(transformer-shared SHARED ...@@ -366,7 +366,7 @@ add_library(transformer-shared SHARED
# $<TARGET_OBJECTS:flash_attention2> # $<TARGET_OBJECTS:flash_attention2>
$<TARGET_OBJECTS:Llama> $<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:LlamaTritonBackend> $<TARGET_OBJECTS:LlamaTritonBackend>
# $<TARGET_OBJECTS:gemm_s4_f16> $<TARGET_OBJECTS:gemm_s4_f16>
$<TARGET_OBJECTS:TopKSamplingLayer> $<TARGET_OBJECTS:TopKSamplingLayer>
$<TARGET_OBJECTS:TopPSamplingLayer> $<TARGET_OBJECTS:TopPSamplingLayer>
$<TARGET_OBJECTS:TransformerTritonBackend> $<TARGET_OBJECTS:TransformerTritonBackend>
......
...@@ -61,7 +61,11 @@ class CLI(object): ...@@ -61,7 +61,11 @@ class CLI(object):
default=0, default=0,
help='A parameter used in awq to quantize fp16 weights ' help='A parameter used in awq to quantize fp16 weights '
'to 4 bits') 'to 4 bits')
parser.add_argument(
'--w4-weight-layout',
type=int,
default=2,
help='A parameter used in AWQ to control the layout of weight ')
parser.set_defaults(run=CLI.convert) parser.set_defaults(run=CLI.convert)
@staticmethod @staticmethod
......
...@@ -196,6 +196,7 @@ def main(model_name: str, ...@@ -196,6 +196,7 @@ def main(model_name: str,
tp: int = 1, tp: int = 1,
quant_path: str = None, quant_path: str = None,
group_size: int = 0, group_size: int = 0,
w4_weight_layout: int = 2,
**kwargs): **kwargs):
"""deploy llama family models via turbomind. """deploy llama family models via turbomind.
...@@ -215,6 +216,7 @@ def main(model_name: str, ...@@ -215,6 +216,7 @@ def main(model_name: str,
quant_path (str): Path of the quantized model, which can be None. quant_path (str): Path of the quantized model, which can be None.
group_size (int): a parameter used in AWQ to quantize fp16 weights group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits to 4 bits
w4_weight_layout (int) :a parameter used in AWQ to control the layout of weight
kwargs (dict): other params for convert kwargs (dict): other params for convert
""" """
...@@ -260,10 +262,13 @@ def main(model_name: str, ...@@ -260,10 +262,13 @@ def main(model_name: str,
cfg.tensor_para_size = tp cfg.tensor_para_size = tp
cfg.rotary_embedding = cfg.size_per_head cfg.rotary_embedding = cfg.size_per_head
cfg.group_size = group_size cfg.group_size = group_size
cfg.w4_weight_layout=w4_weight_layout
if inferred_model_format.find('awq') != -1: if inferred_model_format.find('awq') != -1:
cfg.weight_type = 'int4' cfg.weight_type = 'int4'
output_format = 'w4' output_format = 'w4'
assert group_size > 0, f'group_size: {group_size} should > 0' assert group_size > 0, f'group_size: {group_size} should > 0'
print("w4_weight_layout:",w4_weight_layout)
assert w4_weight_layout>=0 and w4_weight_layout<3,f'w4_weight_layout: {w4_weight_layout} should >= 0 and < 3'
else: else:
#output_format = update_output_format(model_name, inferred_model_format, #output_format = update_output_format(model_name, inferred_model_format,
# model_path, output_format) # model_path, output_format)
......
...@@ -5,6 +5,7 @@ import inspect ...@@ -5,6 +5,7 @@ import inspect
import io import io
import json import json
import os.path as osp import os.path as osp
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from configparser import ConfigParser from configparser import ConfigParser
...@@ -52,6 +53,7 @@ class TurbomindModelConfig: ...@@ -52,6 +53,7 @@ class TurbomindModelConfig:
rope_theta: float = 10000.0 rope_theta: float = 10000.0
size_per_head: int = 128 size_per_head: int = 128
group_size: int = 0 group_size: int = 0
w4_weight_layout : int = 2
max_batch_size: int = 64 max_batch_size: int = 64
max_context_token_num: int = 1 max_context_token_num: int = 1
step_length: int = 1 step_length: int = 1
...@@ -150,6 +152,12 @@ class BaseOutputModel(ABC): ...@@ -150,6 +152,12 @@ class BaseOutputModel(ABC):
self.to_file = to_file self.to_file = to_file
self.out_dir = out_dir self.out_dir = out_dir
self.tm_params = {} self.tm_params = {}
#self.weight_layout= 1
#获取环境变量
#env_weight_layout = os.environ.get('LMDEPLOY_WEIGHTLAYOUT_SWITCH', '1')
#self.weight_layout =int(env_weight_layout)
#print("self.weight_layout:",self.weight_layout)
@abstractmethod @abstractmethod
def get_config(self, cfg: TurbomindModelConfig) -> TurbomindModelConfig: def get_config(self, cfg: TurbomindModelConfig) -> TurbomindModelConfig:
...@@ -317,6 +325,10 @@ def permute(x: torch.Tensor, size_per_head: int = 128): ...@@ -317,6 +325,10 @@ def permute(x: torch.Tensor, size_per_head: int = 128):
return x.view(n_heads, 2, dim // n_heads // 2, return x.view(n_heads, 2, dim // n_heads // 2,
1).transpose(1, 2).reshape(dim, 1) 1).transpose(1, 2).reshape(dim, 1)
def permute_trans(x: torch.Tensor):
if x.shape[-1]>1:
dim = x.shape[-1]
return x.view(-1, x.shape[-1]).transpose(0, 1).reshape(dim,-1)
def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int, def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int,
dim: int): dim: int):
......
...@@ -8,7 +8,7 @@ import lmdeploy ...@@ -8,7 +8,7 @@ import lmdeploy
from ..source_model.base import BaseInputModel, BaseReader from ..source_model.base import BaseInputModel, BaseReader
from .base import (OUTPUT_MODELS, BaseOutputModel, TurbomindModelConfig, from .base import (OUTPUT_MODELS, BaseOutputModel, TurbomindModelConfig,
merge_qkv, permute) merge_qkv, permute,permute_trans)
# import _turbomind as _tm # import _turbomind as _tm
# TODO: find another way import _turbomind # TODO: find another way import _turbomind
...@@ -56,6 +56,18 @@ def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor, ...@@ -56,6 +56,18 @@ def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
qw.size(-1) * 8, qw.size(0), group_size) qw.size(-1) * 8, qw.size(0), group_size)
return _qw, _sz return _qw, _sz
def convert_s4_(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size: int):
assert qw.is_contiguous()
assert qz.is_contiguous()
assert s.is_contiguous()
_qw = torch.zeros_like(qw)
_sz = torch.zeros_like(s, dtype=torch.int32) # half2
_ws = torch.zeros_like(s)
_tm.convert_s4_k_m8_(_qw, _sz, _ws, qw, s, qz,
qw.size(-1) * 8, qw.size(0), group_size)
return _qw, _sz
def tp_m_s4(x: torch.Tensor, tp: int): def tp_m_s4(x: torch.Tensor, tp: int):
return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3, return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
...@@ -104,6 +116,7 @@ class TurbomindW4Model(BaseOutputModel): ...@@ -104,6 +116,7 @@ class TurbomindW4Model(BaseOutputModel):
"""Export transformer layer i.""" """Export transformer layer i."""
group_size = self.cfg.group_size group_size = self.cfg.group_size
tp = self.cfg.tensor_para_size tp = self.cfg.tensor_para_size
w4_weight_layout = self.cfg.w4_weight_layout
size_per_head = self.cfg.size_per_head size_per_head = self.cfg.size_per_head
# attn # attn
q_qw, k_qw, v_qw, o_qw = get_cuda_tensor(bin.attn(i)) q_qw, k_qw, v_qw, o_qw = get_cuda_tensor(bin.attn(i))
...@@ -121,12 +134,45 @@ class TurbomindW4Model(BaseOutputModel): ...@@ -121,12 +134,45 @@ class TurbomindW4Model(BaseOutputModel):
qkv_qz = merge_qkv(q_qz, k_qz, v_qz, tp, dim=2) qkv_qz = merge_qkv(q_qz, k_qz, v_qz, tp, dim=2)
qkv_s = merge_qkv(q_s, k_s, v_s, tp, dim=2) qkv_s = merge_qkv(q_s, k_s, v_s, tp, dim=2)
qkv_qw, qkv_sz = convert_s4(qkv_qw, qkv_qz, qkv_s, group_size) pad_group_count=2
qkv_qw = tp_m_s4(qkv_qw, tp)
if w4_weight_layout==1 or w4_weight_layout==2:
if qkv_qw.shape[0]%4096==0:
qkv_qw_padding=torch.zeros(group_size*pad_group_count,qkv_qw.shape[1],dtype=torch.int32).cuda()
qkv_qw =torch.cat((qkv_qw,qkv_qw_padding),dim=0).contiguous()
qkv_qz_padding =torch.zeros(pad_group_count,qkv_qz.shape[1],dtype=torch.int32).cuda()
qkv_qz =torch.cat((qkv_qz,qkv_qz_padding),dim=0).contiguous()
qkv_s_padding =torch.zeros(pad_group_count,qkv_s.shape[1],dtype=torch.float16).cuda()
qkv_s =torch.cat((qkv_s,qkv_s_padding),dim=0).contiguous()
qkv_qw, qkv_sz = convert_s4_(qkv_qw, qkv_qz, qkv_s, group_size)
qkv_qw = tp_m_s4(qkv_qw, tp)
qkv_sz = permute_trans(qkv_sz)
else:
qkv_qw, qkv_sz = convert_s4(qkv_qw, qkv_qz, qkv_s, group_size)
qkv_qw = tp_m_s4(qkv_qw, tp)
#print("请设置weight layout\n")
self.save_split(qkv_qw, f'layers.{i}.attention.w_qkv.qweight', -1) self.save_split(qkv_qw, f'layers.{i}.attention.w_qkv.qweight', -1)
self.save_split(qkv_sz, f'layers.{i}.attention.w_qkv.scales_zeros', -1) self.save_split(qkv_sz, f'layers.{i}.attention.w_qkv.scales_zeros', -1)
o_qw, o_sz = convert_s4(o_qw, o_qz, o_s, group_size) if w4_weight_layout==1 or w4_weight_layout==2:
if o_qw.shape[0]%4096==0:
o_qw_padding=torch.zeros(group_size*pad_group_count,o_qw.shape[1],dtype=torch.int32).cuda()
o_qw =torch.cat((o_qw,o_qw_padding),dim=0).contiguous()
o_qz_padding =torch.zeros(pad_group_count,o_qz.shape[1],dtype=torch.int32).cuda()
o_qz =torch.cat((o_qz,o_qz_padding),dim=0).contiguous()
o_s_padding =torch.zeros(pad_group_count,o_s.shape[1],dtype=torch.float16).cuda()
o_s =torch.cat((o_s,o_s_padding),dim=0).contiguous()
o_qw, o_sz = convert_s4_(o_qw, o_qz, o_s, group_size)
o_sz = permute_trans(o_sz)
else:
o_qw, o_sz = convert_s4(o_qw, o_qz, o_s, group_size)
self.save_split(o_qw, f'layers.{i}.attention.wo.qweight', 0) self.save_split(o_qw, f'layers.{i}.attention.wo.qweight', 0)
self.save_split(o_sz, f'layers.{i}.attention.wo.scales_zeros', 0) self.save_split(o_sz, f'layers.{i}.attention.wo.scales_zeros', 0)
...@@ -145,13 +191,45 @@ class TurbomindW4Model(BaseOutputModel): ...@@ -145,13 +191,45 @@ class TurbomindW4Model(BaseOutputModel):
w13_qw, w13_qz, w13_s = fuse_w1_w3_s4(w1_qw, w1_qz, w1_s, w3_qw, w3_qz, w13_qw, w13_qz, w13_s = fuse_w1_w3_s4(w1_qw, w1_qz, w1_s, w3_qw, w3_qz,
w3_s) w3_s)
w13_qw, w13_sz = convert_s4(w13_qw, w13_qz, w13_s, group_size) if w4_weight_layout==1 or w4_weight_layout==2:
w13_qw = tp_m_s4(w13_qw, tp) if w13_qw.shape[0]%4096==0:
w13_qw_padding=torch.zeros(group_size*pad_group_count,w13_qw.shape[1],dtype=torch.int32).cuda()
w13_qw =torch.cat((w13_qw,w13_qw_padding),dim=0).contiguous()
w13_qz_padding =torch.zeros(pad_group_count,w13_qz.shape[1],dtype=torch.int32).cuda()
w13_qz =torch.cat((w13_qz,w13_qz_padding),dim=0).contiguous()
w13_s_padding =torch.zeros(pad_group_count,w13_s.shape[1],dtype=torch.float16).cuda()
w13_s =torch.cat((w13_s,w13_s_padding),dim=0).contiguous()
w13_qw, w13_sz = convert_s4_(w13_qw, w13_qz, w13_s, group_size)
w13_qw = tp_m_s4(w13_qw, tp)
w13_sz = permute_trans(w13_sz)
else:
w13_qw, w13_sz = convert_s4(w13_qw, w13_qz, w13_s, group_size)
w13_qw = tp_m_s4(w13_qw, tp)
self.save_split(w13_qw, f'layers.{i}.feed_forward.w13.qweight', -1) self.save_split(w13_qw, f'layers.{i}.feed_forward.w13.qweight', -1)
self.save_split(w13_sz, f'layers.{i}.feed_forward.w13.scales_zeros', self.save_split(w13_sz, f'layers.{i}.feed_forward.w13.scales_zeros',
-1) -1)
w2_qw, w2_sz = convert_s4(w2_qw, w2_qz, w2_s, group_size) if w4_weight_layout==1 or w4_weight_layout==2:
#pading
if w2_qw.shape[0]%4096==0:
w2_qw_padding=torch.zeros(group_size*pad_group_count,w2_qw.shape[1],dtype=torch.int32).cuda()
w2_qw =torch.cat((w2_qw,w2_qw_padding),dim=0).contiguous()
w2_qz_padding =torch.zeros(pad_group_count,w2_qz.shape[1],dtype=torch.int32).cuda()
w2_qz =torch.cat((w2_qz,w2_qz_padding),dim=0).contiguous()
w2_s_padding =torch.zeros(pad_group_count,w2_s.shape[1],dtype=torch.float16).cuda()
w2_s =torch.cat((w2_s,w2_s_padding),dim=0).contiguous()
w2_qw, w2_sz = convert_s4_(w2_qw, w2_qz, w2_s, group_size)
w2_sz = permute_trans(w2_sz)
else:
w2_qw, w2_sz = convert_s4(w2_qw, w2_qz, w2_s, group_size)
self.save_split(w2_qw, f'layers.{i}.feed_forward.w2.qweight', 0) self.save_split(w2_qw, f'layers.{i}.feed_forward.w2.qweight', 0)
self.save_split(w2_sz, f'layers.{i}.feed_forward.w2.scales_zeros', 0) self.save_split(w2_sz, f'layers.{i}.feed_forward.w2.scales_zeros', 0)
......
...@@ -147,6 +147,7 @@ class TurboMind: ...@@ -147,6 +147,7 @@ class TurboMind:
model_name: Optional[str] = None, model_name: Optional[str] = None,
model_format: Optional[str] = None, model_format: Optional[str] = None,
group_size: Optional[int] = None, group_size: Optional[int] = None,
w4_weight_layout: Optional[int] = None,
tp: Optional[int] = None, tp: Optional[int] = None,
chat_template_config: Optional[ChatTemplateConfig] = None, chat_template_config: Optional[ChatTemplateConfig] = None,
**kwargs): **kwargs):
...@@ -179,6 +180,7 @@ class TurboMind: ...@@ -179,6 +180,7 @@ class TurboMind:
engine_config = _update_engine_config(engine_config, engine_config = _update_engine_config(engine_config,
model_format=model_format, model_format=model_format,
group_size=group_size, group_size=group_size,
w4_weight_layout=w4_weight_layout,
tp=tp, tp=tp,
**kwargs) **kwargs)
...@@ -304,6 +306,7 @@ class TurboMind: ...@@ -304,6 +306,7 @@ class TurboMind:
output_format = 'w4' output_format = 'w4'
data_type = 'int4' data_type = 'int4'
cfg.group_size = 128 cfg.group_size = 128
cfg.w4_weight_layout=2
else: else:
# output_format = update_output_format(cfg.model_name, # output_format = update_output_format(cfg.model_name,
# inferred_model_format, # inferred_model_format,
...@@ -378,6 +381,7 @@ class TurboMind: ...@@ -378,6 +381,7 @@ class TurboMind:
self.config = cfg self.config = cfg
self.model_name = cfg.model_name self.model_name = cfg.model_name
self.data_type = cfg.weight_type self.data_type = cfg.weight_type
print("from_workspace_cfg:",cfg)
# create model # create model
logger.warning(f'model_config:\n\n{cfg.toini()}') logger.warning(f'model_config:\n\n{cfg.toini()}')
......
...@@ -69,8 +69,20 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -69,8 +69,20 @@ def get_version_add(sha: Optional[str] = None) -> str:
file.writelines(lines) file.writelines(lines)
file.close() file.close()
def copy_ck_so():
lmdeploy_root = os.path.dirname(os.path.abspath(__file__))
so_path = os.path.join(os.path.join(lmdeploy_root, "3rdparty"), "libgemm_multiB_int4.so")
# dtk version
if os.getenv("ROCM_PATH"):
rocm_path = os.getenv('ROCM_PATH', "")
rocm_so_path = os.path.join(rocm_path, 'lib')
print("rocm_so_path:",rocm_so_path)
shutil.copy(so_path, rocm_so_path)
else:
shutil.copy(so_path, "usr/local/lib")
def get_version(): def get_version():
copy_ck_so()
get_version_add() get_version_add()
version_file = 'lmdeploy/version.py' version_file = 'lmdeploy/version.py'
with open(version_file, encoding='utf-8') as f: with open(version_file, encoding='utf-8') as f:
......
...@@ -72,5 +72,5 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu) ...@@ -72,5 +72,5 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu)
#set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#add_subdirectory(gemm_s_f16) add_subdirectory(gemm_s_f16)
add_subdirectory(decoder_multihead_attention) add_subdirectory(decoder_multihead_attention)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
add_library(gemm_s4_f16 STATIC gemm_s4_f16.cu format.cu) add_library(gemm_s4_f16 STATIC gemm_s4_f16.cu format.cu ../../models/llama/awq_sugon/gemm_w4_dequation.cu)
target_compile_options(gemm_s4_f16 PRIVATE target_compile_options(gemm_s4_f16 PRIVATE
--generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr) --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
set_property(TARGET gemm_s4_f16 PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET gemm_s4_f16 PROPERTY POSITION_INDEPENDENT_CODE ON)
......
...@@ -72,19 +72,23 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) ...@@ -72,19 +72,23 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[0]) // : "=r"(h[0])
// : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 h[0]=(i4s & BOTTOM_MASK)|I4s_TO_F16s_MAGIC_NUM;
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[1]) // : "=r"(h[1])
// : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 h[1]=(i4s & TOP_MASK)|I4s_TO_F16s_MAGIC_NUM;
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[2]) // : "=r"(h[2])
// : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 h[2]=(top_i4s & BOTTOM_MASK)|I4s_TO_F16s_MAGIC_NUM;
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[3]) // : "=r"(h[3])
// : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
printf("=========common.h 86\n"); h[3]=(top_i4s & TOP_MASK)|I4s_TO_F16s_MAGIC_NUM;
// I use inline PTX below because I am not sure if the compiler will emit // I use inline PTX below because I am not sure if the compiler will emit
// float2half instructions if I use the half2 ctor. In this case, I chose // float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability. // performance reliability over code readability.
...@@ -102,14 +106,17 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) ...@@ -102,14 +106,17 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// Finally, we construct the output numbers. // Finally, we construct the output numbers.
// Convert elt_01 // Convert elt_01
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); //asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// // Convert elt_23 h[0]=h[0]-FP16_TOP_MAGIC_NUM;
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); // Convert elt_23
// // Convert elt_45 //asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); h[1]=h[1]*ONE_SIXTEENTH+NEG_64;
// // Convert elt_67 // Convert elt_45
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); //asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
h[2]=h[2]-FP16_TOP_MAGIC_NUM;
// Convert elt_67
//asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
h[3]=h[3]*ONE_SIXTEENTH+NEG_64;
return result; return result;
} }
...@@ -131,31 +138,22 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source) ...@@ -131,31 +138,22 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required. // dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8; const uint32_t top_i4s = i4s >> 8;
printf("=========common.h 133\n");
// if (0) { // 1024 & 64 // 64 only, trade 4 hfma2 with 2 shifts
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut)); h[0] =(i4s & BOT_MASK) |MAGIC_NUM_2;
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut)); h[1] =(i4s & TOP_MASK) |MAGIC_NUM_1;
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut)); h[2] =(top_i4s & BOT_MASK) |MAGIC_NUM_2;
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut)); h[3] =(top_i4s & TOP_MASK) |MAGIC_NUM_1;
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MAGIC_NUM_0)); h[0] <<= 4;
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MAGIC_NUM_1)); h[2] <<= 4;
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MAGIC_NUM_0)); // we don't need to subtract the magic nums because zeros will go through the same dequant function
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MAGIC_NUM_1)); // and carry the same magic constant, the magic num will be canceled out after subtracting zeros
// }
// else { // 64 only, trade 4 hfma2 with 2 shifts
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
// h[0] <<= 4;
// h[2] <<= 4;
// // we don't need to subtract the magic nums because zeros will go through the same dequant function
// // and carry the same magic constant, the magic num will be canceled out after subtracting zeros
// }
return result; return result;
} }
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr) __inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr)
{ {
uint32_t smem_int_ptr; uint32_t smem_int_ptr;
...@@ -220,12 +218,12 @@ __inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t sme ...@@ -220,12 +218,12 @@ __inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t sme
__inline__ __device__ half2 apply_Q(const half2& x, const half2& q) __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
{ {
uint s, z; //uint s, z;
(half2&)z = __halves2half2(q.x, q.x); //(half2&)z = __halves2half2(q.x, q.x);
(half2&)s = __halves2half2(q.y, q.y); //(half2&)s = __halves2half2(q.y, q.y);
auto& t = (const uint&)x; //auto& t = (const uint&)x;
uint u, v; uint v;
// if (TURBOMIND_S4_DEQUANT_USE_FMA) { // if (TURBOMIND_S4_DEQUANT_USE_FMA) {
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z)); // asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z));
// } // }
...@@ -233,7 +231,7 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q) ...@@ -233,7 +231,7 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
// asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z)); // asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z));
// asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s)); // asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s));
// } // }
printf("=========common.h 235\n");
return (half2&)v; return (half2&)v;
} }
......
// Copyright (c) OpenMMLab. All rights reserved. // Copyright (c) OpenMMLab. All rights reserved.
#include "common.h" #include "common.h"
#include "src/turbomind/models/llama/awq_sugon/gemm_w4_dequation.cuh"
#include <iostream> #include <iostream>
namespace turbomind { namespace turbomind {
...@@ -71,7 +72,17 @@ void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStre ...@@ -71,7 +72,17 @@ void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStre
// permutation for [k, m/8] layout // permutation for [k, m/8] layout
Array<int, 10> shape{k / 32, 2, 2, 4, 2, m / 32, 2, 2, 2, 4}; Array<int, 10> shape{k / 32, 2, 2, 4, 2, m / 32, 2, 2, 2, 4};
// |warp| lane | 2x2 | a0-7 | // |warp| lane | 2x2 | a0-7 |
permute_u4<0, 5, 9, 8, 3, 1, 6, 4, 2, 7><<<512, 512, 0, st>>>(dst, src, shape); //permute_u4<0, 5, 9, 8, 3, 1, 6, 4, 2, 7><<<512, 512, 0, st>>>(dst, src, shape);
permute_u4<0, 1, 2, 3, 4, 5, 6, 7, 8, 9><<<512, 512, 0, st>>>(dst, src, shape);
}
void reformat_s4_k_m8_tarnsw4(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
{
// permutation for [k, m/8] layout
Array<int, 10> shape{1, k / 8, 2, 2, 2, 1, m / 8, 2, 2, 2};
// 0123456-->4,6,7,5,0,3,1,2
//permute_u4<4, 6, 7, 5, 0, 3, 1, 2><<<512, 512, 0, st>>>(dst, src, shape);
permute_u4<5, 6, 8, 9, 7, 0, 1, 4, 2, 3><<<512, 512, 0, st>>>(dst, src, shape);
} }
__global__ void dequantize_s4_offset_64(uint4* dst, const uint32_t* src, size_t count) __global__ void dequantize_s4_offset_64(uint4* dst, const uint32_t* src, size_t count)
...@@ -112,6 +123,22 @@ void convert_s4_k_m8(uint32_t* A_dst, ...@@ -112,6 +123,22 @@ void convert_s4_k_m8(uint32_t* A_dst,
reformat_s4_k_m8(A_dst, A_src, m, k, st); reformat_s4_k_m8(A_dst, A_src, m, k, st);
} }
void convert_s4_k_m8_(uint32_t* A_dst,
half2* Q_dst,
half* workspace,
const uint32_t* A_src,
const half* scales,
const uint32_t* qzeros,
int m,
int k,
int group_size,
cudaStream_t st)
{
dequantize_s4_offset_64<<<256, 256, 0, st>>>((uint4*)workspace, qzeros, k / group_size * m / 8);
merge_Q<<<256, 256, 0, st>>>(Q_dst, scales, workspace, k / group_size * m);
reformat_s4_k_m8_tarnsw4(A_dst, A_src, m, k, st);
}
void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st) void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st)
{ {
Array<int, 7> shape{k, m / size_per_head, 2, size_per_head / 2 / 8, 2, 2, 2}; Array<int, 7> shape{k, m / size_per_head, 2, size_per_head / 2 / 8, 2, 2, 2};
...@@ -140,5 +167,82 @@ void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t s ...@@ -140,5 +167,82 @@ void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t s
{ {
dequantize_s4_kernel<<<512, 512>>>(dst, src, count); dequantize_s4_kernel<<<512, 512>>>(dst, src, count);
} }
__global__ void dequant_kernel(int num_kernels,half* weight ,const half2* zeros_and_scales,int k,int n,int group_size)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int j=id%n;
int i=id/n;
half x=zeros_and_scales[i/group_size*n+j].data[0];
half y= zeros_and_scales[i/group_size*n+j].data[1];
float tmp=(weight[id]-x)*y;
weight[id]=__float2half(tmp);
}
__global__ void dequant_kernel_colmajor(int num_kernels,half* weight ,const half2* zeros_and_scales,int k,int n,int group_size)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int j=id/group_size;
half x=zeros_and_scales[j].data[0];
half y= zeros_and_scales[j].data[1];
float tmp=(weight[id]-x)*y;
weight[id]=__float2half(tmp);
}
void dequant_w4_gemm(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size)
{
dequantize_s4_offset_64<<<256, 256, 0, stream>>>((uint4*)output, weight, k * n / 8);
int num_kernels=k*n;
dequant_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels,output,zeros_and_scales,k,n,group_size);
}
void dequant_w4_gemm_colmajor(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size)
{
dequantize_s4_offset_64<<<256, 256, 0, stream>>>((uint4*)output, weight, k * n / 8);
int num_kernels=k*n;
dequant_kernel_colmajor<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels,output,zeros_and_scales,k,n,group_size);
}
__global__ void FusedSiluActivation_kernel(int num_kernels,half* output ,const uint32_t* src,int m,int n)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
auto data = ((half2*)src)[id];
float x= __half2float(data.data[0]);
float y= __half2float(data.data[1]);
float silu=x / (1.f + __expf(-x))*y;
output[id]=__float2half(silu);
}
__global__ void assign_kernel(int num_kernels,half* output ,const half* src,int m,int n)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
output[id]=src[id];
}
void addFusedSiluActivation(cudaStream_t stream,half* output, const half* src,int m,int n,int type)
{
int num_kernels=m*n;
switch (type) {
case 0:
assign_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels,output,src,m,n);
break;
case 1:
FusedSiluActivation_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(int(num_kernels/2),output,(const uint32_t*)src,m,n);
break;
default:
return;
}
}
} // namespace turbomind } // namespace turbomind
...@@ -23,6 +23,17 @@ void convert_s4_k_m8(uint32_t* A_dst, ...@@ -23,6 +23,17 @@ void convert_s4_k_m8(uint32_t* A_dst,
int group_size, int group_size,
cudaStream_t st = {}); cudaStream_t st = {});
void convert_s4_k_m8_(uint32_t* A_dst,
half2* Q_dst,
half* workspace,
const uint32_t* A_src,
const half* scales,
const uint32_t* qzeros,
int m,
int k,
int group_size,
cudaStream_t st = {});
void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st = {}); void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st = {});
void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st = {}); void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st = {});
......
...@@ -13,6 +13,10 @@ ...@@ -13,6 +13,10 @@
namespace turbomind { namespace turbomind {
extern bool g_dump_kernel_info_once; extern bool g_dump_kernel_info_once;
void dequant_w4_gemm(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size);
void addFusedSiluActivation(cudaStream_t stream,half* output, const half* src,int m,int n,int type);
void dequant_w4_gemm_colmajor(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size);
class GemmS4F16 { class GemmS4F16 {
public: public:
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "common.h" #include "common.h"
#include "cta_iterator.h" #include "cta_iterator.h"
#include "warp_iterator.h" #include "warp_iterator.h"
#include <cuda_pipeline_primitives.h> //#include <cuda_pipeline_primitives.h>
namespace turbomind { namespace turbomind {
...@@ -48,19 +48,19 @@ mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<ha ...@@ -48,19 +48,19 @@ mma_m16n8k16_row_col(Array<float, 4>& d, const Array<half, 8>& a, const Array<ha
__inline__ __device__ uint transpose_m8n8_b16_warp_shuffle(uint value, int lane_id) __inline__ __device__ uint transpose_m8n8_b16_warp_shuffle(uint value, int lane_id)
{ {
int src_lane = lane_id / 8 + lane_id % 4 * 8; // int src_lane = lane_id / 8 + lane_id % 4 * 8;
uint u0 = __shfl_sync(0xffffffff, value, src_lane); // uint u0 = __shfl_sync(0xffffffff, value, src_lane);
uint u1 = __shfl_sync(0xffffffff, value, src_lane + 4); // uint u1 = __shfl_sync(0xffffffff, value, src_lane + 4);
short2 r; short2 r;
if (lane_id % 8 < 4) { // if (lane_id % 8 < 4) {
r.x = ((short2&)u0).x; // r.x = ((short2&)u0).x;
r.y = ((short2&)u1).x; // r.y = ((short2&)u1).x;
} // }
else { // else {
r.x = ((short2&)u0).y; // r.x = ((short2&)u0).y;
r.y = ((short2&)u1).y; // r.y = ((short2&)u1).y;
} // }
return (uint&)r; return (uint&)r;
} }
...@@ -87,6 +87,7 @@ __inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id) ...@@ -87,6 +87,7 @@ __inline__ __device__ uint transpose_m8n8_b16(uint a, int lane_id)
// #else // #else
// return transpose_m8n8_b16_warp_shuffle(a, lane_id); // return transpose_m8n8_b16_warp_shuffle(a, lane_id);
// #endif // #endif
return a;
} }
namespace ops { namespace ops {
...@@ -158,61 +159,61 @@ struct Gemm { ...@@ -158,61 +159,61 @@ struct Gemm {
int& gemm_iter) int& gemm_iter)
{ {
constexpr int ITER_M = WARP_M / OP_M; // constexpr int ITER_M = WARP_M / OP_M;
constexpr int ITER_N = WARP_N / OP_N; // constexpr int ITER_N = WARP_N / OP_N;
constexpr int ITER_K = WARP_K / OP_K; // constexpr int ITER_K = WARP_K / OP_K;
constexpr int kBatchA = (IteratorA::kIterCount + ITER_K - 1) / ITER_K; // constexpr int kBatchA = (IteratorA::kIterCount + ITER_K - 1) / ITER_K;
constexpr int kBatchQ = (IteratorQ::kIterCount + ITER_K - 1) / ITER_K; // constexpr int kBatchQ = (IteratorQ::kIterCount + ITER_K - 1) / ITER_K;
constexpr int kBatchB = (IteratorB::kIterCount + ITER_K - 1) / ITER_K; // constexpr int kBatchB = (IteratorB::kIterCount + ITER_K - 1) / ITER_K;
auto frag_C_ptr = (Array<float, 4>*)accum; // [ITER_N, ITER_M] // auto frag_C_ptr = (Array<float, 4>*)accum; // [ITER_N, ITER_M]
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int iter_k = 0; iter_k < ITER_K; ++iter_k) { // for (int iter_k = 0; iter_k < ITER_K; ++iter_k) {
warp_iter_A.load(warp_frag_A_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K); // warp_iter_A.load(warp_frag_A_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K);
warp_iter_B.load(warp_frag_B_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K); // warp_iter_B.load(warp_frag_B_[(iter_k + 1) % 2], (iter_k + 1) % ITER_K);
auto warp_frag_A = warp_frag_A_[iter_k % 2]; // auto warp_frag_A = warp_frag_A_[iter_k % 2];
auto warp_frag_B = warp_frag_B_[iter_k % 2]; // auto warp_frag_B = warp_frag_B_[iter_k % 2];
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int iter_m = 0; iter_m < ITER_M; ++iter_m) { // for (int iter_m = 0; iter_m < ITER_M; ++iter_m) {
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int iter_n = 0; iter_n < ITER_N; ++iter_n) { // for (int iter_n = 0; iter_n < ITER_N; ++iter_n) {
auto& frag_A = warp_frag_A[iter_m]; // auto& frag_A = warp_frag_A[iter_m];
auto& frag_B = warp_frag_B[iter_n]; // auto& frag_B = warp_frag_B[iter_n];
auto& frag_C = frag_C_ptr[iter_n * ITER_M + iter_m]; // auto& frag_C = frag_C_ptr[iter_n * ITER_M + iter_m];
mma_m16n8k16_row_col(frag_C, frag_A, frag_B, frag_C); // mma_m16n8k16_row_col(frag_C, frag_A, frag_B, frag_C);
} // }
} // }
if (iter_k < ITER_K - 1) { // if (iter_k < ITER_K - 1) {
iter_A.prefetch_batch(iter_k, kBatchA, gemm_iter > 0); // iter_A.prefetch_batch(iter_k, kBatchA, gemm_iter > 0);
iter_Q.prefetch_batch(iter_k, kBatchQ, gemm_iter > 0); // iter_Q.prefetch_batch(iter_k, kBatchQ, gemm_iter > 0);
iter_B.prefetch_batch(iter_k, kBatchB, gemm_iter > 0); // iter_B.prefetch_batch(iter_k, kBatchB, gemm_iter > 0);
} // }
if (iter_k == ITER_K - 2) { // if (iter_k == ITER_K - 2) {
iter_A.prefetch_batch(iter_k + 1, kBatchA, gemm_iter > 0); // iter_A.prefetch_batch(iter_k + 1, kBatchA, gemm_iter > 0);
iter_Q.prefetch_batch(iter_k + 1, kBatchQ, gemm_iter > 0); // iter_Q.prefetch_batch(iter_k + 1, kBatchQ, gemm_iter > 0);
iter_B.prefetch_batch(iter_k + 1, kBatchB, gemm_iter > 0); // iter_B.prefetch_batch(iter_k + 1, kBatchB, gemm_iter > 0);
__pipeline_commit(); // __pipeline_commit();
__pipeline_wait_prior(STAGES - 2); // __pipeline_wait_prior(STAGES - 2);
sync_slice(slice_id); // sync_slice(slice_id);
iter_A.next_stage(); // iter_A.next_stage();
iter_Q.next_stage(); // iter_Q.next_stage();
iter_B.next_stage(); // iter_B.next_stage();
warp_iter_A.next_stage(); // warp_iter_A.next_stage();
warp_iter_B.next_stage(); // warp_iter_B.next_stage();
--gemm_iter; // --gemm_iter;
} // }
} // }
} }
template<typename T, int N> template<typename T, int N>
...@@ -235,35 +236,35 @@ struct Gemm { ...@@ -235,35 +236,35 @@ struct Gemm {
__device__ void sync_slice(int slice_id) __device__ void sync_slice(int slice_id)
{ {
if constexpr (SLICES == 1) { // if constexpr (SLICES == 1) {
__syncthreads(); // __syncthreads();
} // }
else { // else {
constexpr int SLICE_GROUP = (SLICES + 7) / 8; // constexpr int SLICE_GROUP = (SLICES + 7) / 8;
constexpr uint32_t num_threads = kWarpCountMN * WARP_SIZE; // constexpr uint32_t num_threads = kWarpCountMN * WARP_SIZE;
const uint32_t barrier_id = slice_id / SLICE_GROUP + 1; // const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
// asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads)); // // asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
} // }
} }
__device__ void load_partial(float* tb_frag_C, const float* partial_C, int cta, int slice_id) __device__ void load_partial(float* tb_frag_C, const float* partial_C, int cta, int slice_id)
{ {
if (slice_id == 0) { // if (slice_id == 0) {
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int i = 0; i < CTA_N; ++i) { // for (int i = 0; i < CTA_N; ++i) {
tb_frag_C[i] += partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x]; // tb_frag_C[i] += partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x];
} // }
} // }
} }
__device__ void store_partial(float* partial_C, const float* tb_frag_C, int cta, int slice_id) __device__ void store_partial(float* partial_C, const float* tb_frag_C, int cta, int slice_id)
{ {
if (slice_id == 0) { // if (slice_id == 0) {
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int i = 0; i < CTA_N; ++i) { // for (int i = 0; i < CTA_N; ++i) {
partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x] = tb_frag_C[i]; // partial_C[cta * CTA_N * CTA_M + i * CTA_M + threadIdx.x] = tb_frag_C[i];
} // }
} // }
} }
template<int Index> template<int Index>
...@@ -280,80 +281,80 @@ struct Gemm { ...@@ -280,80 +281,80 @@ struct Gemm {
int slice_id) int slice_id)
{ {
if (slice_id != 0) { // if (slice_id != 0) {
return; // return;
} // }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-c // // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-c
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int i = 0; i < WARP_N / OP_N; ++i) { // for (int i = 0; i < WARP_N / OP_N; ++i) {
const float2* frag_C = (float2*)&tb_frag_C[i * WARP_M / OP_M * 4]; // const float2* frag_C = (float2*)&tb_frag_C[i * WARP_M / OP_M * 4];
const int nn = cta_n + warp_id_n * WARP_N + i * OP_N + lane_id / 4; // const int nn = cta_n + warp_id_n * WARP_N + i * OP_N + lane_id / 4;
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int j = 0; j < WARP_M / OP_M; ++j) { // for (int j = 0; j < WARP_M / OP_M; ++j) {
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int x = 0; x < 2; ++x) { // for (int x = 0; x < 2; ++x) {
const int mm = cta_m + warp_id_m * WARP_M + j * OP_M + x * 8 + lane_id % 4 * 2; // const int mm = cta_m + warp_id_m * WARP_M + j * OP_M + x * 8 + lane_id % 4 * 2;
// convert to half // // convert to half
half2 half_C = __float22half2_rn(frag_C[j * 2 + x]); // half2 half_C = __float22half2_rn(frag_C[j * 2 + x]);
// transpose 8x8 accum tile // // transpose 8x8 accum tile
uint trans_C = transpose_m8n8_b16((uint&)half_C, lane_id); // uint trans_C = transpose_m8n8_b16((uint&)half_C, lane_id);
// store to global memory // // store to global memory
OutputOps::template apply<Index>(trans_C, mm, nn, C, m, n); // OutputOps::template apply<Index>(trans_C, mm, nn, C, m, n);
} // }
} // }
} // }
} }
__device__ void __device__ void
sum_slices(float* tb_frag_C, float* tb_smem_C, int warp_id_m, int warp_id_n, int lane_id, int slice_id) sum_slices(float* tb_frag_C, float* tb_smem_C, int warp_id_m, int warp_id_n, int lane_id, int slice_id)
{ {
int offset_m = warp_id_m * WARP_M / OP_M; // int offset_m = warp_id_m * WARP_M / OP_M;
int offset_n = warp_id_n * WARP_N / OP_N; // int offset_n = warp_id_n * WARP_N / OP_N;
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int z = 0; z < SLICES; ++z) { // for (int z = 0; z < SLICES; ++z) {
if (slice_id == z) { // if (slice_id == z) {
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int i = 0; i < WARP_N / OP_N; ++i) { // for (int i = 0; i < WARP_N / OP_N; ++i) {
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int j = 0; j < WARP_M / OP_M; ++j) { // for (int j = 0; j < WARP_M / OP_M; ++j) {
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int x = 0; x < 4; ++x) { // for (int x = 0; x < 4; ++x) {
int src = (i * WARP_M / OP_M + j) * 4 + x; // int src = (i * WARP_M / OP_M + j) * 4 + x;
int dst = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x; // int dst = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x;
if (z > 0) { // if (z > 0) {
using namespace ops; // using namespace ops;
tb_frag_C[src] = tb_smem_C[dst * WARP_SIZE + lane_id] + tb_frag_C[src]; // tb_frag_C[src] = tb_smem_C[dst * WARP_SIZE + lane_id] + tb_frag_C[src];
} // }
tb_smem_C[dst * WARP_SIZE + lane_id] = tb_frag_C[src]; // tb_smem_C[dst * WARP_SIZE + lane_id] = tb_frag_C[src];
} // }
} // }
} // }
} // }
__syncthreads(); // __syncthreads();
} // }
if (slice_id == 0) { // if (slice_id == 0) {
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int i = 0; i < WARP_N / OP_N; ++i) { // for (int i = 0; i < WARP_N / OP_N; ++i) {
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int j = 0; j < WARP_M / OP_M; ++j) { // for (int j = 0; j < WARP_M / OP_M; ++j) {
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int x = 0; x < 4; ++x) { // for (int x = 0; x < 4; ++x) {
int src = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x; // int src = ((i + offset_n) * CTA_M / OP_M + j + offset_m) * 4 + x;
int dst = (i * WARP_M / OP_M + j) * 4 + x; // int dst = (i * WARP_M / OP_M + j) * 4 + x;
tb_frag_C[dst] = tb_smem_C[src * WARP_SIZE + lane_id]; // tb_frag_C[dst] = tb_smem_C[src * WARP_SIZE + lane_id];
} // }
} // }
} // }
} // }
} }
Array<half, 8> warp_frag_A_[2][WARP_M / OP_M]; // Array<half, 8> warp_frag_A_[2][WARP_M / OP_M];
Array<half, 4> warp_frag_B_[2][WARP_N / OP_N]; // Array<half, 4> warp_frag_B_[2][WARP_N / OP_N];
__device__ void run_v2(half* __restrict__ C, __device__ void run_v2(half* __restrict__ C,
const uint* __restrict__ A, const uint* __restrict__ A,
...@@ -364,89 +365,89 @@ struct Gemm { ...@@ -364,89 +365,89 @@ struct Gemm {
int K, int K,
int output_op_idx) int output_op_idx)
{ {
static_assert(WARP_M % OP_N == 0); // static_assert(WARP_M % OP_N == 0);
float tb_frag_C[(WARP_N / OP_N) * (WARP_M / OP_M) * 4]; // float tb_frag_C[(WARP_N / OP_N) * (WARP_M / OP_M) * 4];
extern __shared__ uint8_t smem[]; // extern __shared__ uint8_t smem[];
const int warp_id = threadIdx.x / WARP_SIZE; // const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE; // const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id_m = warp_id % kWarpCountM; // const int warp_id_m = warp_id % kWarpCountM;
const int warp_id_nk = warp_id / kWarpCountM; // const int warp_id_nk = warp_id / kWarpCountM;
const int warp_id_n = warp_id_nk % kWarpCountN; // const int warp_id_n = warp_id_nk % kWarpCountN;
const int warp_id_k = warp_id_nk / kWarpCountN; // const int warp_id_k = warp_id_nk / kWarpCountN;
const int warp_id_mn = warp_id_n * kWarpCountM + warp_id_m; // const int warp_id_mn = warp_id_n * kWarpCountM + warp_id_m;
const int slice_id = warp_id_k; // const int slice_id = warp_id_k;
const int cta_k = slice_id * SLICE_K; // sliced-k offset // const int cta_k = slice_id * SLICE_K; // sliced-k offset
const int cta_m = blockIdx.x * CTA_M; // const int cta_m = blockIdx.x * CTA_M;
const int cta_n = blockIdx.y * CTA_N; // const int cta_n = blockIdx.y * CTA_N;
// each slice has its own partition of smem // // each slice has its own partition of smem
uint4* const tb_smem_A = (uint4*)(smem + IteratorA::kSmemByteSize * slice_id); // uint4* const tb_smem_A = (uint4*)(smem + IteratorA::kSmemByteSize * slice_id);
half* const tb_smem_B = (half*)(smem + IteratorA::kSmemByteSize * SLICES + IteratorB::kSmemByteSize * slice_id); // half* const tb_smem_B = (half*)(smem + IteratorA::kSmemByteSize * SLICES + IteratorB::kSmemByteSize * slice_id);
// [CTA_N / OP_N, CTA_M / OP_M, 4, WARP_SIZE], all mn fragments in CTA // // [CTA_N / OP_N, CTA_M / OP_M, 4, WARP_SIZE], all mn fragments in CTA
float* const tb_smem_C = (float*)smem; // float* const tb_smem_C = (float*)smem;
__shared__ typename IteratorQ::Storage tb_smem_Q_storage; // __shared__ typename IteratorQ::Storage tb_smem_Q_storage;
auto tb_smem_Q = tb_smem_Q_storage.data[slice_id]; // auto tb_smem_Q = tb_smem_Q_storage.data[slice_id];
IteratorA iter_A{A, tb_smem_A, M, K, cta_m, cta_k, warp_id_mn, lane_id}; // IteratorA iter_A{A, tb_smem_A, M, K, cta_m, cta_k, warp_id_mn, lane_id};
IteratorQ iter_Q{Q, tb_smem_Q, M, K, cta_m, cta_k, warp_id_mn, lane_id}; // IteratorQ iter_Q{Q, tb_smem_Q, M, K, cta_m, cta_k, warp_id_mn, lane_id};
IteratorB iter_B{B, tb_smem_B, K, N, cta_n, cta_k, warp_id_mn, lane_id}; // IteratorB iter_B{B, tb_smem_B, K, N, cta_n, cta_k, warp_id_mn, lane_id};
const int offset_m = warp_id_m * WARP_M + lane_id; // const int offset_m = warp_id_m * WARP_M + lane_id;
WarpIterA warp_iter_A(iter_A.smem_, iter_Q.smem_, warp_id, lane_id, offset_m, cta_k); // WarpIterA warp_iter_A(iter_A.smem_, iter_Q.smem_, warp_id, lane_id, offset_m, cta_k);
WarpIterB warp_iter_B(iter_B.smem_int_ptr_, warp_id_n, lane_id, 0); // WarpIterB warp_iter_B(iter_B.smem_int_ptr_, warp_id_n, lane_id, 0);
int gemm_iter = (K + CTA_K - 1) / CTA_K; // int gemm_iter = (K + CTA_K - 1) / CTA_K;
PRAGMA_UNROLL // PRAGMA_UNROLL
for (int stage = 0; stage < STAGES - 1; ++stage, --gemm_iter) { // for (int stage = 0; stage < STAGES - 1; ++stage, --gemm_iter) {
iter_A.prefetch_stage(gemm_iter > 0); // iter_A.prefetch_stage(gemm_iter > 0);
iter_Q.prefetch_stage(gemm_iter > 0); // iter_Q.prefetch_stage(gemm_iter > 0);
iter_B.prefetch_stage(gemm_iter > 0); // iter_B.prefetch_stage(gemm_iter > 0);
__pipeline_commit(); // __pipeline_commit();
} // }
clear(tb_frag_C); // clear(tb_frag_C);
__pipeline_wait_prior(STAGES - 2); // __pipeline_wait_prior(STAGES - 2);
sync_slice(slice_id); // sync_slice(slice_id);
warp_iter_A.load(warp_frag_A_[0], 0); // warp_iter_A.load(warp_frag_A_[0], 0);
warp_iter_B.load(warp_frag_B_[0], 0); // warp_iter_B.load(warp_frag_B_[0], 0);
PRAGMA_NO_UNROLL // PRAGMA_NO_UNROLL
for (; gemm_iter > -STAGES + 1;) { // for (; gemm_iter > -STAGES + 1;) {
warp_mma(iter_A, iter_Q, iter_B, warp_iter_A, warp_iter_B, tb_frag_C, slice_id, gemm_iter); // warp_mma(iter_A, iter_Q, iter_B, warp_iter_A, warp_iter_B, tb_frag_C, slice_id, gemm_iter);
} // }
__pipeline_commit(); // __pipeline_commit();
__pipeline_wait_prior(0); // __pipeline_wait_prior(0);
__syncthreads(); // __syncthreads();
if constexpr (SLICES > 1) { // if constexpr (SLICES > 1) {
sum_slices(tb_frag_C, tb_smem_C, warp_id_m, warp_id_n, lane_id, slice_id); // sum_slices(tb_frag_C, tb_smem_C, warp_id_m, warp_id_n, lane_id, slice_id);
} // }
switch (output_op_idx) { // switch (output_op_idx) {
case 0: // case 0:
store_accum<0>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id); // store_accum<0>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id);
break; // break;
case 1: // case 1:
store_accum<1>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id); // store_accum<1>(tb_frag_C, tb_smem_C, C, M, N, cta_m, cta_n, warp_id_m, warp_id_n, lane_id, slice_id);
break; // break;
default: // default:
return; // return;
} // }
} }
}; };
......
...@@ -78,6 +78,7 @@ bool BlockManager::Malloc() ...@@ -78,6 +78,7 @@ bool BlockManager::Malloc()
return false; return false;
} }
//auto ptr = (std::byte*)allocator_->malloc(block_size_ * chunk_size);
auto ptr = (uint8_t*)allocator_->malloc(block_size_ * chunk_size); auto ptr = (uint8_t*)allocator_->malloc(block_size_ * chunk_size);
if (!ptr) { if (!ptr) {
return false; return false;
......
...@@ -19,13 +19,14 @@ add_library(Llama STATIC ...@@ -19,13 +19,14 @@ add_library(Llama STATIC
unified_attention_layer.cc unified_attention_layer.cc
llama_kernels.cu llama_kernels.cu
llama_decoder_kernels.cu llama_decoder_kernels.cu
llama_utils.cu) llama_utils.cu
./awq_sugon/gemm_w4_dequation.cu)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
#set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(Llama PUBLIC cudart target_link_libraries(Llama PUBLIC cudart
# gemm_s4_f16 gemm_s4_f16
cublasMMWrapper cublasMMWrapper
DynamicDecodeLayer DynamicDecodeLayer
activation_kernels activation_kernels
...@@ -41,7 +42,8 @@ target_link_libraries(Llama PUBLIC cudart ...@@ -41,7 +42,8 @@ target_link_libraries(Llama PUBLIC cudart
memory_utils memory_utils
nccl_utils nccl_utils
cuda_utils cuda_utils
logger) logger
gemm_multiB_int4)
# llama_fmha) # llama_fmha)
if (NOT MSVC) if (NOT MSVC)
......
...@@ -41,6 +41,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num, ...@@ -41,6 +41,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
size_t inter_size, size_t inter_size,
WeightType weight_type, WeightType weight_type,
int group_size, int group_size,
int w4_weight_layout,
bool attn_bias, bool attn_bias,
size_t tensor_para_size, size_t tensor_para_size,
size_t tensor_para_rank): size_t tensor_para_rank):
...@@ -58,31 +59,37 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num, ...@@ -58,31 +59,37 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_; self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_;
self_attn_weights.qkv.type = weight_type; self_attn_weights.qkv.type = weight_type;
self_attn_weights.qkv.group_size = group_size; self_attn_weights.qkv.group_size = group_size;
self_attn_weights.qkv.w4_weight_layout = w4_weight_layout;
self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_; self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_;
self_attn_weights.output.output_dims = hidden_units_; self_attn_weights.output.output_dims = hidden_units_;
self_attn_weights.output.type = weight_type; self_attn_weights.output.type = weight_type;
self_attn_weights.output.group_size = group_size; self_attn_weights.output.group_size = group_size;
self_attn_weights.output.w4_weight_layout = w4_weight_layout;
ffn_weights.gating.input_dims = hidden_units_; ffn_weights.gating.input_dims = hidden_units_;
ffn_weights.gating.output_dims = inter_size_ / tensor_para_size_; ffn_weights.gating.output_dims = inter_size_ / tensor_para_size_;
ffn_weights.gating.type = weight_type; ffn_weights.gating.type = weight_type;
ffn_weights.gating.group_size = group_size; ffn_weights.gating.group_size = group_size;
ffn_weights.gating.w4_weight_layout = w4_weight_layout;
ffn_weights.intermediate.input_dims = hidden_units_; ffn_weights.intermediate.input_dims = hidden_units_;
ffn_weights.intermediate.output_dims = inter_size_ / tensor_para_size_; ffn_weights.intermediate.output_dims = inter_size_ / tensor_para_size_;
ffn_weights.intermediate.type = weight_type; ffn_weights.intermediate.type = weight_type;
ffn_weights.intermediate.group_size = group_size; ffn_weights.intermediate.group_size = group_size;
ffn_weights.intermediate.w4_weight_layout = w4_weight_layout;
ffn_weights.fused_gating_intermediate.input_dims = hidden_units_; ffn_weights.fused_gating_intermediate.input_dims = hidden_units_;
ffn_weights.fused_gating_intermediate.output_dims = inter_size_ / tensor_para_size_ * 2; ffn_weights.fused_gating_intermediate.output_dims = inter_size_ / tensor_para_size_ * 2;
ffn_weights.fused_gating_intermediate.type = weight_type; ffn_weights.fused_gating_intermediate.type = weight_type;
ffn_weights.fused_gating_intermediate.group_size = group_size; ffn_weights.fused_gating_intermediate.group_size = group_size;
ffn_weights.fused_gating_intermediate.w4_weight_layout = w4_weight_layout;
ffn_weights.output.input_dims = inter_size_ / tensor_para_size_; ffn_weights.output.input_dims = inter_size_ / tensor_para_size_;
ffn_weights.output.output_dims = hidden_units_; ffn_weights.output.output_dims = hidden_units_;
ffn_weights.output.type = weight_type; ffn_weights.output.type = weight_type;
ffn_weights.output.group_size = group_size; ffn_weights.output.group_size = group_size;
ffn_weights.output.w4_weight_layout = w4_weight_layout;
mallocWeights(); mallocWeights();
} }
...@@ -111,10 +118,28 @@ void mallocWeights(LlamaDenseWeight<T>& weights, bool bias) ...@@ -111,10 +118,28 @@ void mallocWeights(LlamaDenseWeight<T>& weights, bool bias)
else { // int8, int4 else { // int8, int4
const int factor = sizeof(float) * 8 / bit_size; const int factor = sizeof(float) * 8 / bit_size;
FT_CHECK(weights.input_dims % factor == 0); FT_CHECK(weights.input_dims % factor == 0);
deviceMalloc((int**)&weights.kernel, weights.input_dims * weights.output_dims / factor); // //读环境变量
deviceMemSetZero((int*)weights.kernel, weights.input_dims * weights.output_dims / factor); // int m_weightlayout_switch=1;
// interleaved scales/zeros // const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
deviceMalloc((T**)&weights.scales_and_zeros, weights.input_dims / weights.group_size * weights.output_dims * 2); // if (env_weightlayout_str != nullptr) {
// m_weightlayout_switch = std::stoi(env_weightlayout_str);
// }
if((weights.input_dims%4096==0)&&(weights.w4_weight_layout==1||weights.w4_weight_layout==2))
{
size_t new_input_dims=weights.input_dims+2*weights.group_size;
deviceMalloc((int**)&weights.kernel, new_input_dims * weights.output_dims / factor);
deviceMemSetZero((int*)weights.kernel, new_input_dims* weights.output_dims / factor);
// interleaved scales/zeros
deviceMalloc((T**)&weights.scales_and_zeros, new_input_dims / weights.group_size * weights.output_dims * 2);
}
else{
deviceMalloc((int**)&weights.kernel, weights.input_dims * weights.output_dims / factor);
deviceMemSetZero((int*)weights.kernel, weights.input_dims * weights.output_dims / factor);
// interleaved scales/zeros
deviceMalloc((T**)&weights.scales_and_zeros, weights.input_dims / weights.group_size * weights.output_dims * 2);
}
} }
} }
...@@ -146,16 +171,39 @@ void getWeightTensor(LlamaDenseWeight<T>& weights, bool bias, const std::string& ...@@ -146,16 +171,39 @@ void getWeightTensor(LlamaDenseWeight<T>& weights, bool bias, const std::string&
} }
else { // int8, int4 else { // int8, int4
const int factor = sizeof(float) * 8 / bit_size; const int factor = sizeof(float) * 8 / bit_size;
output.insert(get_name("qweight"), // //读环境变量
Tensor{MEMORY_GPU, // int m_weightlayout_switch=1;
TYPE_INT32, // const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
{weights.input_dims * weights.output_dims * sizeof(int) / factor}, // if (env_weightlayout_str != nullptr) {
weights.kernel}); // m_weightlayout_switch = std::stoi(env_weightlayout_str);
output.insert(get_name("scales_zeros"), // }
Tensor{MEMORY_GPU, if((weights.input_dims%4096==0)&&(weights.w4_weight_layout==1||weights.w4_weight_layout==2))
getTensorType<T>(), {
{weights.input_dims / weights.group_size * weights.output_dims * 2 * sizeof(T)}, size_t new_input_dims=weights.input_dims+weights.group_size;
weights.scales_and_zeros});
output.insert(get_name("qweight"),
Tensor{MEMORY_GPU,
TYPE_INT32,
{new_input_dims * weights.output_dims * sizeof(int) / factor},
weights.kernel});
output.insert(get_name("scales_zeros"),
Tensor{MEMORY_GPU,
getTensorType<T>(),
{new_input_dims / weights.group_size * weights.output_dims * 2 * sizeof(T)},
weights.scales_and_zeros});
}
else{
output.insert(get_name("qweight"),
Tensor{MEMORY_GPU,
TYPE_INT32,
{weights.input_dims * weights.output_dims * sizeof(int) / factor},
weights.kernel});
output.insert(get_name("scales_zeros"),
Tensor{MEMORY_GPU,
getTensorType<T>(),
{weights.input_dims / weights.group_size * weights.output_dims * 2 * sizeof(T)},
weights.scales_and_zeros});
}
} }
} }
...@@ -259,12 +307,31 @@ void loadWeights(LlamaDenseWeight<T>& w, ...@@ -259,12 +307,31 @@ void loadWeights(LlamaDenseWeight<T>& w,
FT_CHECK(dim1 % factor == 0); FT_CHECK(dim1 % factor == 0);
std::vector<size_t> w_shape{dim0, dim1 / factor * sizeof(uint32_t)}; // //读环境变量
loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8, {}); // int m_weightlayout_switch=1;
// const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
// if (env_weightlayout_str != nullptr) {
// m_weightlayout_switch = std::stoi(env_weightlayout_str);
// }
if((dim0%4096==0)&&(w.w4_weight_layout==1||w.w4_weight_layout==2))
{
size_t new_dim0=dim0+2*w.group_size;
std::vector<size_t> w_shape{new_dim0, dim1 / factor * sizeof(uint32_t)};
loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8, {});
const size_t group_count = w.group_size > 0 ? new_dim0 / w.group_size : 1;
const size_t group_count = w.group_size > 0 ? dim0 / w.group_size : 1; loadWeightFromBin((half*)w.scales_and_zeros, {group_count, dim1 * 2}, prefix + ".scales_zeros", type, {});
}
else{
std::vector<size_t> w_shape{dim0, dim1 / factor * sizeof(uint32_t)};
loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8, {});
const size_t group_count = w.group_size > 0 ? dim0 / w.group_size : 1;
loadWeightFromBin((half*)w.scales_and_zeros, {group_count, dim1 * 2}, prefix + ".scales_zeros", type, {}); loadWeightFromBin((half*)w.scales_and_zeros, {group_count, dim1 * 2}, prefix + ".scales_zeros", type, {});
}
} }
} }
......
...@@ -35,6 +35,7 @@ public: ...@@ -35,6 +35,7 @@ public:
size_t inter_size, size_t inter_size,
WeightType weight_type, WeightType weight_type,
int group_size, int group_size,
int w4_weight_layout,
bool attn_bias, bool attn_bias,
size_t tensor_para_size, size_t tensor_para_size,
size_t tensor_para_rank); size_t tensor_para_rank);
......
...@@ -63,6 +63,7 @@ struct LlamaDenseWeight { ...@@ -63,6 +63,7 @@ struct LlamaDenseWeight {
T* bias; T* bias;
T* scales_and_zeros; T* scales_and_zeros;
int group_size; int group_size;
int w4_weight_layout;
}; };
template<typename T> template<typename T>
......
...@@ -29,7 +29,7 @@ namespace turbomind { ...@@ -29,7 +29,7 @@ namespace turbomind {
template<typename T> template<typename T>
void LlamaFfnLayer<T>::allocateBuffer(size_t token_num) void LlamaFfnLayer<T>::allocateBuffer(size_t token_num)
{ {
inter_buf_ = (T*)allocator_->reMalloc(inter_buf_, sizeof(T) * token_num * inter_size_, false); inter_buf_ = (T*)allocator_->reMalloc(inter_buf_, 2*sizeof(T) * token_num * inter_size_, false);
gating_buf_ = (T*)allocator_->reMalloc(gating_buf_, sizeof(T) * token_num * inter_size_, false); gating_buf_ = (T*)allocator_->reMalloc(gating_buf_, sizeof(T) * token_num * inter_size_, false);
is_allocate_buffer_ = true; is_allocate_buffer_ = true;
} }
...@@ -90,8 +90,11 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors, ...@@ -90,8 +90,11 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,
if (weights->fused_gating_intermediate.kernel) { if (weights->fused_gating_intermediate.kernel) {
NvtxScope scope("fused_silu_ffn"); NvtxScope scope("fused_silu_ffn");
linear_.forward( // linear_.forward(
gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear<T>::kFusedSiluFfn); // gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear<T>::kFusedSiluFfn);
linear_.forward_ffn(
gating_buf_,inter_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear<T>::kFusedSiluFfn);
} }
else { else {
{ // w1(x) { // w1(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment