Unverified Commit 4a60b45d authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

[Feature] Support Qwen-7B, dynamic NTK scaling and logN scaling in turbomind (#230)

* qwen support

* dynamic ntk & logn attn

* fix ntk & add chat template

* fix ntk scaling & stop words

* fix lint

* add tiktoken to requirements.txt

* fix tokenizer, set model format automatically

* update model.py

* update readme

* fix lint
parent 62b60db7
...@@ -13,6 +13,7 @@ ______________________________________________________________________ ...@@ -13,6 +13,7 @@ ______________________________________________________________________
## News 🎉 ## News 🎉
- \[2023/08\] TurboMind supports Qwen-7B, dynamic NTK-RoPE scaling and dynamic logN scaling
- \[2023/08\] TurboMind supports Windows (tp=1) - \[2023/08\] TurboMind supports Windows (tp=1)
- \[2023/08\] TurboMind supports 4-bit inference, 2.4x faster than FP16, the fastest open-source implementation🚀. Check [this](./docs/en/w4a16.md) guide for detailed info - \[2023/08\] TurboMind supports 4-bit inference, 2.4x faster than FP16, the fastest open-source implementation🚀. Check [this](./docs/en/w4a16.md) guide for detailed info
- \[2023/08\] LMDeploy has launched on the [HuggingFace Hub](https://huggingface.co/lmdeploy), providing ready-to-use 4-bit models. - \[2023/08\] LMDeploy has launched on the [HuggingFace Hub](https://huggingface.co/lmdeploy), providing ready-to-use 4-bit models.
......
...@@ -13,6 +13,7 @@ ______________________________________________________________________ ...@@ -13,6 +13,7 @@ ______________________________________________________________________
## 更新 🎉 ## 更新 🎉
- \[2023/08\] TurboMind 支持 Qwen-7B,动态NTK-RoPE缩放,动态logN缩放
- \[2023/08\] TurboMind 支持 Windows (tp=1) - \[2023/08\] TurboMind 支持 Windows (tp=1)
- \[2023/08\] TurboMind 支持 4-bit 推理,速度是 FP16 的 2.4 倍,是目前最快的开源实现🚀。部署方式请看[这里](./docs/zh_cn/w4a16.md) - \[2023/08\] TurboMind 支持 4-bit 推理,速度是 FP16 的 2.4 倍,是目前最快的开源实现🚀。部署方式请看[这里](./docs/zh_cn/w4a16.md)
- \[2023/08\] LMDeploy 开通了 [HuggingFace Hub](https://huggingface.co/lmdeploy) ,提供开箱即用的 4-bit 模型 - \[2023/08\] LMDeploy 开通了 [HuggingFace Hub](https://huggingface.co/lmdeploy) ,提供开箱即用的 4-bit 模型
......
...@@ -16,7 +16,8 @@ class Tokenizer: ...@@ -16,7 +16,8 @@ class Tokenizer:
self.pad_id = self.model.pad_id() self.pad_id = self.model.pad_id()
else: else:
from transformers import AutoTokenizer from transformers import AutoTokenizer
self.model = AutoTokenizer.from_pretrained(model_file) self.model = AutoTokenizer.from_pretrained(model_file,
trust_remote_code=True)
self.vocab_size = self.model.vocab_size self.vocab_size = self.model.vocab_size
self.start_id = self.model.bos_token_id self.start_id = self.model.bos_token_id
self.end_id = self.model.eos_token_id self.end_id = self.model.eos_token_id
......
...@@ -161,6 +161,36 @@ If a question does not make any sense, or is not factually coherent, explain why ...@@ -161,6 +161,36 @@ If a question does not make any sense, or is not factually coherent, explain why
return f'{self.b_inst} {prompt} {self.e_inst} ' return f'{self.b_inst} {prompt} {self.e_inst} '
@MODELS.register_module(name='qwen-7b')
class Qwen7BChat(BaseModel):
"""Chat template for Qwen-7B-Chat."""
def __init__(self):
super().__init__()
self.session_len = 8192
self.top_p = 0.5
self.top_k = 40
self.temperature = 1.0
self.im_start = '<|im_start|>'
self.im_end = '<|im_end|>'
self.system = 'You are a helpful assistant.'
def get_prompt(self, prompt, sequence_start=True):
if sequence_start:
return f'{self.im_start}system\n{self.system}{self.im_end}' \
f'\n{self.im_start}user\n{prompt}{self.im_end}' \
f'\n{self.im_start}assistant\n'
return f'\n{self.im_start}user\n{prompt}{self.im_end}' \
f'\n{self.im_start}assistant\n'
@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [151645] # <|im_end|>
def main(model_name: str = 'test'): def main(model_name: str = 'test'):
assert model_name in MODELS.module_dict.keys(), \ assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \ f"'{model_name}' is not supported. " \
......
...@@ -16,7 +16,7 @@ from sentencepiece import SentencePieceProcessor ...@@ -16,7 +16,7 @@ from sentencepiece import SentencePieceProcessor
import lmdeploy import lmdeploy
from lmdeploy.model import MODELS from lmdeploy.model import MODELS
supported_formats = ['llama', 'hf', 'awq'] supported_formats = ['llama', 'hf', 'awq', 'qwen']
def get_package_root_path(): def get_package_root_path():
...@@ -84,7 +84,7 @@ def copy_triton_model_templates(_path: str): ...@@ -84,7 +84,7 @@ def copy_triton_model_templates(_path: str):
return None return None
def tokenizer_info(model_path: str): def tokenizer_info_sp(model_path: str):
"""Return the vocabulary size, bos token id and eos token id. """Return the vocabulary size, bos token id and eos token id.
Args: Args:
...@@ -101,6 +101,13 @@ def tokenizer_info(model_path: str): ...@@ -101,6 +101,13 @@ def tokenizer_info(model_path: str):
return n_words, bos_id, eos_id return n_words, bos_id, eos_id
def tokenizer_info_qwen(model_dir: str):
n_words = 151851
bos_id = 0
eos_id = 151643
return n_words, bos_id, eos_id
def export(model_name: str, def export(model_name: str,
num_layer: int, num_layer: int,
norm_eps: float, norm_eps: float,
...@@ -111,7 +118,11 @@ def export(model_name: str, ...@@ -111,7 +118,11 @@ def export(model_name: str,
tp: int, tp: int,
size_per_head: int = 128, size_per_head: int = 128,
group_size: int = 0, group_size: int = 0,
weight_type: str = 'fp16'): weight_type: str = 'fp16',
max_position_embeddings: int = 0,
use_dynamic_ntk: int = 0,
use_logn_attn: int = 0,
tokenizer_info=tokenizer_info_sp):
"""Export deploying information to a config file. """Export deploying information to a config file.
Args: Args:
...@@ -191,7 +202,7 @@ def export(model_name: str, ...@@ -191,7 +202,7 @@ def export(model_name: str,
head_num=head_num, head_num=head_num,
kv_head_num=kv_head_num, kv_head_num=kv_head_num,
size_per_head=size_per_head, size_per_head=size_per_head,
vocab_size=vocab_size, vocab_size=_vocab_size,
num_layer=num_layer, num_layer=num_layer,
rotary_embedding=size_per_head, rotary_embedding=size_per_head,
inter_size=inter_size, inter_size=inter_size,
...@@ -210,7 +221,11 @@ def export(model_name: str, ...@@ -210,7 +221,11 @@ def export(model_name: str,
cache_chunk_size=1, cache_chunk_size=1,
use_context_fmha=1, use_context_fmha=1,
quant_policy=0, quant_policy=0,
tensor_para_size=tp)) tensor_para_size=tp,
# extra attention params
max_position_embeddings=max_position_embeddings,
use_dynamic_ntk=int(use_dynamic_ntk),
use_logn_attn=int(use_logn_attn)))
config = configparser.ConfigParser() config = configparser.ConfigParser()
for section, key_values in cfg.items(): for section, key_values in cfg.items():
...@@ -725,6 +740,134 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str, ...@@ -725,6 +740,134 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str,
group_size=group_size) group_size=group_size)
def deploy_qwen(model_name: str, model_path: str, tokenizer_path: str,
triton_models_path: str, tp: int):
"""Deploy a model with huggingface transformers' format.
Args:
model_name (str): the name of the to-be-deployed model
model_path (str): the path of the directory where the model weight
files are
tokenizer_path (str): the path of the tokenizer model path
triton_models_path (str): the path of the exported triton models
tp (int): the number of tensor parallelism
quant_path (str): path of the quantized model, which can be None
group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits
"""
if osp.exists(model_path):
shutil.copy(osp.join(model_path, 'qwen.tiktoken'),
osp.join(triton_models_path, 'tokenizer'))
for _file in os.listdir(model_path):
if _file.endswith('.json') or _file.endswith('.py'):
json_path = osp.join(model_path, _file)
shutil.copy(json_path,
osp.join(triton_models_path, 'tokenizer', _file))
with get_package_root_path() as root_path:
shutil.copy(osp.join(root_path, 'turbomind/tokenizer.py'),
osp.join(triton_models_path, 'tokenizer'))
else:
print(f'tokenizer model {tokenizer_path} does not exist')
exit(-1)
# read model arguments from params.json
try:
params_path = osp.join(model_path, 'config.json')
with open(params_path) as f:
config = json.load(f)
num_layer = config['num_hidden_layers']
norm_eps = config['layer_norm_epsilon']
if 'num_key_value_heads' in config:
kv_head_num = config['num_key_value_heads']
else:
kv_head_num = config['num_attention_heads']
seq_length = config['seq_length']
use_dynamic_ntk = config['use_dynamic_ntk']
use_logn_attn = config['use_logn_attn']
except Exception as e:
print(f'get "num_hidden_layers" and "layer_norm_epsilon" from '
f'{params_path} failed: {e}')
return False
# convert weights from hf to turbomind
model_params = {}
_files = [file for file in os.listdir(model_path) if file.endswith('.bin')]
_files = sorted(_files)
print(_files)
_params = {}
for _file in _files:
_tmp = torch.load(osp.join(model_path, _file), map_location='cpu')
_params.update(_tmp)
def get_tensor(name, trans=True):
"""return a transposed tensor according its name."""
if trans:
return _params[name].cuda().t()
else:
return _params[name].cuda()
for i in range(num_layer):
print(i)
# qkv weights
qkv_w = get_tensor(f'transformer.h.{i}.attn.c_attn.weight')
q_w, k_w, v_w = torch.split(qkv_w, qkv_w.size(-1) // 3, dim=-1)
q_w, k_w = permute(q_w), permute(k_w)
qkv_w = merge_qkv(q_w, k_w, v_w, tp, dim=2)
model_params[f'layers.{i}.attention.w_qkv.weight'] = qkv_w
# qkv bias
qkv_b = get_tensor(f'transformer.h.{i}.attn.c_attn.bias')
q_b, k_b, v_b = torch.split(qkv_b, qkv_b.size(-1) // 3)
q_b, k_b = permute(q_b), permute(k_b)
qkv_b = merge_qkv(q_b, k_b, v_b, tp, dim=1)
model_params[f'layers.{i}.attention.w_qkv.bias'] = qkv_b
# o weights
o_w = get_tensor(f'transformer.h.{i}.attn.c_proj.weight')
model_params[f'layers.{i}.attention.wo.weight'] = o_w
model_params[f'layers.{i}.attention.wo.bias'] = torch.zeros_like(q_b)
# ffn weights
# ours: w2(silu(w1(x)) * w3(x))
# qwen: c_proj(w1(x) * silu(w2(x)))
w1 = get_tensor(f'transformer.h.{i}.mlp.w2.weight')
w3 = get_tensor(f'transformer.h.{i}.mlp.w1.weight')
w2 = get_tensor(f'transformer.h.{i}.mlp.c_proj.weight')
model_params[f'layers.{i}.feed_forward.w1.weight'] = w1
model_params[f'layers.{i}.feed_forward.w2.weight'] = w2
model_params[f'layers.{i}.feed_forward.w3.weight'] = w3
# norm weights
attn_norm = get_tensor(f'transformer.h.{i}.ln_1.weight')
ffn_norm = get_tensor(f'transformer.h.{i}.ln_2.weight')
model_params[f'layers.{i}.attention_norm.weight'] = attn_norm
model_params[f'layers.{i}.ffn_norm.weight'] = ffn_norm
other = [('tok_embeddings.weight', 'transformer.wte.weight'),
('norm.weight', 'transformer.ln_f.weight'),
('output.weight', 'lm_head.weight')]
for ft, hf in other:
model_params[ft] = get_tensor(hf, trans=False)
return export(model_name,
num_layer,
norm_eps,
kv_head_num,
model_params,
model_path,
triton_models_path,
tp,
max_position_embeddings=seq_length,
use_dynamic_ntk=use_dynamic_ntk,
use_logn_attn=use_logn_attn,
tokenizer_info=tokenizer_info_qwen)
def pack_model_repository(workspace_path: str): def pack_model_repository(workspace_path: str):
"""package the model repository. """package the model repository.
...@@ -752,7 +895,7 @@ def pack_model_repository(workspace_path: str): ...@@ -752,7 +895,7 @@ def pack_model_repository(workspace_path: str):
def main(model_name: str, def main(model_name: str,
model_path: str, model_path: str,
model_format: str = 'hf', model_format: str = None,
tokenizer_path: str = None, tokenizer_path: str = None,
dst_path: str = './workspace', dst_path: str = './workspace',
tp: int = 1, tp: int = 1,
...@@ -777,6 +920,9 @@ def main(model_name: str, ...@@ -777,6 +920,9 @@ def main(model_name: str,
f"'{model_name}' is not supported. " \ f"'{model_name}' is not supported. " \
f'The supported models are: {MODELS.module_dict.keys()}' f'The supported models are: {MODELS.module_dict.keys()}'
if model_format is None:
model_format = 'qwen' if model_name == 'qwen-7b' else 'hf'
if model_format not in supported_formats: if model_format not in supported_formats:
print(f'the model format "{model_format}" is not supported. ' print(f'the model format "{model_format}" is not supported. '
f'The supported format are: {supported_formats}') f'The supported format are: {supported_formats}')
...@@ -803,6 +949,9 @@ def main(model_name: str, ...@@ -803,6 +949,9 @@ def main(model_name: str,
elif model_format == 'awq': elif model_format == 'awq':
res = deploy_awq(model_name, model_path, tokenizer_path, res = deploy_awq(model_name, model_path, tokenizer_path,
triton_models_path, tp, quant_path, group_size) triton_models_path, tp, quant_path, group_size)
elif model_format == 'qwen':
res = deploy_qwen(model_name, model_path, tokenizer_path,
triton_models_path, tp)
# update `tensor_para_size` in `triton_models/interactive/config.pbtxt` # update `tensor_para_size` in `triton_models/interactive/config.pbtxt`
with open(osp.join(triton_models_path, 'interactive/config.pbtxt'), with open(osp.join(triton_models_path, 'interactive/config.pbtxt'),
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp import os.path as osp
from typing import Sequence, Union from typing import Sequence, Union
...@@ -99,6 +100,13 @@ class HuggingFaceTokenizer: ...@@ -99,6 +100,13 @@ class HuggingFaceTokenizer:
if hasattr(self.model, 'backend_tokenizer'): if hasattr(self.model, 'backend_tokenizer'):
self.model.backend_tokenizer.save(backend_tokenizer_file) self.model.backend_tokenizer.save(backend_tokenizer_file)
if self.model.eos_token_id is None:
generation_config_file = osp.join(model_dir,
'generation_config.json')
with open(generation_config_file, 'r') as f:
cfg = json.load(f)
self.model.eos_token_id = cfg['eos_token_id']
@property @property
def vocab_size(self): def vocab_size(self):
"""vocabulary size.""" """vocabulary size."""
......
...@@ -60,13 +60,6 @@ struct Multihead_attention_params_base { ...@@ -60,13 +60,6 @@ struct Multihead_attention_params_base {
// The input Vs and the associated bias. Dimensions B x D and D, resp. // The input Vs and the associated bias. Dimensions B x D and D, resp.
const T *v = nullptr, *v_bias = nullptr; const T *v = nullptr, *v_bias = nullptr;
// The cache for the Ks. The size must be at least B x L x D.
T* k_cache = nullptr;
// The cache for the Vs. The size must be at least B x L x D.
T* v_cache = nullptr;
// The indirections to use for cache when beam sampling.
const int* cache_indir = nullptr;
// scales // scales
const float* query_weight_output_scale = nullptr; const float* query_weight_output_scale = nullptr;
const float* attention_qk_scale = nullptr; const float* attention_qk_scale = nullptr;
...@@ -108,10 +101,6 @@ struct Multihead_attention_params_base { ...@@ -108,10 +101,6 @@ struct Multihead_attention_params_base {
// The slope per head of linear position bias to attention score (H). // The slope per head of linear position bias to attention score (H).
const T* linear_bias_slopes = nullptr; const T* linear_bias_slopes = nullptr;
const T* ia3_key_weights = nullptr;
const T* ia3_value_weights = nullptr;
const int* ia3_tasks = nullptr;
const float* qkv_scale_out = nullptr; const float* qkv_scale_out = nullptr;
const float* attention_out_scale = nullptr; const float* attention_out_scale = nullptr;
int int8_mode = 0; int int8_mode = 0;
...@@ -123,17 +112,15 @@ struct Multihead_attention_params_base { ...@@ -123,17 +112,15 @@ struct Multihead_attention_params_base {
template<typename T> template<typename T>
struct Multihead_attention_params: public Multihead_attention_params_base<T> { struct Multihead_attention_params: public Multihead_attention_params_base<T> {
// allows to exist attention eary bool* finished = nullptr;
bool* finished = nullptr; const int* length_per_sample = nullptr;
T** k_cache_per_sample = nullptr;
// required in case of masked attention with different length T** v_cache_per_sample = nullptr;
const int* length_per_sample = nullptr; size_t kv_cache_per_sample_offset = 0;
int num_kv_heads = 0;
T** k_cache_per_sample = nullptr; int max_position_embeddings = 0;
T** v_cache_per_sample = nullptr; bool use_dynamic_ntk = false;
size_t kv_cache_per_sample_offset = 0; bool use_logn_attn = false;
bool k_cache_interleaved = true;
int num_kv_heads = 0;
}; };
template<class T> template<class T>
......
...@@ -41,15 +41,12 @@ ...@@ -41,15 +41,12 @@
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// !!! Specialize the launcher for Cross attention
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE> template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{ {
constexpr int THREADS_PER_VALUE = threads_per_value_t<T, Dh_MAX>::value; constexpr int THREADS_PER_VALUE = threads_per_value_t<T, Dh_MAX>::value;
// constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
const int tlength = params.timestep;
FT_CHECK(params.cache_indir == nullptr); const int tlength = params.timestep;
if (params.int8_mode == 4) { if (params.int8_mode == 4) {
if (tlength < 32) { if (tlength < 32) {
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
#include "src/turbomind/kernels/decoder_masked_multihead_attention.h" #include "src/turbomind/kernels/decoder_masked_multihead_attention.h"
#include "src/turbomind/kernels/decoder_masked_multihead_attention_utils.h" #include "src/turbomind/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/turbomind/macro.h" #include "src/turbomind/macro.h"
// #include "src/turbomind/utils/cuda_bf16_wrapper.h"
// #include "src/turbomind/utils/cuda_fp8_utils.h"
#include "src/turbomind/utils/cuda_type_utils.cuh" #include "src/turbomind/utils/cuda_type_utils.cuh"
#include <assert.h> #include <assert.h>
#include <float.h> #include <float.h>
...@@ -592,42 +590,6 @@ __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) ...@@ -592,42 +590,6 @@ __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
return b; return b;
} }
#ifdef ENABLE_FP8
// fp8_t
template<>
__inline__ __device__ float vec_conversion<float, __nv_fp8_e4m3>(const __nv_fp8_e4m3& a)
{
return float(a);
}
template<>
__inline__ __device__ __nv_fp8_e4m3 vec_conversion<__nv_fp8_e4m3, float>(const float& a)
{
return __nv_fp8_e4m3(a);
}
// fp8_2_t
template<>
__inline__ __device__ float2 vec_conversion<float2, fp8_2_t>(const fp8_2_t& a)
{
return float2(a);
}
template<>
__inline__ __device__ fp8_2_t vec_conversion<fp8_2_t, float2>(const float2& a)
{
return fp8_2_t(a);
}
// fp8_4_t
template<>
__inline__ __device__ float4 vec_conversion<float4, fp8_4_t>(const fp8_4_t& a)
{
return float4(a);
}
template<>
__inline__ __device__ fp8_4_t vec_conversion<fp8_4_t, float4>(const float4& a)
{
return fp8_4_t(a);
}
#endif // ENABLE_FP8
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS_PER_KEY, typename K_vec, int N> template<int THREADS_PER_KEY, typename K_vec, int N>
...@@ -868,19 +830,6 @@ inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) ...@@ -868,19 +830,6 @@ inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_FP8
inline __device__ void convert_from_float(fp8_4_t& dst, float4 src)
{
dst = fp8_4_t(src);
}
inline __device__ void convert_from_float(fp8_2_t& dst, float2 src)
{
dst = fp8_2_t(src);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(float2& dst, float2 src) inline __device__ void convert_from_float(float2& dst, float2 src)
{ {
dst = src; dst = src;
...@@ -1365,8 +1314,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1365,8 +1314,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// The thread in the block. // The thread in the block.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
constexpr bool handle_kv = true;
// While doing the product Q*K^T for the different keys we track the max. // While doing the product Q*K^T for the different keys we track the max.
float qk_max = -FLT_MAX; float qk_max = -FLT_MAX;
...@@ -1422,28 +1369,36 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1422,28 +1369,36 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
Qk_vec_k k_bias; Qk_vec_k k_bias;
zero(k_bias); zero(k_bias);
if (handle_kv) {
k_bias = k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_bias[k_bias_offset])) :
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_bias[k_bias_offset])) : k_bias;
k_bias;
}
// Computes the Q/K values with bias. // Computes the Q/K values with bias.
q = add(q, q_bias); q = add(q, q_bias);
if (handle_kv) { k = add(k, k_bias);
k = add(k, k_bias);
float rotary_emb_base = 10000.f;
if (params.use_dynamic_ntk) {
// +1 because of `length_per_sample == context_length - 1`
rotary_emb_base = rotary_embedding_get_base(params.length_per_sample[bi] + 1,
params.max_position_embeddings,
params.rotary_embedding_dim,
rotary_emb_base);
} }
// Padded len // Padded len
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
if (params.rotary_embedding_dim > 0) { if (params.rotary_embedding_dim > 0) {
if (handle_kv) { apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, rotary_emb_base, params.timestep - padd_len);
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep - padd_len); }
}
else { if (params.use_logn_attn) {
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep - padd_len); T log_n_scaling;
} // +1 because of `length_per_sample == context_length - 1`
convert_from_float(log_n_scaling,
logn_attn_get_scaling(params.length_per_sample[bi] + 1, params.max_position_embeddings));
q = mul<Qk_vec_k, T, Qk_vec_k>(log_n_scaling, q);
} }
if (!is_masked) { if (!is_masked) {
...@@ -1462,47 +1417,23 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1462,47 +1417,23 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// The position of the thread in that 16B chunk. // The position of the thread in that 16B chunk.
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
if (handle_kv && group_leader) { if (group_leader) {
// Trigger the stores to global memory. // Trigger the stores to global memory.
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
if (!params.k_cache_per_sample) {
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B
+ tlength_circ * QK_ELTS_IN_16B + ci;
if (!QUANT_POLICY) { int offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + tlength_circ * Dh
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k); + co * QK_ELTS_IN_16B + ci;
}
else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale, k_zp);
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache); if (!QUANT_POLICY) {
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8; *reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) =
} vec_conversion<Qk_vec_m, Qk_vec_k>(k);
} }
else { else if (QUANT_POLICY == 4) {
int offset; using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
if (params.k_cache_interleaved) { Packed_Int8_t k_int8 = quant(k, k_scale, k_zp);
offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh
+ co * params.memory_max_len * QK_ELTS_IN_16B + tlength_circ * QK_ELTS_IN_16B + ci;
}
else {
offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh
+ tlength_circ * Dh + co * QK_ELTS_IN_16B + ci;
}
if (!QUANT_POLICY) { int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
*reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) = *reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
vec_conversion<Qk_vec_m, Qk_vec_k>(k);
}
else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale, k_zp);
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
}
} }
} }
} }
...@@ -1584,33 +1515,21 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1584,33 +1515,21 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
int8_t* k_cache_batch_int8 = nullptr; int8_t* k_cache_batch_int8 = nullptr;
if (!QUANT_POLICY) { if (!QUANT_POLICY) {
k_cache_batch = params.k_cache_per_sample ? (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset k_cache_batch =
+ kvhi * params.memory_max_len * Dh + ki) : params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + ki;
&params.k_cache[bhi * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
} }
else if (QUANT_POLICY == 4) { else if (QUANT_POLICY == 4) {
// convert k_cache_per_sample to int8 // convert k_cache_per_sample to int8
if (params.k_cache_per_sample) { int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]); k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + ki;
k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + ki;
}
else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache);
k_cache_batch_int8 = &ptr[bhi * params.memory_max_len * Dh + ki];
}
} }
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
// prefix prompt length if has
const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi];
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
const int* beam_indices = HAS_BEAMS ? &params.cache_indir[bi_seq_len_offset] : nullptr; // const int* beam_indices = HAS_BEAMS ? &params.cache_indir[bi_seq_len_offset] : nullptr;
for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) { for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) {
const int ti_circ = ti % params.memory_max_len; const int ti_circ = ti % params.memory_max_len;
...@@ -1622,8 +1541,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1622,8 +1541,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
zero(k_vec_zero); zero(k_vec_zero);
#pragma unroll #pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = int jj = ti_circ * Dh / QK_ELTS_IN_16B + ii;
params.k_cache_interleaved ? ii * params.memory_max_len + ti_circ : ti_circ * Dh / QK_ELTS_IN_16B + ii;
// if( ti < params.timestep ) { // if( ti < params.timestep ) {
const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len); const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len);
if (ti < tlength) { if (ti < tlength) {
...@@ -1632,9 +1550,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1632,9 +1550,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
} }
else { else {
int beam_offset = 0; int beam_offset = 0;
if (HAS_BEAMS) {
beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh;
}
if (!QUANT_POLICY) { if (!QUANT_POLICY) {
k[ii] = vec_conversion<K_vec_k, K_vec_m>( k[ii] = vec_conversion<K_vec_k, K_vec_m>(
...@@ -1771,24 +1686,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1771,24 +1686,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
int8_t* v_cache_batch_int8 = nullptr; int8_t* v_cache_batch_int8 = nullptr;
if (!QUANT_POLICY) { if (!QUANT_POLICY) {
v_cache =
v_cache = params.v_cache_per_sample ? (params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + vi;
+ kvhi * params.memory_max_len * Dh + vi) :
&params.v_cache[bhi * params.memory_max_len * Dh + vi];
// Base pointer for the beam's batch, before offsetting with indirection buffer // Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi]; // T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
v_cache_batch = v_cache; v_cache_batch = v_cache;
} }
else if (QUANT_POLICY == 4) { else if (QUANT_POLICY == 4) {
if (params.v_cache_per_sample) { int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]);
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]); v_cache_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + vi;
v_cache_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + vi;
}
else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache);
v_cache_int8 = &ptr[bhi * params.memory_max_len * Dh + vi];
}
v_cache_batch_int8 = v_cache_int8; v_cache_batch_int8 = v_cache_int8;
} }
...@@ -1819,19 +1725,14 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1819,19 +1725,14 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// the compiler cannot optimize the codes automatically. // the compiler cannot optimize the codes automatically.
const int min_length = min(tlength, params.memory_max_len); const int min_length = min(tlength, params.memory_max_len);
for (int ti = first_step + vo; ti < min_length; ti += V_PER_ITER) { for (int ti = first_step + vo; ti < min_length; ti += V_PER_ITER) {
// Fetch offset based on cache_indir when beam sampling
const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti] : 0;
const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0;
// Load the values from the cache. // Load the values from the cache.
V_vec_k v; V_vec_k v;
if (!QUANT_POLICY) { if (!QUANT_POLICY) {
v = vec_conversion<V_vec_k, V_vec_m>( v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache_batch[ti * Dh]));
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * Dh]));
} }
else if (QUANT_POLICY == 4) { else if (QUANT_POLICY == 4) {
Packed_Int8_t v_vec_m_int8 = Packed_Int8_t v_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[ti * Dh]);
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale, v_zp); Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale, v_zp);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float); v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
...@@ -1867,18 +1768,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1867,18 +1768,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
} }
const int ti_circ = ti % params.memory_max_len; const int ti_circ = ti % params.memory_max_len;
// Fetch offset based on cache_indir when beam sampling
const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0;
const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0;
// Load the values from the cache. // Load the values from the cache.
V_vec_k v; V_vec_k v;
if (!QUANT_POLICY) { if (!QUANT_POLICY) {
v = vec_conversion<V_vec_k, V_vec_m>( v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache_batch[ti_circ * Dh]));
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti_circ * Dh]));
} }
else if (QUANT_POLICY == 4) { else if (QUANT_POLICY == 4) {
Packed_Int8_t v_vec_m_int8 = Packed_Int8_t v_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[ti_circ * Dh]);
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale, v_zp); Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale, v_zp);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float); v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
...@@ -1927,11 +1823,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1927,11 +1823,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
} }
// Store the V values to cache // Store the V values to cache
if (handle_kv && group_leader) { if (group_leader) {
// Store the values with bias back to global memory in the cache for V. // Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v;
if (!QUANT_POLICY) { if (!QUANT_POLICY) {
*reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v); *reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v);
} }
......
...@@ -19,25 +19,6 @@ ...@@ -19,25 +19,6 @@
namespace turbomind { namespace turbomind {
template<typename T>
void invokeAddQKVBiasIA3Transpose(T* q_buf,
T* k_buf,
T* v_buf,
T* Q,
const T* bias_Q,
T* K,
const T* bias_K,
T* V,
const T* bias_V,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
cudaStream_t stream);
template<typename T, typename T_IN> template<typename T, typename T_IN>
struct MaskedSoftmaxParam { struct MaskedSoftmaxParam {
// Common parameters. // Common parameters.
...@@ -69,27 +50,6 @@ void invokeTransposeQKV(T* dst, ...@@ -69,27 +50,6 @@ void invokeTransposeQKV(T* dst,
const int int8_mode, const int int8_mode,
cudaStream_t stream); cudaStream_t stream);
template<typename T>
void invokeAddQKVBiasIA3RebuildPadding(T* Q,
const T* bias_Q,
T* K,
const T* bias_K,
T* V,
const T* bias_V,
T* q_buf,
T* k_buf,
T* v_buf,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int valid_word_num,
const int* mask_offset,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
cudaStream_t stream);
template<typename T> template<typename T>
void invokeTransposeAttentionOutRemovePadding(T* src, void invokeTransposeAttentionOutRemovePadding(T* src,
T* dst, T* dst,
...@@ -103,36 +63,26 @@ void invokeTransposeAttentionOutRemovePadding(T* src, ...@@ -103,36 +63,26 @@ void invokeTransposeAttentionOutRemovePadding(T* src,
const int int8_mode, const int int8_mode,
cudaStream_t stream); cudaStream_t stream);
// Prefix Prompt Parameters
template<typename T> template<typename T>
struct PrefixPromptBatchWeightsParam { void invokeAddFusedQKVBiasTranspose(T* q_buf,
const T** d_prefix_prompt_batch = nullptr; T* k_buf,
const int* d_prefix_prompt_lengths = nullptr; T* v_buf,
const int max_prefix_prompt_length = 0; T* QKV,
// l * 2 * hidden_units_ / tensor_para_.world_size_ const T* qkv_bias,
const size_t prefix_prompt_layer_offset_per_seq = 0; const int* padding_offset,
}; const int* history_length,
const int* input_length,
template<typename T> const int batch_size,
void invokeAddFusedQKVBiasTranspose(T* q_buf, const int seq_len,
T* k_buf, const int token_num,
T* v_buf, const int head_num,
PrefixPromptBatchWeightsParam<T> param, const int kv_head_num,
T* QKV, const int size_per_head,
const T* qkv_bias, const int rotary_embedding_dim,
const int* padding_offset, int max_position_embeddings,
const int* history_length, bool use_dynamic_ntk,
const int batch_size, bool use_logn_attn,
const int seq_len, cudaStream_t stream);
const int token_num,
const int head_num,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
const int neox_rotary_style,
const float* scale,
const int int8_mode,
cudaStream_t stream);
template<typename T> template<typename T>
void invokeTranspose4d(T* dst, void invokeTranspose4d(T* dst,
......
...@@ -163,21 +163,21 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap* ...@@ -163,21 +163,21 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
invokeAddFusedQKVBiasTranspose(q_buf_2_, invokeAddFusedQKVBiasTranspose(q_buf_2_,
k_buf_2_, k_buf_2_,
v_buf_2_, v_buf_2_,
PrefixPromptBatchWeightsParam<T>{},
qkv_buf_, qkv_buf_,
weights->qkv.bias, weights->qkv.bias,
padding_offset, // padding_offset, padding_offset, // padding_offset,
history_length, // used for applying rotary embedding history_length, // used for applying rotary embedding
input_length,
batch_size, batch_size,
max_q_len, // seq_len max_q_len, // seq_len
num_token, // batch_size * seq_len num_token, // batch_size * seq_len
local_head_num_, local_head_num_,
local_kv_head_num_, local_kv_head_num_,
size_per_head_, size_per_head_,
rotary_embedding_dim_, params_.rotray_embedding_dim,
neox_rotary_style_, params_.max_position_embeddings,
nullptr, // query_weight.scale_out params_.use_dynamic_ntk,
0, // int8 mode params_.use_logn_attn,
stream_); stream_);
sync_check_cuda_error(); sync_check_cuda_error();
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "src/turbomind/models/llama/LlamaDenseWeight.h" #include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/LlamaLinear.h" #include "src/turbomind/models/llama/LlamaLinear.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/nccl_utils.h" #include "src/turbomind/utils/nccl_utils.h"
...@@ -34,26 +35,24 @@ public: ...@@ -34,26 +35,24 @@ public:
void freeBuffer(); void freeBuffer();
void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len); void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len);
LlamaContextAttentionLayer(size_t head_num, LlamaContextAttentionLayer(size_t head_num,
size_t kv_head_num, size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t rotary_embedding_dim, LlamaAttentionParams attn_params,
bool neox_rotary_style, NcclParam tensor_para,
NcclParam tensor_para, cudaStream_t stream,
cudaStream_t stream, cublasMMWrapper* cublas_wrapper,
cublasMMWrapper* cublas_wrapper, IAllocator* allocator,
IAllocator* allocator, bool is_free_buffer_after_forward,
bool is_free_buffer_after_forward, bool use_fmha,
bool use_fmha, int quant_policy):
int quant_policy):
head_num_(head_num), head_num_(head_num),
size_per_head_(size_per_head), size_per_head_(size_per_head),
hidden_units_(head_num * size_per_head), hidden_units_(head_num * size_per_head),
local_head_num_(head_num / tensor_para.world_size_), local_head_num_(head_num / tensor_para.world_size_),
local_kv_head_num_(kv_head_num / tensor_para.world_size_), local_kv_head_num_(kv_head_num / tensor_para.world_size_),
head_n_rep_(head_num / kv_head_num), head_n_rep_(head_num / kv_head_num),
rotary_embedding_dim_(rotary_embedding_dim), params_(attn_params),
neox_rotary_style_(neox_rotary_style),
tensor_para_(tensor_para), tensor_para_(tensor_para),
stream_(stream), stream_(stream),
cublas_wrapper_(cublas_wrapper), cublas_wrapper_(cublas_wrapper),
...@@ -99,10 +98,9 @@ private: ...@@ -99,10 +98,9 @@ private:
const size_t local_kv_head_num_; const size_t local_kv_head_num_;
const size_t local_head_num_; const size_t local_head_num_;
const size_t head_n_rep_; const size_t head_n_rep_;
const size_t rotary_embedding_dim_;
const bool is_free_buffer_after_forward_; const bool is_free_buffer_after_forward_;
const bool neox_rotary_style_; const LlamaAttentionParams params_;
const bool use_fmha_; const bool use_fmha_;
const int quant_policy_; const int quant_policy_;
......
...@@ -61,15 +61,17 @@ void LlamaContextDecoder<T>::freeBuffer() ...@@ -61,15 +61,17 @@ void LlamaContextDecoder<T>::freeBuffer()
} }
template<typename T> template<typename T>
void LlamaContextDecoder<T>::initialize(size_t kv_head_num, bool use_fmha, int quant_policy) void LlamaContextDecoder<T>::initialize(const LlamaAttentionParams& attn_params,
size_t kv_head_num,
bool use_fmha,
int quant_policy)
{ {
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true); h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
context_attention_layer_ = new LlamaContextAttentionLayer<T>(head_num_, context_attention_layer_ = new LlamaContextAttentionLayer<T>(head_num_,
kv_head_num, kv_head_num,
size_per_head_, size_per_head_,
rotary_embedding_dim_, attn_params,
false, // neox_rotary_style
tensor_para_, tensor_para_,
stream_, stream_,
cublas_wrapper_, cublas_wrapper_,
...@@ -124,32 +126,31 @@ void LlamaContextDecoder<T>::forwardSelfAttn(const Session& ...@@ -124,32 +126,31 @@ void LlamaContextDecoder<T>::forwardSelfAttn(const Session&
} }
template<typename T> template<typename T>
LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num, LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num,
size_t kv_head_num, size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
size_t rotary_embedding_dim, const LlamaAttentionParams& attn_params,
float rmsnorm_eps, float rmsnorm_eps,
NcclParam tensor_para, NcclParam tensor_para,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
bool use_fmha, bool use_fmha,
int quant_policy): int quant_policy):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
head_num_(head_num), head_num_(head_num),
size_per_head_(size_per_head), size_per_head_(size_per_head),
inter_size_(inter_size), inter_size_(inter_size),
hidden_units_(head_num * size_per_head), hidden_units_(head_num * size_per_head),
num_layer_(num_layer), num_layer_(num_layer),
rotary_embedding_dim_(rotary_embedding_dim),
rmsnorm_eps_(rmsnorm_eps), rmsnorm_eps_(rmsnorm_eps),
tensor_para_(tensor_para), tensor_para_(tensor_para),
data_type_(getTensorType<T>()) data_type_(getTensorType<T>())
{ {
initialize(kv_head_num, use_fmha, quant_policy); initialize(attn_params, kv_head_num, use_fmha, quant_policy);
} }
template<typename T> template<typename T>
......
...@@ -20,14 +20,11 @@ ...@@ -20,14 +20,11 @@
#pragma once #pragma once
// #include "src/turbomind/kernels/add_residual_kernels.h"
// #include "src/turbomind/kernels/layernorm_kernels.h"
#include "src/turbomind/layers/BaseLayer.h" #include "src/turbomind/layers/BaseLayer.h"
// #include "src/turbomind/layers/FfnLayer.h"
// #include "src/turbomind/layers/attention_layers/BaseAttentionLayer.h"
#include "src/turbomind/models/llama/LlamaContextAttentionLayer.h" #include "src/turbomind/models/llama/LlamaContextAttentionLayer.h"
#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h" #include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
#include "src/turbomind/models/llama/LlamaFfnLayer.h" #include "src/turbomind/models/llama/LlamaFfnLayer.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/allocator.h" #include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cublasMMWrapper.h"
...@@ -43,13 +40,12 @@ protected: ...@@ -43,13 +40,12 @@ protected:
void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len); void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len);
void freeBuffer() override; void freeBuffer() override;
void initialize(size_t kv_head_num, bool use_fmha, int quant_policy); void initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, bool use_fmha, int quant_policy);
size_t head_num_; size_t head_num_;
size_t size_per_head_; size_t size_per_head_;
size_t inter_size_; size_t inter_size_;
size_t num_layer_; size_t num_layer_;
size_t rotary_embedding_dim_;
size_t hidden_units_; size_t hidden_units_;
float rmsnorm_eps_; float rmsnorm_eps_;
...@@ -87,20 +83,20 @@ protected: ...@@ -87,20 +83,20 @@ protected:
bool is_final); bool is_final);
public: public:
LlamaContextDecoder(size_t head_num, LlamaContextDecoder(size_t head_num,
size_t kv_head_num, size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
size_t rotary_embedding_dim, const LlamaAttentionParams& attn_params,
float rmsnorm_eps, float rmsnorm_eps,
NcclParam tensor_para, NcclParam tensor_para,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
bool use_fmha, bool use_fmha,
int quant_policy); int quant_policy);
~LlamaContextDecoder() override; ~LlamaContextDecoder() override;
......
...@@ -23,37 +23,37 @@ ...@@ -23,37 +23,37 @@
#include "src/turbomind/macro.h" #include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/llama_decoder_kernels.h" #include "src/turbomind/models/llama/llama_decoder_kernels.h"
#include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/models/llama/llama_utils.h"
namespace turbomind { namespace turbomind {
template<typename T> template<typename T>
LlamaDecoder<T>::LlamaDecoder(size_t head_num, LlamaDecoder<T>::LlamaDecoder(size_t head_num,
size_t kv_head_num, size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
size_t rotary_embedding_dim, const LlamaAttentionParams& attn_params,
float rmsnorm_eps, float rmsnorm_eps,
NcclParam tensor_para, NcclParam tensor_para,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
int quant_policy): int quant_policy):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
head_num_(head_num), head_num_(head_num),
size_per_head_(size_per_head), size_per_head_(size_per_head),
inter_size_(inter_size), inter_size_(inter_size),
num_layer_(num_layer), num_layer_(num_layer),
rotary_embedding_dim_(rotary_embedding_dim),
hidden_units_(head_num * size_per_head), hidden_units_(head_num * size_per_head),
rmsnorm_eps_(rmsnorm_eps), rmsnorm_eps_(rmsnorm_eps),
tensor_para_(tensor_para), tensor_para_(tensor_para),
data_type_(getTensorType<T>()) data_type_(getTensorType<T>())
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
initialize(kv_head_num, quant_policy); initialize(attn_params, kv_head_num, quant_policy);
} }
template<typename T> template<typename T>
...@@ -65,15 +65,14 @@ LlamaDecoder<T>::~LlamaDecoder() ...@@ -65,15 +65,14 @@ LlamaDecoder<T>::~LlamaDecoder()
} }
template<typename T> template<typename T>
void LlamaDecoder<T>::initialize(size_t kv_head_num, int quant_policy) void LlamaDecoder<T>::initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, int quant_policy)
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
self_attention_layer_ = new LlamaDecoderSelfAttentionLayer<T>(head_num_, self_attention_layer_ = new LlamaDecoderSelfAttentionLayer<T>(head_num_,
kv_head_num, kv_head_num,
size_per_head_, size_per_head_,
rotary_embedding_dim_, attn_params,
false, // neox_rotary_style
tensor_para_, tensor_para_,
stream_, stream_,
cublas_wrapper_, cublas_wrapper_,
......
...@@ -20,10 +20,10 @@ ...@@ -20,10 +20,10 @@
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptDecoder.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptDecoder.h
#include "src/turbomind/layers/BaseLayer.h" #include "src/turbomind/layers/BaseLayer.h"
// #include "src/turbomind/layers/FfnLayer.h"
#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h" #include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
#include "src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h" #include "src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h"
#include "src/turbomind/models/llama/LlamaFfnLayer.h" #include "src/turbomind/models/llama/LlamaFfnLayer.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/custom_ar_comm.h" #include "src/turbomind/utils/custom_ar_comm.h"
#include "src/turbomind/utils/nccl_utils.h" #include "src/turbomind/utils/nccl_utils.h"
...@@ -35,13 +35,12 @@ protected: ...@@ -35,13 +35,12 @@ protected:
void allocateBuffer() override; // deprecated void allocateBuffer() override; // deprecated
void allocateBuffer(size_t batch_size); void allocateBuffer(size_t batch_size);
void freeBuffer() override; void freeBuffer() override;
void initialize(size_t kv_head_num, int quant_policy); void initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, int quant_policy);
size_t head_num_; size_t head_num_;
size_t size_per_head_; size_t size_per_head_;
size_t inter_size_; size_t inter_size_;
size_t num_layer_; size_t num_layer_;
size_t rotary_embedding_dim_;
size_t hidden_units_; size_t hidden_units_;
float rmsnorm_eps_; float rmsnorm_eps_;
...@@ -69,19 +68,19 @@ protected: ...@@ -69,19 +68,19 @@ protected:
void forwardFfn(const LlamaDecoder::Session& sess, T* ffn_io, size_t layer); void forwardFfn(const LlamaDecoder::Session& sess, T* ffn_io, size_t layer);
public: public:
LlamaDecoder(size_t head_num, LlamaDecoder(size_t head_num,
size_t kv_head_num, size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
size_t rotary_embedding_dim, const LlamaAttentionParams& attn_params,
float rmsnorm_eps, float rmsnorm_eps,
NcclParam tensor_para, NcclParam tensor_para,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
int quant_policy); int quant_policy);
~LlamaDecoder() override; ~LlamaDecoder() override;
......
...@@ -61,6 +61,9 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -61,6 +61,9 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
const int kv_head_num, const int kv_head_num,
const int size_per_head, const int size_per_head,
const int rotary_embedding_dim, const int rotary_embedding_dim,
const int max_position_embeddings,
const bool use_dynamic_ntk,
const bool use_logn_attn,
const int memory_max_len, const int memory_max_len,
const int* prefix_prompt_lengths, const int* prefix_prompt_lengths,
const int max_prefix_prompt_length, const int max_prefix_prompt_length,
...@@ -110,13 +113,9 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -110,13 +113,9 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
FT_CHECK(k_cache_per_sample && v_cache_per_sample); FT_CHECK(k_cache_per_sample && v_cache_per_sample);
params.k_cache = reinterpret_cast<DataType*>(key_cache);
params.v_cache = reinterpret_cast<DataType*>(value_cache);
params.k_cache_per_sample = reinterpret_cast<DataType**>(k_cache_per_sample); params.k_cache_per_sample = reinterpret_cast<DataType**>(k_cache_per_sample);
params.v_cache_per_sample = reinterpret_cast<DataType**>(v_cache_per_sample); params.v_cache_per_sample = reinterpret_cast<DataType**>(v_cache_per_sample);
params.kv_cache_per_sample_offset = kv_cache_per_sample_offset; params.kv_cache_per_sample_offset = kv_cache_per_sample_offset;
params.k_cache_interleaved = false;
params.cache_indir = cache_indir;
params.batch_size = inference_batch_size; params.batch_size = inference_batch_size;
params.beam_width = beam_width; params.beam_width = beam_width;
params.memory_max_len = memory_max_len; params.memory_max_len = memory_max_len;
...@@ -128,8 +127,12 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -128,8 +127,12 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
params.num_heads = head_num; params.num_heads = head_num;
params.num_kv_heads = kv_head_num; params.num_kv_heads = kv_head_num;
params.hidden_size_per_head = size_per_head; params.hidden_size_per_head = size_per_head;
params.rotary_embedding_dim = rotary_embedding_dim; params.rotary_embedding_dim = rotary_embedding_dim;
params.max_position_embeddings = max_position_embeddings;
params.use_dynamic_ntk = use_dynamic_ntk;
params.use_logn_attn = use_logn_attn;
// Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust) // Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * q_scaling); params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * q_scaling);
...@@ -146,10 +149,6 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -146,10 +149,6 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
} }
params.max_input_length = max_input_len; params.max_input_length = max_input_len;
params.ia3_tasks = ia3_tasks;
params.ia3_key_weights = reinterpret_cast<const DataType*>(ia3_key_weights);
params.ia3_value_weights = reinterpret_cast<const DataType*>(ia3_value_weights);
params.int8_mode = int8_mode; params.int8_mode = int8_mode;
if (int8_mode & QuantPolicy::kCacheKVInt8) { if (int8_mode & QuantPolicy::kCacheKVInt8) {
...@@ -185,8 +184,6 @@ void LlamaDecoderSelfAttentionLayer<T>::freeBuffer() ...@@ -185,8 +184,6 @@ void LlamaDecoderSelfAttentionLayer<T>::freeBuffer()
if (is_allocate_buffer_) { if (is_allocate_buffer_) {
allocator_->free((void**)(&qkv_buf_)); allocator_->free((void**)(&qkv_buf_));
allocator_->free((void**)(&context_buf_)); allocator_->free((void**)(&context_buf_));
// allocator_->free((void**)(&k_cache_buf_));
// allocator_->free((void**)(&v_cache_buf_));
is_allocate_buffer_ = false; is_allocate_buffer_ = false;
} }
} }
...@@ -263,7 +260,10 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o ...@@ -263,7 +260,10 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
local_head_num_, local_head_num_,
local_kv_head_num_, local_kv_head_num_,
size_per_head_, size_per_head_,
rotary_embedding_dim_, params_.rotray_embedding_dim,
params_.max_position_embeddings,
params_.use_dynamic_ntk,
params_.use_logn_attn,
memory_len, memory_len,
nullptr, // prefix_prompt_lengths nullptr, // prefix_prompt_lengths
0, // max_prefix_prompt_length 0, // max_prefix_prompt_length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment