Commit 8efb9210 authored by zhouxiang's avatar zhouxiang
Browse files

Merge branch 'dtk24.04-v0.2.6_awq' into 'dtk24.04-v0.2.6'

合入dtk2404-v0.2.6版本int4量化推理部分

See merge request dcutoolkit/deeplearing/lmdeploy!2
parents 2326380c 175eaedb
...@@ -5,7 +5,7 @@ __pycache__/ ...@@ -5,7 +5,7 @@ __pycache__/
.vscode/ .vscode/
.idea/ .idea/
# C extensions # C extensions
*.so #*.so
# Distribution / packaging # Distribution / packaging
.Python .Python
......
该处so主要用于awq功能使用
\ No newline at end of file
...@@ -366,7 +366,7 @@ add_library(transformer-shared SHARED ...@@ -366,7 +366,7 @@ add_library(transformer-shared SHARED
# $<TARGET_OBJECTS:flash_attention2> # $<TARGET_OBJECTS:flash_attention2>
$<TARGET_OBJECTS:Llama> $<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:LlamaTritonBackend> $<TARGET_OBJECTS:LlamaTritonBackend>
# $<TARGET_OBJECTS:gemm_s4_f16> $<TARGET_OBJECTS:gemm_s4_f16>
$<TARGET_OBJECTS:TopKSamplingLayer> $<TARGET_OBJECTS:TopKSamplingLayer>
$<TARGET_OBJECTS:TopPSamplingLayer> $<TARGET_OBJECTS:TopPSamplingLayer>
$<TARGET_OBJECTS:TransformerTritonBackend> $<TARGET_OBJECTS:TransformerTritonBackend>
......
...@@ -61,7 +61,11 @@ class CLI(object): ...@@ -61,7 +61,11 @@ class CLI(object):
default=0, default=0,
help='A parameter used in awq to quantize fp16 weights ' help='A parameter used in awq to quantize fp16 weights '
'to 4 bits') 'to 4 bits')
parser.add_argument(
'--w4-weight-layout',
type=int,
default=2,
help='A parameter used in AWQ to control the layout of weight ')
parser.set_defaults(run=CLI.convert) parser.set_defaults(run=CLI.convert)
@staticmethod @staticmethod
......
...@@ -196,6 +196,7 @@ def main(model_name: str, ...@@ -196,6 +196,7 @@ def main(model_name: str,
tp: int = 1, tp: int = 1,
quant_path: str = None, quant_path: str = None,
group_size: int = 0, group_size: int = 0,
w4_weight_layout: int = 2,
**kwargs): **kwargs):
"""deploy llama family models via turbomind. """deploy llama family models via turbomind.
...@@ -215,6 +216,7 @@ def main(model_name: str, ...@@ -215,6 +216,7 @@ def main(model_name: str,
quant_path (str): Path of the quantized model, which can be None. quant_path (str): Path of the quantized model, which can be None.
group_size (int): a parameter used in AWQ to quantize fp16 weights group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits to 4 bits
w4_weight_layout (int) :a parameter used in AWQ to control the layout of weight
kwargs (dict): other params for convert kwargs (dict): other params for convert
""" """
...@@ -260,10 +262,13 @@ def main(model_name: str, ...@@ -260,10 +262,13 @@ def main(model_name: str,
cfg.tensor_para_size = tp cfg.tensor_para_size = tp
cfg.rotary_embedding = cfg.size_per_head cfg.rotary_embedding = cfg.size_per_head
cfg.group_size = group_size cfg.group_size = group_size
cfg.w4_weight_layout=w4_weight_layout
if inferred_model_format.find('awq') != -1: if inferred_model_format.find('awq') != -1:
cfg.weight_type = 'int4' cfg.weight_type = 'int4'
output_format = 'w4' output_format = 'w4'
assert group_size > 0, f'group_size: {group_size} should > 0' assert group_size > 0, f'group_size: {group_size} should > 0'
print("w4_weight_layout:",w4_weight_layout)
assert w4_weight_layout>=0 and w4_weight_layout<3,f'w4_weight_layout: {w4_weight_layout} should >= 0 and < 3'
else: else:
#output_format = update_output_format(model_name, inferred_model_format, #output_format = update_output_format(model_name, inferred_model_format,
# model_path, output_format) # model_path, output_format)
......
...@@ -5,6 +5,7 @@ import inspect ...@@ -5,6 +5,7 @@ import inspect
import io import io
import json import json
import os.path as osp import os.path as osp
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from configparser import ConfigParser from configparser import ConfigParser
...@@ -52,6 +53,7 @@ class TurbomindModelConfig: ...@@ -52,6 +53,7 @@ class TurbomindModelConfig:
rope_theta: float = 10000.0 rope_theta: float = 10000.0
size_per_head: int = 128 size_per_head: int = 128
group_size: int = 0 group_size: int = 0
w4_weight_layout : int = 2
max_batch_size: int = 64 max_batch_size: int = 64
max_context_token_num: int = 1 max_context_token_num: int = 1
step_length: int = 1 step_length: int = 1
...@@ -150,6 +152,12 @@ class BaseOutputModel(ABC): ...@@ -150,6 +152,12 @@ class BaseOutputModel(ABC):
self.to_file = to_file self.to_file = to_file
self.out_dir = out_dir self.out_dir = out_dir
self.tm_params = {} self.tm_params = {}
#self.weight_layout= 1
#获取环境变量
#env_weight_layout = os.environ.get('LMDEPLOY_WEIGHTLAYOUT_SWITCH', '1')
#self.weight_layout =int(env_weight_layout)
#print("self.weight_layout:",self.weight_layout)
@abstractmethod @abstractmethod
def get_config(self, cfg: TurbomindModelConfig) -> TurbomindModelConfig: def get_config(self, cfg: TurbomindModelConfig) -> TurbomindModelConfig:
...@@ -317,6 +325,10 @@ def permute(x: torch.Tensor, size_per_head: int = 128): ...@@ -317,6 +325,10 @@ def permute(x: torch.Tensor, size_per_head: int = 128):
return x.view(n_heads, 2, dim // n_heads // 2, return x.view(n_heads, 2, dim // n_heads // 2,
1).transpose(1, 2).reshape(dim, 1) 1).transpose(1, 2).reshape(dim, 1)
def permute_trans(x: torch.Tensor):
if x.shape[-1]>1:
dim = x.shape[-1]
return x.view(-1, x.shape[-1]).transpose(0, 1).reshape(dim,-1)
def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int, def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int,
dim: int): dim: int):
......
...@@ -8,7 +8,7 @@ import lmdeploy ...@@ -8,7 +8,7 @@ import lmdeploy
from ..source_model.base import BaseInputModel, BaseReader from ..source_model.base import BaseInputModel, BaseReader
from .base import (OUTPUT_MODELS, BaseOutputModel, TurbomindModelConfig, from .base import (OUTPUT_MODELS, BaseOutputModel, TurbomindModelConfig,
merge_qkv, permute) merge_qkv, permute,permute_trans)
# import _turbomind as _tm # import _turbomind as _tm
# TODO: find another way import _turbomind # TODO: find another way import _turbomind
...@@ -56,6 +56,18 @@ def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor, ...@@ -56,6 +56,18 @@ def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
qw.size(-1) * 8, qw.size(0), group_size) qw.size(-1) * 8, qw.size(0), group_size)
return _qw, _sz return _qw, _sz
def convert_s4_(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size: int):
assert qw.is_contiguous()
assert qz.is_contiguous()
assert s.is_contiguous()
_qw = torch.zeros_like(qw)
_sz = torch.zeros_like(s, dtype=torch.int32) # half2
_ws = torch.zeros_like(s)
_tm.convert_s4_k_m8_(_qw, _sz, _ws, qw, s, qz,
qw.size(-1) * 8, qw.size(0), group_size)
return _qw, _sz
def tp_m_s4(x: torch.Tensor, tp: int): def tp_m_s4(x: torch.Tensor, tp: int):
return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3, return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
...@@ -104,6 +116,7 @@ class TurbomindW4Model(BaseOutputModel): ...@@ -104,6 +116,7 @@ class TurbomindW4Model(BaseOutputModel):
"""Export transformer layer i.""" """Export transformer layer i."""
group_size = self.cfg.group_size group_size = self.cfg.group_size
tp = self.cfg.tensor_para_size tp = self.cfg.tensor_para_size
w4_weight_layout = self.cfg.w4_weight_layout
size_per_head = self.cfg.size_per_head size_per_head = self.cfg.size_per_head
# attn # attn
q_qw, k_qw, v_qw, o_qw = get_cuda_tensor(bin.attn(i)) q_qw, k_qw, v_qw, o_qw = get_cuda_tensor(bin.attn(i))
...@@ -121,11 +134,44 @@ class TurbomindW4Model(BaseOutputModel): ...@@ -121,11 +134,44 @@ class TurbomindW4Model(BaseOutputModel):
qkv_qz = merge_qkv(q_qz, k_qz, v_qz, tp, dim=2) qkv_qz = merge_qkv(q_qz, k_qz, v_qz, tp, dim=2)
qkv_s = merge_qkv(q_s, k_s, v_s, tp, dim=2) qkv_s = merge_qkv(q_s, k_s, v_s, tp, dim=2)
pad_group_count=2
if w4_weight_layout==1 or w4_weight_layout==2:
if qkv_qw.shape[0]%4096==0:
qkv_qw_padding=torch.zeros(group_size*pad_group_count,qkv_qw.shape[1],dtype=torch.int32).cuda()
qkv_qw =torch.cat((qkv_qw,qkv_qw_padding),dim=0).contiguous()
qkv_qz_padding =torch.zeros(pad_group_count,qkv_qz.shape[1],dtype=torch.int32).cuda()
qkv_qz =torch.cat((qkv_qz,qkv_qz_padding),dim=0).contiguous()
qkv_s_padding =torch.zeros(pad_group_count,qkv_s.shape[1],dtype=torch.float16).cuda()
qkv_s =torch.cat((qkv_s,qkv_s_padding),dim=0).contiguous()
qkv_qw, qkv_sz = convert_s4_(qkv_qw, qkv_qz, qkv_s, group_size)
qkv_qw = tp_m_s4(qkv_qw, tp)
qkv_sz = permute_trans(qkv_sz)
else:
qkv_qw, qkv_sz = convert_s4(qkv_qw, qkv_qz, qkv_s, group_size) qkv_qw, qkv_sz = convert_s4(qkv_qw, qkv_qz, qkv_s, group_size)
qkv_qw = tp_m_s4(qkv_qw, tp) qkv_qw = tp_m_s4(qkv_qw, tp)
#print("请设置weight layout\n")
self.save_split(qkv_qw, f'layers.{i}.attention.w_qkv.qweight', -1) self.save_split(qkv_qw, f'layers.{i}.attention.w_qkv.qweight', -1)
self.save_split(qkv_sz, f'layers.{i}.attention.w_qkv.scales_zeros', -1) self.save_split(qkv_sz, f'layers.{i}.attention.w_qkv.scales_zeros', -1)
if w4_weight_layout==1 or w4_weight_layout==2:
if o_qw.shape[0]%4096==0:
o_qw_padding=torch.zeros(group_size*pad_group_count,o_qw.shape[1],dtype=torch.int32).cuda()
o_qw =torch.cat((o_qw,o_qw_padding),dim=0).contiguous()
o_qz_padding =torch.zeros(pad_group_count,o_qz.shape[1],dtype=torch.int32).cuda()
o_qz =torch.cat((o_qz,o_qz_padding),dim=0).contiguous()
o_s_padding =torch.zeros(pad_group_count,o_s.shape[1],dtype=torch.float16).cuda()
o_s =torch.cat((o_s,o_s_padding),dim=0).contiguous()
o_qw, o_sz = convert_s4_(o_qw, o_qz, o_s, group_size)
o_sz = permute_trans(o_sz)
else:
o_qw, o_sz = convert_s4(o_qw, o_qz, o_s, group_size) o_qw, o_sz = convert_s4(o_qw, o_qz, o_s, group_size)
self.save_split(o_qw, f'layers.{i}.attention.wo.qweight', 0) self.save_split(o_qw, f'layers.{i}.attention.wo.qweight', 0)
self.save_split(o_sz, f'layers.{i}.attention.wo.scales_zeros', 0) self.save_split(o_sz, f'layers.{i}.attention.wo.scales_zeros', 0)
...@@ -145,13 +191,45 @@ class TurbomindW4Model(BaseOutputModel): ...@@ -145,13 +191,45 @@ class TurbomindW4Model(BaseOutputModel):
w13_qw, w13_qz, w13_s = fuse_w1_w3_s4(w1_qw, w1_qz, w1_s, w3_qw, w3_qz, w13_qw, w13_qz, w13_s = fuse_w1_w3_s4(w1_qw, w1_qz, w1_s, w3_qw, w3_qz,
w3_s) w3_s)
if w4_weight_layout==1 or w4_weight_layout==2:
if w13_qw.shape[0]%4096==0:
w13_qw_padding=torch.zeros(group_size*pad_group_count,w13_qw.shape[1],dtype=torch.int32).cuda()
w13_qw =torch.cat((w13_qw,w13_qw_padding),dim=0).contiguous()
w13_qz_padding =torch.zeros(pad_group_count,w13_qz.shape[1],dtype=torch.int32).cuda()
w13_qz =torch.cat((w13_qz,w13_qz_padding),dim=0).contiguous()
w13_s_padding =torch.zeros(pad_group_count,w13_s.shape[1],dtype=torch.float16).cuda()
w13_s =torch.cat((w13_s,w13_s_padding),dim=0).contiguous()
w13_qw, w13_sz = convert_s4_(w13_qw, w13_qz, w13_s, group_size)
w13_qw = tp_m_s4(w13_qw, tp)
w13_sz = permute_trans(w13_sz)
else:
w13_qw, w13_sz = convert_s4(w13_qw, w13_qz, w13_s, group_size) w13_qw, w13_sz = convert_s4(w13_qw, w13_qz, w13_s, group_size)
w13_qw = tp_m_s4(w13_qw, tp) w13_qw = tp_m_s4(w13_qw, tp)
self.save_split(w13_qw, f'layers.{i}.feed_forward.w13.qweight', -1) self.save_split(w13_qw, f'layers.{i}.feed_forward.w13.qweight', -1)
self.save_split(w13_sz, f'layers.{i}.feed_forward.w13.scales_zeros', self.save_split(w13_sz, f'layers.{i}.feed_forward.w13.scales_zeros',
-1) -1)
if w4_weight_layout==1 or w4_weight_layout==2:
#pading
if w2_qw.shape[0]%4096==0:
w2_qw_padding=torch.zeros(group_size*pad_group_count,w2_qw.shape[1],dtype=torch.int32).cuda()
w2_qw =torch.cat((w2_qw,w2_qw_padding),dim=0).contiguous()
w2_qz_padding =torch.zeros(pad_group_count,w2_qz.shape[1],dtype=torch.int32).cuda()
w2_qz =torch.cat((w2_qz,w2_qz_padding),dim=0).contiguous()
w2_s_padding =torch.zeros(pad_group_count,w2_s.shape[1],dtype=torch.float16).cuda()
w2_s =torch.cat((w2_s,w2_s_padding),dim=0).contiguous()
w2_qw, w2_sz = convert_s4_(w2_qw, w2_qz, w2_s, group_size)
w2_sz = permute_trans(w2_sz)
else:
w2_qw, w2_sz = convert_s4(w2_qw, w2_qz, w2_s, group_size) w2_qw, w2_sz = convert_s4(w2_qw, w2_qz, w2_s, group_size)
self.save_split(w2_qw, f'layers.{i}.feed_forward.w2.qweight', 0) self.save_split(w2_qw, f'layers.{i}.feed_forward.w2.qweight', 0)
self.save_split(w2_sz, f'layers.{i}.feed_forward.w2.scales_zeros', 0) self.save_split(w2_sz, f'layers.{i}.feed_forward.w2.scales_zeros', 0)
......
...@@ -147,6 +147,7 @@ class TurboMind: ...@@ -147,6 +147,7 @@ class TurboMind:
model_name: Optional[str] = None, model_name: Optional[str] = None,
model_format: Optional[str] = None, model_format: Optional[str] = None,
group_size: Optional[int] = None, group_size: Optional[int] = None,
w4_weight_layout: Optional[int] = None,
tp: Optional[int] = None, tp: Optional[int] = None,
chat_template_config: Optional[ChatTemplateConfig] = None, chat_template_config: Optional[ChatTemplateConfig] = None,
**kwargs): **kwargs):
...@@ -179,6 +180,7 @@ class TurboMind: ...@@ -179,6 +180,7 @@ class TurboMind:
engine_config = _update_engine_config(engine_config, engine_config = _update_engine_config(engine_config,
model_format=model_format, model_format=model_format,
group_size=group_size, group_size=group_size,
w4_weight_layout=w4_weight_layout,
tp=tp, tp=tp,
**kwargs) **kwargs)
...@@ -304,6 +306,7 @@ class TurboMind: ...@@ -304,6 +306,7 @@ class TurboMind:
output_format = 'w4' output_format = 'w4'
data_type = 'int4' data_type = 'int4'
cfg.group_size = 128 cfg.group_size = 128
cfg.w4_weight_layout=2
else: else:
# output_format = update_output_format(cfg.model_name, # output_format = update_output_format(cfg.model_name,
# inferred_model_format, # inferred_model_format,
...@@ -378,6 +381,7 @@ class TurboMind: ...@@ -378,6 +381,7 @@ class TurboMind:
self.config = cfg self.config = cfg
self.model_name = cfg.model_name self.model_name = cfg.model_name
self.data_type = cfg.weight_type self.data_type = cfg.weight_type
#print("from_workspace_cfg:",cfg)
# create model # create model
logger.warning(f'model_config:\n\n{cfg.toini()}') logger.warning(f'model_config:\n\n{cfg.toini()}')
......
...@@ -8,6 +8,7 @@ import subprocess ...@@ -8,6 +8,7 @@ import subprocess
from typing import Optional, Union from typing import Optional, Union
from pathlib import Path from pathlib import Path
import torch import torch
import shutil
pwd = os.path.dirname(__file__) pwd = os.path.dirname(__file__)
version_file = 'lmdeploy/version.py' version_file = 'lmdeploy/version.py'
...@@ -69,7 +70,6 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -69,7 +70,6 @@ def get_version_add(sha: Optional[str] = None) -> str:
file.writelines(lines) file.writelines(lines)
file.close() file.close()
def get_version(): def get_version():
get_version_add() get_version_add()
version_file = 'lmdeploy/version.py' version_file = 'lmdeploy/version.py'
...@@ -185,9 +185,24 @@ def parse_requirements(fname='requirements.txt', with_version=True): ...@@ -185,9 +185,24 @@ def parse_requirements(fname='requirements.txt', with_version=True):
packages += cuda_pkgs packages += cuda_pkgs
return packages return packages
def copy_ck_so():
lmdeploy_root = os.path.dirname(os.path.abspath(__file__))
so_path = os.path.join(os.path.join(lmdeploy_root, "3rdparty","composable_kernel"), "libgemm_multiB_int4.so")
# dtk version
target_path=os.path.join(lmdeploy_root, "lmdeploy","lib")
if os.path.exists(target_path):
shutil.copy(so_path, target_path)
elif os.getenv("ROCM_PATH"):
rocm_path = os.getenv('ROCM_PATH', "")
rocm_so_path = os.path.join(rocm_path, 'lib')
print("rocm_so_path:",rocm_so_path)
shutil.copy(so_path, rocm_so_path)
else:
shutil.copy(so_path, "usr/local/lib")
if __name__ == '__main__': if __name__ == '__main__':
lmdeploy_package_data = ['lmdeploy/bin/llama_gemm'] lmdeploy_package_data = ['lmdeploy/bin/llama_gemm']
copy_ck_so()
setup( setup(
name='lmdeploy', name='lmdeploy',
version=get_version(), version=get_version(),
......
...@@ -72,5 +72,5 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu) ...@@ -72,5 +72,5 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu)
#set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
#add_subdirectory(gemm_s_f16) add_subdirectory(gemm_s_f16)
add_subdirectory(decoder_multihead_attention) add_subdirectory(decoder_multihead_attention)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
add_library(gemm_s4_f16 STATIC gemm_s4_f16.cu format.cu) add_library(gemm_s4_f16 STATIC gemm_s4_f16.cu format.cu)
target_compile_options(gemm_s4_f16 PRIVATE target_compile_options(gemm_s4_f16 PRIVATE
......
...@@ -72,19 +72,23 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) ...@@ -72,19 +72,23 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[0]) // : "=r"(h[0])
// : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 h[0]=(i4s & BOTTOM_MASK)|I4s_TO_F16s_MAGIC_NUM;
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[1]) // : "=r"(h[1])
// : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 h[1]=(i4s & TOP_MASK)|I4s_TO_F16s_MAGIC_NUM;
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[2]) // : "=r"(h[2])
// : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 h[2]=(top_i4s & BOTTOM_MASK)|I4s_TO_F16s_MAGIC_NUM;
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
// : "=r"(h[3]) // : "=r"(h[3])
// : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
printf("=========common.h 86\n"); h[3]=(top_i4s & TOP_MASK)|I4s_TO_F16s_MAGIC_NUM;
// I use inline PTX below because I am not sure if the compiler will emit // I use inline PTX below because I am not sure if the compiler will emit
// float2half instructions if I use the half2 ctor. In this case, I chose // float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability. // performance reliability over code readability.
...@@ -102,14 +106,17 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) ...@@ -102,14 +106,17 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// Finally, we construct the output numbers. // Finally, we construct the output numbers.
// Convert elt_01 // Convert elt_01
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); //asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// // Convert elt_23 h[0]=h[0]-FP16_TOP_MAGIC_NUM;
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); // Convert elt_23
// // Convert elt_45 //asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); h[1]=h[1]*ONE_SIXTEENTH+NEG_64;
// // Convert elt_67 // Convert elt_45
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); //asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
h[2]=h[2]-FP16_TOP_MAGIC_NUM;
// Convert elt_67
//asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
h[3]=h[3]*ONE_SIXTEENTH+NEG_64;
return result; return result;
} }
...@@ -131,31 +138,22 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source) ...@@ -131,31 +138,22 @@ __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required. // dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8; const uint32_t top_i4s = i4s >> 8;
printf("=========common.h 133\n");
// if (0) { // 1024 & 64 // 64 only, trade 4 hfma2 with 2 shifts
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut)); h[0] =(i4s & BOT_MASK) |MAGIC_NUM_2;
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut)); h[1] =(i4s & TOP_MASK) |MAGIC_NUM_1;
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut)); h[2] =(top_i4s & BOT_MASK) |MAGIC_NUM_2;
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut)); h[3] =(top_i4s & TOP_MASK) |MAGIC_NUM_1;
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MAGIC_NUM_0)); h[0] <<= 4;
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MAGIC_NUM_1)); h[2] <<= 4;
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MAGIC_NUM_0)); // we don't need to subtract the magic nums because zeros will go through the same dequant function
// asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MAGIC_NUM_1)); // and carry the same magic constant, the magic num will be canceled out after subtracting zeros
// }
// else { // 64 only, trade 4 hfma2 with 2 shifts
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
// asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
// h[0] <<= 4;
// h[2] <<= 4;
// // we don't need to subtract the magic nums because zeros will go through the same dequant function
// // and carry the same magic constant, the magic num will be canceled out after subtracting zeros
// }
return result; return result;
} }
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr) __inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr)
{ {
uint32_t smem_int_ptr; uint32_t smem_int_ptr;
...@@ -220,12 +218,12 @@ __inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t sme ...@@ -220,12 +218,12 @@ __inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t sme
__inline__ __device__ half2 apply_Q(const half2& x, const half2& q) __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
{ {
uint s, z; //uint s, z;
(half2&)z = __halves2half2(q.x, q.x); //(half2&)z = __halves2half2(q.x, q.x);
(half2&)s = __halves2half2(q.y, q.y); //(half2&)s = __halves2half2(q.y, q.y);
auto& t = (const uint&)x; //auto& t = (const uint&)x;
uint u, v; uint v;
// if (TURBOMIND_S4_DEQUANT_USE_FMA) { // if (TURBOMIND_S4_DEQUANT_USE_FMA) {
// asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z)); // asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z));
// } // }
...@@ -233,7 +231,7 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q) ...@@ -233,7 +231,7 @@ __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
// asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z)); // asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z));
// asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s)); // asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s));
// } // }
printf("=========common.h 235\n");
return (half2&)v; return (half2&)v;
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "common.h" #include "common.h"
#include <iostream> #include <iostream>
#define BLOCKSIZE 256
namespace turbomind { namespace turbomind {
...@@ -71,7 +72,17 @@ void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStre ...@@ -71,7 +72,17 @@ void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStre
// permutation for [k, m/8] layout // permutation for [k, m/8] layout
Array<int, 10> shape{k / 32, 2, 2, 4, 2, m / 32, 2, 2, 2, 4}; Array<int, 10> shape{k / 32, 2, 2, 4, 2, m / 32, 2, 2, 2, 4};
// |warp| lane | 2x2 | a0-7 | // |warp| lane | 2x2 | a0-7 |
permute_u4<0, 5, 9, 8, 3, 1, 6, 4, 2, 7><<<512, 512, 0, st>>>(dst, src, shape); //permute_u4<0, 5, 9, 8, 3, 1, 6, 4, 2, 7><<<512, 512, 0, st>>>(dst, src, shape);
permute_u4<0, 1, 2, 3, 4, 5, 6, 7, 8, 9><<<512, 512, 0, st>>>(dst, src, shape);
}
void reformat_s4_k_m8_tarnsw4(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
{
// permutation for [k, m/8] layout
Array<int, 10> shape{1, k / 8, 2, 2, 2, 1, m / 8, 2, 2, 2};
// 0123456-->4,6,7,5,0,3,1,2
//permute_u4<4, 6, 7, 5, 0, 3, 1, 2><<<512, 512, 0, st>>>(dst, src, shape);
permute_u4<5, 6, 8, 9, 7, 0, 1, 4, 2, 3><<<512, 512, 0, st>>>(dst, src, shape);
} }
__global__ void dequantize_s4_offset_64(uint4* dst, const uint32_t* src, size_t count) __global__ void dequantize_s4_offset_64(uint4* dst, const uint32_t* src, size_t count)
...@@ -112,6 +123,22 @@ void convert_s4_k_m8(uint32_t* A_dst, ...@@ -112,6 +123,22 @@ void convert_s4_k_m8(uint32_t* A_dst,
reformat_s4_k_m8(A_dst, A_src, m, k, st); reformat_s4_k_m8(A_dst, A_src, m, k, st);
} }
void convert_s4_k_m8_(uint32_t* A_dst,
half2* Q_dst,
half* workspace,
const uint32_t* A_src,
const half* scales,
const uint32_t* qzeros,
int m,
int k,
int group_size,
cudaStream_t st)
{
dequantize_s4_offset_64<<<256, 256, 0, st>>>((uint4*)workspace, qzeros, k / group_size * m / 8);
merge_Q<<<256, 256, 0, st>>>(Q_dst, scales, workspace, k / group_size * m);
reformat_s4_k_m8_tarnsw4(A_dst, A_src, m, k, st);
}
void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st) void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st)
{ {
Array<int, 7> shape{k, m / size_per_head, 2, size_per_head / 2 / 8, 2, 2, 2}; Array<int, 7> shape{k, m / size_per_head, 2, size_per_head / 2 / 8, 2, 2, 2};
...@@ -140,5 +167,118 @@ void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t s ...@@ -140,5 +167,118 @@ void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t s
{ {
dequantize_s4_kernel<<<512, 512>>>(dst, src, count); dequantize_s4_kernel<<<512, 512>>>(dst, src, count);
} }
__global__ void dequant_kernel(int num_kernels,half* weight ,const half2* zeros_and_scales,int k,int n,int group_size)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int j=id%n;
int i=id/n;
half x=zeros_and_scales[i/group_size*n+j].data[0];
half y= zeros_and_scales[i/group_size*n+j].data[1];
float tmp=(weight[id]-x)*y;
weight[id]=__float2half(tmp);
}
__global__ void dequant_kernel_colmajor(int num_kernels,half* weight ,const half2* zeros_and_scales,int k,int n,int group_size)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int j=id/group_size;
half x=zeros_and_scales[j].data[0];
half y= zeros_and_scales[j].data[1];
float tmp=(weight[id]-x)*y;
weight[id]=__float2half(tmp);
}
void dequant_w4_gemm(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size)
{
dequantize_s4_offset_64<<<256, 256, 0, stream>>>((uint4*)output, weight, k * n / 8);
int num_kernels=k*n;
dequant_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels,output,zeros_and_scales,k,n,group_size);
}
void dequant_w4_gemm_colmajor(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size)
{
dequantize_s4_offset_64<<<256, 256, 0, stream>>>((uint4*)output, weight, k * n / 8);
int num_kernels=k*n;
dequant_kernel_colmajor<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels,output,zeros_and_scales,k,n,group_size);
}
__global__ void FusedSiluActivation_kernel(int num_kernels,half* output ,const uint32_t* src,int m,int n)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
auto data = ((half2*)src)[id];
float x= __half2float(data.data[0]);
float y= __half2float(data.data[1]);
float silu=x / (1.f + __expf(-x))*y;
output[id]=__float2half(silu);
}
__global__ void assign_kernel(int num_kernels,half* output ,const half* src,int m,int n)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
output[id]=src[id];
}
void addFusedSiluActivation(cudaStream_t stream,half* output, const half* src,int m,int n,int type)
{
int num_kernels=m*n;
switch (type) {
case 0:
assign_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels,output,src,m,n);
break;
case 1:
FusedSiluActivation_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(int(num_kernels/2),output,(const uint32_t*)src,m,n);
break;
default:
return;
}
}
template <typename T>
__global__ void input_padding_kernel(int num_kernels,T* output,const T* input,int m,int k,int group_size,int count)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int j=id%(k+count*group_size);
int i=id/(k+count*group_size);
if(j<k)
{
output[i*(k+count*group_size)+j]=input[i*(k)+j];
}
else
{
output[i*(k+count*group_size)+j]=0.f;
}
}
template <typename T>
void input_padding(cudaStream_t stream, T* output,const T* input,int m,int k,int group_size,int pad_groupcount)
{
//input的size是[m,k],output的size是[m,n+group_size]
//
int num_kernels=m*(k+pad_groupcount*group_size);
input_padding_kernel<<<(num_kernels+BLOCKSIZE-1)/BLOCKSIZE,BLOCKSIZE,0,stream>>>(num_kernels, output,input,m,k,group_size,pad_groupcount);
}
#define INSTANTIATEINPUTPADING(T) \
template void input_padding(cudaStream_t stream, T* output,const T* input,int m,int k,int group_size,int pad_groupcount);
INSTANTIATEINPUTPADING(__half)
} // namespace turbomind } // namespace turbomind
...@@ -23,6 +23,17 @@ void convert_s4_k_m8(uint32_t* A_dst, ...@@ -23,6 +23,17 @@ void convert_s4_k_m8(uint32_t* A_dst,
int group_size, int group_size,
cudaStream_t st = {}); cudaStream_t st = {});
void convert_s4_k_m8_(uint32_t* A_dst,
half2* Q_dst,
half* workspace,
const uint32_t* A_src,
const half* scales,
const uint32_t* qzeros,
int m,
int k,
int group_size,
cudaStream_t st = {});
void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st = {}); void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st = {});
void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st = {}); void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st = {});
......
...@@ -9,10 +9,32 @@ ...@@ -9,10 +9,32 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
typedef struct ihipStream_t* hipStream_t;
extern void run_weight_only_gemm(const void *A,
const void *B0,
const void *B1,
void *C,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideB_padded,
int StrideC,
int Group,
void* splitK_padA_workspace,
int splitK_padA_workspace_elementSize,
hipStream_t stream_id=0);
namespace turbomind { namespace turbomind {
extern bool g_dump_kernel_info_once; extern bool g_dump_kernel_info_once;
void dequant_w4_gemm(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size);
void addFusedSiluActivation(cudaStream_t stream,half* output, const half* src,int m,int n,int type);
void dequant_w4_gemm_colmajor(cudaStream_t stream, half* output,const uint32_t* weight,const half2* zeros_and_scales,int k, int n, int group_size);
template <typename T>
void input_padding(cudaStream_t stream, T* output,const T* input,int m,int k,int group_size,int pad_groupcount);
class GemmS4F16 { class GemmS4F16 {
public: public:
......
...@@ -78,6 +78,7 @@ bool BlockManager::Malloc() ...@@ -78,6 +78,7 @@ bool BlockManager::Malloc()
return false; return false;
} }
//auto ptr = (std::byte*)allocator_->malloc(block_size_ * chunk_size);
auto ptr = (uint8_t*)allocator_->malloc(block_size_ * chunk_size); auto ptr = (uint8_t*)allocator_->malloc(block_size_ * chunk_size);
if (!ptr) { if (!ptr) {
return false; return false;
......
...@@ -19,13 +19,15 @@ add_library(Llama STATIC ...@@ -19,13 +19,15 @@ add_library(Llama STATIC
unified_attention_layer.cc unified_attention_layer.cc
llama_kernels.cu llama_kernels.cu
llama_decoder_kernels.cu llama_decoder_kernels.cu
llama_utils.cu) llama_utils.cu
)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fPIC")
#set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON) #set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON)
#set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) #set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_directories(Llama PUBLIC ../../../../3rdparty/composable_kernel/)
target_link_libraries(Llama PUBLIC cudart target_link_libraries(Llama PUBLIC cudart
# gemm_s4_f16 gemm_s4_f16
cublasMMWrapper cublasMMWrapper
DynamicDecodeLayer DynamicDecodeLayer
activation_kernels activation_kernels
...@@ -41,7 +43,8 @@ target_link_libraries(Llama PUBLIC cudart ...@@ -41,7 +43,8 @@ target_link_libraries(Llama PUBLIC cudart
memory_utils memory_utils
nccl_utils nccl_utils
cuda_utils cuda_utils
logger) logger
gemm_multiB_int4)
# llama_fmha) # llama_fmha)
if (NOT MSVC) if (NOT MSVC)
......
...@@ -41,6 +41,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num, ...@@ -41,6 +41,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
size_t inter_size, size_t inter_size,
WeightType weight_type, WeightType weight_type,
int group_size, int group_size,
int w4_weight_layout,
bool attn_bias, bool attn_bias,
size_t tensor_para_size, size_t tensor_para_size,
size_t tensor_para_rank): size_t tensor_para_rank):
...@@ -58,31 +59,37 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num, ...@@ -58,31 +59,37 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_; self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_;
self_attn_weights.qkv.type = weight_type; self_attn_weights.qkv.type = weight_type;
self_attn_weights.qkv.group_size = group_size; self_attn_weights.qkv.group_size = group_size;
self_attn_weights.qkv.w4_weight_layout = w4_weight_layout;
self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_; self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_;
self_attn_weights.output.output_dims = hidden_units_; self_attn_weights.output.output_dims = hidden_units_;
self_attn_weights.output.type = weight_type; self_attn_weights.output.type = weight_type;
self_attn_weights.output.group_size = group_size; self_attn_weights.output.group_size = group_size;
self_attn_weights.output.w4_weight_layout = w4_weight_layout;
ffn_weights.gating.input_dims = hidden_units_; ffn_weights.gating.input_dims = hidden_units_;
ffn_weights.gating.output_dims = inter_size_ / tensor_para_size_; ffn_weights.gating.output_dims = inter_size_ / tensor_para_size_;
ffn_weights.gating.type = weight_type; ffn_weights.gating.type = weight_type;
ffn_weights.gating.group_size = group_size; ffn_weights.gating.group_size = group_size;
ffn_weights.gating.w4_weight_layout = w4_weight_layout;
ffn_weights.intermediate.input_dims = hidden_units_; ffn_weights.intermediate.input_dims = hidden_units_;
ffn_weights.intermediate.output_dims = inter_size_ / tensor_para_size_; ffn_weights.intermediate.output_dims = inter_size_ / tensor_para_size_;
ffn_weights.intermediate.type = weight_type; ffn_weights.intermediate.type = weight_type;
ffn_weights.intermediate.group_size = group_size; ffn_weights.intermediate.group_size = group_size;
ffn_weights.intermediate.w4_weight_layout = w4_weight_layout;
ffn_weights.fused_gating_intermediate.input_dims = hidden_units_; ffn_weights.fused_gating_intermediate.input_dims = hidden_units_;
ffn_weights.fused_gating_intermediate.output_dims = inter_size_ / tensor_para_size_ * 2; ffn_weights.fused_gating_intermediate.output_dims = inter_size_ / tensor_para_size_ * 2;
ffn_weights.fused_gating_intermediate.type = weight_type; ffn_weights.fused_gating_intermediate.type = weight_type;
ffn_weights.fused_gating_intermediate.group_size = group_size; ffn_weights.fused_gating_intermediate.group_size = group_size;
ffn_weights.fused_gating_intermediate.w4_weight_layout = w4_weight_layout;
ffn_weights.output.input_dims = inter_size_ / tensor_para_size_; ffn_weights.output.input_dims = inter_size_ / tensor_para_size_;
ffn_weights.output.output_dims = hidden_units_; ffn_weights.output.output_dims = hidden_units_;
ffn_weights.output.type = weight_type; ffn_weights.output.type = weight_type;
ffn_weights.output.group_size = group_size; ffn_weights.output.group_size = group_size;
ffn_weights.output.w4_weight_layout = w4_weight_layout;
mallocWeights(); mallocWeights();
} }
...@@ -111,11 +118,29 @@ void mallocWeights(LlamaDenseWeight<T>& weights, bool bias) ...@@ -111,11 +118,29 @@ void mallocWeights(LlamaDenseWeight<T>& weights, bool bias)
else { // int8, int4 else { // int8, int4
const int factor = sizeof(float) * 8 / bit_size; const int factor = sizeof(float) * 8 / bit_size;
FT_CHECK(weights.input_dims % factor == 0); FT_CHECK(weights.input_dims % factor == 0);
// //读环境变量
// int m_weightlayout_switch=1;
// const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
// if (env_weightlayout_str != nullptr) {
// m_weightlayout_switch = std::stoi(env_weightlayout_str);
// }
if((weights.input_dims%4096==0)&&(weights.w4_weight_layout==1||weights.w4_weight_layout==2))
{
size_t new_input_dims=weights.input_dims+2*weights.group_size;
deviceMalloc((int**)&weights.kernel, new_input_dims * weights.output_dims / factor);
deviceMemSetZero((int*)weights.kernel, new_input_dims* weights.output_dims / factor);
// interleaved scales/zeros
deviceMalloc((T**)&weights.scales_and_zeros, new_input_dims / weights.group_size * weights.output_dims * 2);
}
else{
deviceMalloc((int**)&weights.kernel, weights.input_dims * weights.output_dims / factor); deviceMalloc((int**)&weights.kernel, weights.input_dims * weights.output_dims / factor);
deviceMemSetZero((int*)weights.kernel, weights.input_dims * weights.output_dims / factor); deviceMemSetZero((int*)weights.kernel, weights.input_dims * weights.output_dims / factor);
// interleaved scales/zeros // interleaved scales/zeros
deviceMalloc((T**)&weights.scales_and_zeros, weights.input_dims / weights.group_size * weights.output_dims * 2); deviceMalloc((T**)&weights.scales_and_zeros, weights.input_dims / weights.group_size * weights.output_dims * 2);
} }
}
} }
template<typename FirstArg, typename... Args> template<typename FirstArg, typename... Args>
...@@ -146,6 +171,28 @@ void getWeightTensor(LlamaDenseWeight<T>& weights, bool bias, const std::string& ...@@ -146,6 +171,28 @@ void getWeightTensor(LlamaDenseWeight<T>& weights, bool bias, const std::string&
} }
else { // int8, int4 else { // int8, int4
const int factor = sizeof(float) * 8 / bit_size; const int factor = sizeof(float) * 8 / bit_size;
// //读环境变量
// int m_weightlayout_switch=1;
// const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
// if (env_weightlayout_str != nullptr) {
// m_weightlayout_switch = std::stoi(env_weightlayout_str);
// }
if((weights.input_dims%4096==0)&&(weights.w4_weight_layout==1||weights.w4_weight_layout==2))
{
size_t new_input_dims=weights.input_dims+weights.group_size;
output.insert(get_name("qweight"),
Tensor{MEMORY_GPU,
TYPE_INT32,
{new_input_dims * weights.output_dims * sizeof(int) / factor},
weights.kernel});
output.insert(get_name("scales_zeros"),
Tensor{MEMORY_GPU,
getTensorType<T>(),
{new_input_dims / weights.group_size * weights.output_dims * 2 * sizeof(T)},
weights.scales_and_zeros});
}
else{
output.insert(get_name("qweight"), output.insert(get_name("qweight"),
Tensor{MEMORY_GPU, Tensor{MEMORY_GPU,
TYPE_INT32, TYPE_INT32,
...@@ -157,6 +204,7 @@ void getWeightTensor(LlamaDenseWeight<T>& weights, bool bias, const std::string& ...@@ -157,6 +204,7 @@ void getWeightTensor(LlamaDenseWeight<T>& weights, bool bias, const std::string&
{weights.input_dims / weights.group_size * weights.output_dims * 2 * sizeof(T)}, {weights.input_dims / weights.group_size * weights.output_dims * 2 * sizeof(T)},
weights.scales_and_zeros}); weights.scales_and_zeros});
} }
}
} }
template<typename T> template<typename T>
...@@ -259,6 +307,24 @@ void loadWeights(LlamaDenseWeight<T>& w, ...@@ -259,6 +307,24 @@ void loadWeights(LlamaDenseWeight<T>& w,
FT_CHECK(dim1 % factor == 0); FT_CHECK(dim1 % factor == 0);
// //读环境变量
// int m_weightlayout_switch=1;
// const char* env_weightlayout_str = std::getenv("LMDEPLOY_WEIGHTLAYOUT_SWITCH");
// if (env_weightlayout_str != nullptr) {
// m_weightlayout_switch = std::stoi(env_weightlayout_str);
// }
if((dim0%4096==0)&&(w.w4_weight_layout==1||w.w4_weight_layout==2))
{
size_t new_dim0=dim0+2*w.group_size;
std::vector<size_t> w_shape{new_dim0, dim1 / factor * sizeof(uint32_t)};
loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8, {});
const size_t group_count = w.group_size > 0 ? new_dim0 / w.group_size : 1;
loadWeightFromBin((half*)w.scales_and_zeros, {group_count, dim1 * 2}, prefix + ".scales_zeros", type, {});
}
else{
std::vector<size_t> w_shape{dim0, dim1 / factor * sizeof(uint32_t)}; std::vector<size_t> w_shape{dim0, dim1 / factor * sizeof(uint32_t)};
loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8, {}); loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8, {});
...@@ -266,6 +332,7 @@ void loadWeights(LlamaDenseWeight<T>& w, ...@@ -266,6 +332,7 @@ void loadWeights(LlamaDenseWeight<T>& w,
loadWeightFromBin((half*)w.scales_and_zeros, {group_count, dim1 * 2}, prefix + ".scales_zeros", type, {}); loadWeightFromBin((half*)w.scales_and_zeros, {group_count, dim1 * 2}, prefix + ".scales_zeros", type, {});
} }
}
} }
template<typename T> template<typename T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment