Unverified Commit 4a60b45d authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

[Feature] Support Qwen-7B, dynamic NTK scaling and logN scaling in turbomind (#230)

* qwen support

* dynamic ntk & logn attn

* fix ntk & add chat template

* fix ntk scaling & stop words

* fix lint

* add tiktoken to requirements.txt

* fix tokenizer, set model format automatically

* update model.py

* update readme

* fix lint
parent 62b60db7
...@@ -13,6 +13,7 @@ ______________________________________________________________________ ...@@ -13,6 +13,7 @@ ______________________________________________________________________
## News 🎉 ## News 🎉
- \[2023/08\] TurboMind supports Qwen-7B, dynamic NTK-RoPE scaling and dynamic logN scaling
- \[2023/08\] TurboMind supports Windows (tp=1) - \[2023/08\] TurboMind supports Windows (tp=1)
- \[2023/08\] TurboMind supports 4-bit inference, 2.4x faster than FP16, the fastest open-source implementation🚀. Check [this](./docs/en/w4a16.md) guide for detailed info - \[2023/08\] TurboMind supports 4-bit inference, 2.4x faster than FP16, the fastest open-source implementation🚀. Check [this](./docs/en/w4a16.md) guide for detailed info
- \[2023/08\] LMDeploy has launched on the [HuggingFace Hub](https://huggingface.co/lmdeploy), providing ready-to-use 4-bit models. - \[2023/08\] LMDeploy has launched on the [HuggingFace Hub](https://huggingface.co/lmdeploy), providing ready-to-use 4-bit models.
......
...@@ -13,6 +13,7 @@ ______________________________________________________________________ ...@@ -13,6 +13,7 @@ ______________________________________________________________________
## 更新 🎉 ## 更新 🎉
- \[2023/08\] TurboMind 支持 Qwen-7B,动态NTK-RoPE缩放,动态logN缩放
- \[2023/08\] TurboMind 支持 Windows (tp=1) - \[2023/08\] TurboMind 支持 Windows (tp=1)
- \[2023/08\] TurboMind 支持 4-bit 推理,速度是 FP16 的 2.4 倍,是目前最快的开源实现🚀。部署方式请看[这里](./docs/zh_cn/w4a16.md) - \[2023/08\] TurboMind 支持 4-bit 推理,速度是 FP16 的 2.4 倍,是目前最快的开源实现🚀。部署方式请看[这里](./docs/zh_cn/w4a16.md)
- \[2023/08\] LMDeploy 开通了 [HuggingFace Hub](https://huggingface.co/lmdeploy) ,提供开箱即用的 4-bit 模型 - \[2023/08\] LMDeploy 开通了 [HuggingFace Hub](https://huggingface.co/lmdeploy) ,提供开箱即用的 4-bit 模型
......
...@@ -16,7 +16,8 @@ class Tokenizer: ...@@ -16,7 +16,8 @@ class Tokenizer:
self.pad_id = self.model.pad_id() self.pad_id = self.model.pad_id()
else: else:
from transformers import AutoTokenizer from transformers import AutoTokenizer
self.model = AutoTokenizer.from_pretrained(model_file) self.model = AutoTokenizer.from_pretrained(model_file,
trust_remote_code=True)
self.vocab_size = self.model.vocab_size self.vocab_size = self.model.vocab_size
self.start_id = self.model.bos_token_id self.start_id = self.model.bos_token_id
self.end_id = self.model.eos_token_id self.end_id = self.model.eos_token_id
......
...@@ -161,6 +161,36 @@ If a question does not make any sense, or is not factually coherent, explain why ...@@ -161,6 +161,36 @@ If a question does not make any sense, or is not factually coherent, explain why
return f'{self.b_inst} {prompt} {self.e_inst} ' return f'{self.b_inst} {prompt} {self.e_inst} '
@MODELS.register_module(name='qwen-7b')
class Qwen7BChat(BaseModel):
"""Chat template for Qwen-7B-Chat."""
def __init__(self):
super().__init__()
self.session_len = 8192
self.top_p = 0.5
self.top_k = 40
self.temperature = 1.0
self.im_start = '<|im_start|>'
self.im_end = '<|im_end|>'
self.system = 'You are a helpful assistant.'
def get_prompt(self, prompt, sequence_start=True):
if sequence_start:
return f'{self.im_start}system\n{self.system}{self.im_end}' \
f'\n{self.im_start}user\n{prompt}{self.im_end}' \
f'\n{self.im_start}assistant\n'
return f'\n{self.im_start}user\n{prompt}{self.im_end}' \
f'\n{self.im_start}assistant\n'
@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [151645] # <|im_end|>
def main(model_name: str = 'test'): def main(model_name: str = 'test'):
assert model_name in MODELS.module_dict.keys(), \ assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \ f"'{model_name}' is not supported. " \
......
...@@ -16,7 +16,7 @@ from sentencepiece import SentencePieceProcessor ...@@ -16,7 +16,7 @@ from sentencepiece import SentencePieceProcessor
import lmdeploy import lmdeploy
from lmdeploy.model import MODELS from lmdeploy.model import MODELS
supported_formats = ['llama', 'hf', 'awq'] supported_formats = ['llama', 'hf', 'awq', 'qwen']
def get_package_root_path(): def get_package_root_path():
...@@ -84,7 +84,7 @@ def copy_triton_model_templates(_path: str): ...@@ -84,7 +84,7 @@ def copy_triton_model_templates(_path: str):
return None return None
def tokenizer_info(model_path: str): def tokenizer_info_sp(model_path: str):
"""Return the vocabulary size, bos token id and eos token id. """Return the vocabulary size, bos token id and eos token id.
Args: Args:
...@@ -101,6 +101,13 @@ def tokenizer_info(model_path: str): ...@@ -101,6 +101,13 @@ def tokenizer_info(model_path: str):
return n_words, bos_id, eos_id return n_words, bos_id, eos_id
def tokenizer_info_qwen(model_dir: str):
n_words = 151851
bos_id = 0
eos_id = 151643
return n_words, bos_id, eos_id
def export(model_name: str, def export(model_name: str,
num_layer: int, num_layer: int,
norm_eps: float, norm_eps: float,
...@@ -111,7 +118,11 @@ def export(model_name: str, ...@@ -111,7 +118,11 @@ def export(model_name: str,
tp: int, tp: int,
size_per_head: int = 128, size_per_head: int = 128,
group_size: int = 0, group_size: int = 0,
weight_type: str = 'fp16'): weight_type: str = 'fp16',
max_position_embeddings: int = 0,
use_dynamic_ntk: int = 0,
use_logn_attn: int = 0,
tokenizer_info=tokenizer_info_sp):
"""Export deploying information to a config file. """Export deploying information to a config file.
Args: Args:
...@@ -191,7 +202,7 @@ def export(model_name: str, ...@@ -191,7 +202,7 @@ def export(model_name: str,
head_num=head_num, head_num=head_num,
kv_head_num=kv_head_num, kv_head_num=kv_head_num,
size_per_head=size_per_head, size_per_head=size_per_head,
vocab_size=vocab_size, vocab_size=_vocab_size,
num_layer=num_layer, num_layer=num_layer,
rotary_embedding=size_per_head, rotary_embedding=size_per_head,
inter_size=inter_size, inter_size=inter_size,
...@@ -210,7 +221,11 @@ def export(model_name: str, ...@@ -210,7 +221,11 @@ def export(model_name: str,
cache_chunk_size=1, cache_chunk_size=1,
use_context_fmha=1, use_context_fmha=1,
quant_policy=0, quant_policy=0,
tensor_para_size=tp)) tensor_para_size=tp,
# extra attention params
max_position_embeddings=max_position_embeddings,
use_dynamic_ntk=int(use_dynamic_ntk),
use_logn_attn=int(use_logn_attn)))
config = configparser.ConfigParser() config = configparser.ConfigParser()
for section, key_values in cfg.items(): for section, key_values in cfg.items():
...@@ -725,6 +740,134 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str, ...@@ -725,6 +740,134 @@ def deploy_awq(model_name: str, model_path: str, tokenizer_path: str,
group_size=group_size) group_size=group_size)
def deploy_qwen(model_name: str, model_path: str, tokenizer_path: str,
triton_models_path: str, tp: int):
"""Deploy a model with huggingface transformers' format.
Args:
model_name (str): the name of the to-be-deployed model
model_path (str): the path of the directory where the model weight
files are
tokenizer_path (str): the path of the tokenizer model path
triton_models_path (str): the path of the exported triton models
tp (int): the number of tensor parallelism
quant_path (str): path of the quantized model, which can be None
group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits
"""
if osp.exists(model_path):
shutil.copy(osp.join(model_path, 'qwen.tiktoken'),
osp.join(triton_models_path, 'tokenizer'))
for _file in os.listdir(model_path):
if _file.endswith('.json') or _file.endswith('.py'):
json_path = osp.join(model_path, _file)
shutil.copy(json_path,
osp.join(triton_models_path, 'tokenizer', _file))
with get_package_root_path() as root_path:
shutil.copy(osp.join(root_path, 'turbomind/tokenizer.py'),
osp.join(triton_models_path, 'tokenizer'))
else:
print(f'tokenizer model {tokenizer_path} does not exist')
exit(-1)
# read model arguments from params.json
try:
params_path = osp.join(model_path, 'config.json')
with open(params_path) as f:
config = json.load(f)
num_layer = config['num_hidden_layers']
norm_eps = config['layer_norm_epsilon']
if 'num_key_value_heads' in config:
kv_head_num = config['num_key_value_heads']
else:
kv_head_num = config['num_attention_heads']
seq_length = config['seq_length']
use_dynamic_ntk = config['use_dynamic_ntk']
use_logn_attn = config['use_logn_attn']
except Exception as e:
print(f'get "num_hidden_layers" and "layer_norm_epsilon" from '
f'{params_path} failed: {e}')
return False
# convert weights from hf to turbomind
model_params = {}
_files = [file for file in os.listdir(model_path) if file.endswith('.bin')]
_files = sorted(_files)
print(_files)
_params = {}
for _file in _files:
_tmp = torch.load(osp.join(model_path, _file), map_location='cpu')
_params.update(_tmp)
def get_tensor(name, trans=True):
"""return a transposed tensor according its name."""
if trans:
return _params[name].cuda().t()
else:
return _params[name].cuda()
for i in range(num_layer):
print(i)
# qkv weights
qkv_w = get_tensor(f'transformer.h.{i}.attn.c_attn.weight')
q_w, k_w, v_w = torch.split(qkv_w, qkv_w.size(-1) // 3, dim=-1)
q_w, k_w = permute(q_w), permute(k_w)
qkv_w = merge_qkv(q_w, k_w, v_w, tp, dim=2)
model_params[f'layers.{i}.attention.w_qkv.weight'] = qkv_w
# qkv bias
qkv_b = get_tensor(f'transformer.h.{i}.attn.c_attn.bias')
q_b, k_b, v_b = torch.split(qkv_b, qkv_b.size(-1) // 3)
q_b, k_b = permute(q_b), permute(k_b)
qkv_b = merge_qkv(q_b, k_b, v_b, tp, dim=1)
model_params[f'layers.{i}.attention.w_qkv.bias'] = qkv_b
# o weights
o_w = get_tensor(f'transformer.h.{i}.attn.c_proj.weight')
model_params[f'layers.{i}.attention.wo.weight'] = o_w
model_params[f'layers.{i}.attention.wo.bias'] = torch.zeros_like(q_b)
# ffn weights
# ours: w2(silu(w1(x)) * w3(x))
# qwen: c_proj(w1(x) * silu(w2(x)))
w1 = get_tensor(f'transformer.h.{i}.mlp.w2.weight')
w3 = get_tensor(f'transformer.h.{i}.mlp.w1.weight')
w2 = get_tensor(f'transformer.h.{i}.mlp.c_proj.weight')
model_params[f'layers.{i}.feed_forward.w1.weight'] = w1
model_params[f'layers.{i}.feed_forward.w2.weight'] = w2
model_params[f'layers.{i}.feed_forward.w3.weight'] = w3
# norm weights
attn_norm = get_tensor(f'transformer.h.{i}.ln_1.weight')
ffn_norm = get_tensor(f'transformer.h.{i}.ln_2.weight')
model_params[f'layers.{i}.attention_norm.weight'] = attn_norm
model_params[f'layers.{i}.ffn_norm.weight'] = ffn_norm
other = [('tok_embeddings.weight', 'transformer.wte.weight'),
('norm.weight', 'transformer.ln_f.weight'),
('output.weight', 'lm_head.weight')]
for ft, hf in other:
model_params[ft] = get_tensor(hf, trans=False)
return export(model_name,
num_layer,
norm_eps,
kv_head_num,
model_params,
model_path,
triton_models_path,
tp,
max_position_embeddings=seq_length,
use_dynamic_ntk=use_dynamic_ntk,
use_logn_attn=use_logn_attn,
tokenizer_info=tokenizer_info_qwen)
def pack_model_repository(workspace_path: str): def pack_model_repository(workspace_path: str):
"""package the model repository. """package the model repository.
...@@ -752,7 +895,7 @@ def pack_model_repository(workspace_path: str): ...@@ -752,7 +895,7 @@ def pack_model_repository(workspace_path: str):
def main(model_name: str, def main(model_name: str,
model_path: str, model_path: str,
model_format: str = 'hf', model_format: str = None,
tokenizer_path: str = None, tokenizer_path: str = None,
dst_path: str = './workspace', dst_path: str = './workspace',
tp: int = 1, tp: int = 1,
...@@ -777,6 +920,9 @@ def main(model_name: str, ...@@ -777,6 +920,9 @@ def main(model_name: str,
f"'{model_name}' is not supported. " \ f"'{model_name}' is not supported. " \
f'The supported models are: {MODELS.module_dict.keys()}' f'The supported models are: {MODELS.module_dict.keys()}'
if model_format is None:
model_format = 'qwen' if model_name == 'qwen-7b' else 'hf'
if model_format not in supported_formats: if model_format not in supported_formats:
print(f'the model format "{model_format}" is not supported. ' print(f'the model format "{model_format}" is not supported. '
f'The supported format are: {supported_formats}') f'The supported format are: {supported_formats}')
...@@ -803,6 +949,9 @@ def main(model_name: str, ...@@ -803,6 +949,9 @@ def main(model_name: str,
elif model_format == 'awq': elif model_format == 'awq':
res = deploy_awq(model_name, model_path, tokenizer_path, res = deploy_awq(model_name, model_path, tokenizer_path,
triton_models_path, tp, quant_path, group_size) triton_models_path, tp, quant_path, group_size)
elif model_format == 'qwen':
res = deploy_qwen(model_name, model_path, tokenizer_path,
triton_models_path, tp)
# update `tensor_para_size` in `triton_models/interactive/config.pbtxt` # update `tensor_para_size` in `triton_models/interactive/config.pbtxt`
with open(osp.join(triton_models_path, 'interactive/config.pbtxt'), with open(osp.join(triton_models_path, 'interactive/config.pbtxt'),
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp import os.path as osp
from typing import Sequence, Union from typing import Sequence, Union
...@@ -99,6 +100,13 @@ class HuggingFaceTokenizer: ...@@ -99,6 +100,13 @@ class HuggingFaceTokenizer:
if hasattr(self.model, 'backend_tokenizer'): if hasattr(self.model, 'backend_tokenizer'):
self.model.backend_tokenizer.save(backend_tokenizer_file) self.model.backend_tokenizer.save(backend_tokenizer_file)
if self.model.eos_token_id is None:
generation_config_file = osp.join(model_dir,
'generation_config.json')
with open(generation_config_file, 'r') as f:
cfg = json.load(f)
self.model.eos_token_id = cfg['eos_token_id']
@property @property
def vocab_size(self): def vocab_size(self):
"""vocabulary size.""" """vocabulary size."""
......
...@@ -60,13 +60,6 @@ struct Multihead_attention_params_base { ...@@ -60,13 +60,6 @@ struct Multihead_attention_params_base {
// The input Vs and the associated bias. Dimensions B x D and D, resp. // The input Vs and the associated bias. Dimensions B x D and D, resp.
const T *v = nullptr, *v_bias = nullptr; const T *v = nullptr, *v_bias = nullptr;
// The cache for the Ks. The size must be at least B x L x D.
T* k_cache = nullptr;
// The cache for the Vs. The size must be at least B x L x D.
T* v_cache = nullptr;
// The indirections to use for cache when beam sampling.
const int* cache_indir = nullptr;
// scales // scales
const float* query_weight_output_scale = nullptr; const float* query_weight_output_scale = nullptr;
const float* attention_qk_scale = nullptr; const float* attention_qk_scale = nullptr;
...@@ -108,10 +101,6 @@ struct Multihead_attention_params_base { ...@@ -108,10 +101,6 @@ struct Multihead_attention_params_base {
// The slope per head of linear position bias to attention score (H). // The slope per head of linear position bias to attention score (H).
const T* linear_bias_slopes = nullptr; const T* linear_bias_slopes = nullptr;
const T* ia3_key_weights = nullptr;
const T* ia3_value_weights = nullptr;
const int* ia3_tasks = nullptr;
const float* qkv_scale_out = nullptr; const float* qkv_scale_out = nullptr;
const float* attention_out_scale = nullptr; const float* attention_out_scale = nullptr;
int int8_mode = 0; int int8_mode = 0;
...@@ -123,17 +112,15 @@ struct Multihead_attention_params_base { ...@@ -123,17 +112,15 @@ struct Multihead_attention_params_base {
template<typename T> template<typename T>
struct Multihead_attention_params: public Multihead_attention_params_base<T> { struct Multihead_attention_params: public Multihead_attention_params_base<T> {
// allows to exist attention eary bool* finished = nullptr;
bool* finished = nullptr; const int* length_per_sample = nullptr;
T** k_cache_per_sample = nullptr;
// required in case of masked attention with different length T** v_cache_per_sample = nullptr;
const int* length_per_sample = nullptr; size_t kv_cache_per_sample_offset = 0;
int num_kv_heads = 0;
T** k_cache_per_sample = nullptr; int max_position_embeddings = 0;
T** v_cache_per_sample = nullptr; bool use_dynamic_ntk = false;
size_t kv_cache_per_sample_offset = 0; bool use_logn_attn = false;
bool k_cache_interleaved = true;
int num_kv_heads = 0;
}; };
template<class T> template<class T>
......
...@@ -41,15 +41,12 @@ ...@@ -41,15 +41,12 @@
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// !!! Specialize the launcher for Cross attention
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE> template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{ {
constexpr int THREADS_PER_VALUE = threads_per_value_t<T, Dh_MAX>::value; constexpr int THREADS_PER_VALUE = threads_per_value_t<T, Dh_MAX>::value;
// constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
const int tlength = params.timestep;
FT_CHECK(params.cache_indir == nullptr); const int tlength = params.timestep;
if (params.int8_mode == 4) { if (params.int8_mode == 4) {
if (tlength < 32) { if (tlength < 32) {
......
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
#include "src/turbomind/kernels/decoder_masked_multihead_attention.h" #include "src/turbomind/kernels/decoder_masked_multihead_attention.h"
#include "src/turbomind/kernels/decoder_masked_multihead_attention_utils.h" #include "src/turbomind/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/turbomind/macro.h" #include "src/turbomind/macro.h"
// #include "src/turbomind/utils/cuda_bf16_wrapper.h"
// #include "src/turbomind/utils/cuda_fp8_utils.h"
#include "src/turbomind/utils/cuda_type_utils.cuh" #include "src/turbomind/utils/cuda_type_utils.cuh"
#include <assert.h> #include <assert.h>
#include <float.h> #include <float.h>
...@@ -592,42 +590,6 @@ __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) ...@@ -592,42 +590,6 @@ __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
return b; return b;
} }
#ifdef ENABLE_FP8
// fp8_t
template<>
__inline__ __device__ float vec_conversion<float, __nv_fp8_e4m3>(const __nv_fp8_e4m3& a)
{
return float(a);
}
template<>
__inline__ __device__ __nv_fp8_e4m3 vec_conversion<__nv_fp8_e4m3, float>(const float& a)
{
return __nv_fp8_e4m3(a);
}
// fp8_2_t
template<>
__inline__ __device__ float2 vec_conversion<float2, fp8_2_t>(const fp8_2_t& a)
{
return float2(a);
}
template<>
__inline__ __device__ fp8_2_t vec_conversion<fp8_2_t, float2>(const float2& a)
{
return fp8_2_t(a);
}
// fp8_4_t
template<>
__inline__ __device__ float4 vec_conversion<float4, fp8_4_t>(const fp8_4_t& a)
{
return float4(a);
}
template<>
__inline__ __device__ fp8_4_t vec_conversion<fp8_4_t, float4>(const float4& a)
{
return fp8_4_t(a);
}
#endif // ENABLE_FP8
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS_PER_KEY, typename K_vec, int N> template<int THREADS_PER_KEY, typename K_vec, int N>
...@@ -868,19 +830,6 @@ inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) ...@@ -868,19 +830,6 @@ inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_FP8
inline __device__ void convert_from_float(fp8_4_t& dst, float4 src)
{
dst = fp8_4_t(src);
}
inline __device__ void convert_from_float(fp8_2_t& dst, float2 src)
{
dst = fp8_2_t(src);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(float2& dst, float2 src) inline __device__ void convert_from_float(float2& dst, float2 src)
{ {
dst = src; dst = src;
...@@ -1365,8 +1314,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1365,8 +1314,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// The thread in the block. // The thread in the block.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
constexpr bool handle_kv = true;
// While doing the product Q*K^T for the different keys we track the max. // While doing the product Q*K^T for the different keys we track the max.
float qk_max = -FLT_MAX; float qk_max = -FLT_MAX;
...@@ -1422,28 +1369,36 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1422,28 +1369,36 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
Qk_vec_k k_bias; Qk_vec_k k_bias;
zero(k_bias); zero(k_bias);
if (handle_kv) {
k_bias = k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_bias[k_bias_offset])) :
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_bias[k_bias_offset])) : k_bias;
k_bias;
}
// Computes the Q/K values with bias. // Computes the Q/K values with bias.
q = add(q, q_bias); q = add(q, q_bias);
if (handle_kv) { k = add(k, k_bias);
k = add(k, k_bias);
float rotary_emb_base = 10000.f;
if (params.use_dynamic_ntk) {
// +1 because of `length_per_sample == context_length - 1`
rotary_emb_base = rotary_embedding_get_base(params.length_per_sample[bi] + 1,
params.max_position_embeddings,
params.rotary_embedding_dim,
rotary_emb_base);
} }
// Padded len // Padded len
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
if (params.rotary_embedding_dim > 0) { if (params.rotary_embedding_dim > 0) {
if (handle_kv) { apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, rotary_emb_base, params.timestep - padd_len);
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep - padd_len); }
}
else { if (params.use_logn_attn) {
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep - padd_len); T log_n_scaling;
} // +1 because of `length_per_sample == context_length - 1`
convert_from_float(log_n_scaling,
logn_attn_get_scaling(params.length_per_sample[bi] + 1, params.max_position_embeddings));
q = mul<Qk_vec_k, T, Qk_vec_k>(log_n_scaling, q);
} }
if (!is_masked) { if (!is_masked) {
...@@ -1462,47 +1417,23 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1462,47 +1417,23 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// The position of the thread in that 16B chunk. // The position of the thread in that 16B chunk.
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
if (handle_kv && group_leader) { if (group_leader) {
// Trigger the stores to global memory. // Trigger the stores to global memory.
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
if (!params.k_cache_per_sample) {
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B
+ tlength_circ * QK_ELTS_IN_16B + ci;
if (!QUANT_POLICY) { int offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + tlength_circ * Dh
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k); + co * QK_ELTS_IN_16B + ci;
}
else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale, k_zp);
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache); if (!QUANT_POLICY) {
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8; *reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) =
} vec_conversion<Qk_vec_m, Qk_vec_k>(k);
} }
else { else if (QUANT_POLICY == 4) {
int offset; using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
if (params.k_cache_interleaved) { Packed_Int8_t k_int8 = quant(k, k_scale, k_zp);
offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh
+ co * params.memory_max_len * QK_ELTS_IN_16B + tlength_circ * QK_ELTS_IN_16B + ci;
}
else {
offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh
+ tlength_circ * Dh + co * QK_ELTS_IN_16B + ci;
}
if (!QUANT_POLICY) { int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
*reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) = *reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
vec_conversion<Qk_vec_m, Qk_vec_k>(k);
}
else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale, k_zp);
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
}
} }
} }
} }
...@@ -1584,33 +1515,21 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1584,33 +1515,21 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
int8_t* k_cache_batch_int8 = nullptr; int8_t* k_cache_batch_int8 = nullptr;
if (!QUANT_POLICY) { if (!QUANT_POLICY) {
k_cache_batch = params.k_cache_per_sample ? (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset k_cache_batch =
+ kvhi * params.memory_max_len * Dh + ki) : params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + ki;
&params.k_cache[bhi * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
} }
else if (QUANT_POLICY == 4) { else if (QUANT_POLICY == 4) {
// convert k_cache_per_sample to int8 // convert k_cache_per_sample to int8
if (params.k_cache_per_sample) { int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]); k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + ki;
k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + ki;
}
else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache);
k_cache_batch_int8 = &ptr[bhi * params.memory_max_len * Dh + ki];
}
} }
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
// prefix prompt length if has
const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi];
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
const int* beam_indices = HAS_BEAMS ? &params.cache_indir[bi_seq_len_offset] : nullptr; // const int* beam_indices = HAS_BEAMS ? &params.cache_indir[bi_seq_len_offset] : nullptr;
for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) { for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) {
const int ti_circ = ti % params.memory_max_len; const int ti_circ = ti % params.memory_max_len;
...@@ -1622,8 +1541,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1622,8 +1541,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
zero(k_vec_zero); zero(k_vec_zero);
#pragma unroll #pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = int jj = ti_circ * Dh / QK_ELTS_IN_16B + ii;
params.k_cache_interleaved ? ii * params.memory_max_len + ti_circ : ti_circ * Dh / QK_ELTS_IN_16B + ii;
// if( ti < params.timestep ) { // if( ti < params.timestep ) {
const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len); const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len);
if (ti < tlength) { if (ti < tlength) {
...@@ -1632,9 +1550,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1632,9 +1550,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
} }
else { else {
int beam_offset = 0; int beam_offset = 0;
if (HAS_BEAMS) {
beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh;
}
if (!QUANT_POLICY) { if (!QUANT_POLICY) {
k[ii] = vec_conversion<K_vec_k, K_vec_m>( k[ii] = vec_conversion<K_vec_k, K_vec_m>(
...@@ -1771,24 +1686,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1771,24 +1686,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
int8_t* v_cache_batch_int8 = nullptr; int8_t* v_cache_batch_int8 = nullptr;
if (!QUANT_POLICY) { if (!QUANT_POLICY) {
v_cache =
v_cache = params.v_cache_per_sample ? (params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + vi;
+ kvhi * params.memory_max_len * Dh + vi) :
&params.v_cache[bhi * params.memory_max_len * Dh + vi];
// Base pointer for the beam's batch, before offsetting with indirection buffer // Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi]; // T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
v_cache_batch = v_cache; v_cache_batch = v_cache;
} }
else if (QUANT_POLICY == 4) { else if (QUANT_POLICY == 4) {
if (params.v_cache_per_sample) { int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]);
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]); v_cache_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + vi;
v_cache_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + vi;
}
else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache);
v_cache_int8 = &ptr[bhi * params.memory_max_len * Dh + vi];
}
v_cache_batch_int8 = v_cache_int8; v_cache_batch_int8 = v_cache_int8;
} }
...@@ -1819,19 +1725,14 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1819,19 +1725,14 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// the compiler cannot optimize the codes automatically. // the compiler cannot optimize the codes automatically.
const int min_length = min(tlength, params.memory_max_len); const int min_length = min(tlength, params.memory_max_len);
for (int ti = first_step + vo; ti < min_length; ti += V_PER_ITER) { for (int ti = first_step + vo; ti < min_length; ti += V_PER_ITER) {
// Fetch offset based on cache_indir when beam sampling
const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti] : 0;
const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0;
// Load the values from the cache. // Load the values from the cache.
V_vec_k v; V_vec_k v;
if (!QUANT_POLICY) { if (!QUANT_POLICY) {
v = vec_conversion<V_vec_k, V_vec_m>( v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache_batch[ti * Dh]));
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * Dh]));
} }
else if (QUANT_POLICY == 4) { else if (QUANT_POLICY == 4) {
Packed_Int8_t v_vec_m_int8 = Packed_Int8_t v_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[ti * Dh]);
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale, v_zp); Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale, v_zp);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float); v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
...@@ -1867,18 +1768,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1867,18 +1768,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
} }
const int ti_circ = ti % params.memory_max_len; const int ti_circ = ti % params.memory_max_len;
// Fetch offset based on cache_indir when beam sampling
const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0;
const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0;
// Load the values from the cache. // Load the values from the cache.
V_vec_k v; V_vec_k v;
if (!QUANT_POLICY) { if (!QUANT_POLICY) {
v = vec_conversion<V_vec_k, V_vec_m>( v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache_batch[ti_circ * Dh]));
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti_circ * Dh]));
} }
else if (QUANT_POLICY == 4) { else if (QUANT_POLICY == 4) {
Packed_Int8_t v_vec_m_int8 = Packed_Int8_t v_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[ti_circ * Dh]);
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale, v_zp); Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale, v_zp);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float); v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
...@@ -1927,11 +1823,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1927,11 +1823,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
} }
// Store the V values to cache // Store the V values to cache
if (handle_kv && group_leader) { if (group_leader) {
// Store the values with bias back to global memory in the cache for V. // Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v;
if (!QUANT_POLICY) { if (!QUANT_POLICY) {
*reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v); *reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v);
} }
......
...@@ -43,35 +43,6 @@ struct Float4_ { ...@@ -43,35 +43,6 @@ struct Float4_ {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
struct bf16_4_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct bf16_8_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
__nv_bfloat162 z;
__nv_bfloat162 w;
};
#endif
#ifdef ENABLE_FP8
using fp8_2_t = __nv_fp8x2_e4m3;
using fp8_4_t = __nv_fp8x4_e4m3;
struct fp8_8_t {
__nv_fp8_e4m3 x;
__nv_fp8_e4m3 y;
__nv_fp8_e4m3 z;
__nv_fp8_e4m3 w;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
struct num_elems; struct num_elems;
template<> template<>
...@@ -108,40 +79,6 @@ struct num_elems<uint4> { ...@@ -108,40 +79,6 @@ struct num_elems<uint4> {
static constexpr int value = 8; static constexpr int value = 8;
}; };
#ifdef ENABLE_BF16
template<>
struct num_elems<__nv_bfloat162> {
static constexpr int value = 2;
};
template<>
struct num_elems<bf16_4_t> {
static constexpr int value = 4;
};
template<>
struct num_elems<bf16_8_t> {
static constexpr int value = 8;
};
#endif
#ifdef ENABLE_FP8
template<>
struct num_elems<__nv_fp8_e4m3> {
static constexpr int value = 1;
};
template<>
struct num_elems<fp8_2_t> {
static constexpr int value = 2;
};
template<>
struct num_elems<fp8_4_t> {
static constexpr int value = 4;
};
template<>
struct num_elems<fp8_8_t> {
static constexpr int value = 8;
};
#endif
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int N> template<typename T, int N>
...@@ -207,44 +144,6 @@ inline __device__ float4 add(float4 a, float4 b) ...@@ -207,44 +144,6 @@ inline __device__ float4 add(float4 a, float4 b)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
{
return a + b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hadd2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b)
{
bf16_4_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b)
{
bf16_8_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint16_t add(uint16_t a, uint16_t b) inline __device__ uint16_t add(uint16_t a, uint16_t b)
{ {
uint16_t c; uint16_t c;
...@@ -344,24 +243,6 @@ inline __device__ float add(float a, uint16_t b) ...@@ -344,24 +243,6 @@ inline __device__ float add(float a, uint16_t b)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ float add(float a, __nv_bfloat16 b)
{
return a + __bfloat162float(b);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_FP8
inline __device__ float add(float a, __nv_fp8_e4m3 b)
{
return a + (float)(b);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 add(uint32_t a, float2 fb) inline __device__ float2 add(uint32_t a, float2 fb)
{ {
float2 fa = half2_to_float2(a); float2 fa = half2_to_float2(a);
...@@ -486,38 +367,6 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) ...@@ -486,38 +367,6 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ float2 add(__nv_bfloat162 a, float2 fb)
{
float2 fa = bf1622float2(a);
return add(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb)
{
Float4_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb)
{
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c)
{ {
uint32_t d; uint32_t d;
...@@ -649,134 +498,6 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) ...@@ -649,134 +498,6 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc)
return fd; return fd;
} }
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hfma2(a, b, c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hfma2(bf162bf162(a), b, c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c)
{
bf16_4_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c)
{
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c)
{
bf16_8_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c)
{
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc)
{
return __bfloat162float(a) * __bfloat162float(b) + fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc)
{
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return fma(fa, fb, fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc)
{
return fma(bf162bf162(a), b, fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc)
{
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc)
{
__nv_bfloat162 s = bf162bf162(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc)
{
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc)
{
__nv_bfloat162 s = bf162bf162(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}
#endif // ENABLE_BF16
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Acc, typename A, typename B> template<typename Acc, typename A, typename B>
...@@ -1018,259 +739,67 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) ...@@ -1018,259 +739,67 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b)
return fc; return fc;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float sum(float v)
#ifdef ENABLE_BF16
template<>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b)
{ {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 return v;
return __hmul(a, b);
#else
return bf16hmul(a, b);
#endif
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<> inline __device__ float sum(float2 v)
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b)
{ {
return bf16hmul2(a, b); return v.x + v.y;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<> inline __device__ float sum(float4 v)
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b)
{ {
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); return v.x + v.y + v.z + v.w;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float sum(uint16_t v)
template<>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b)
{ {
bf16_4_t c; return half_to_float(v);
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return c;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<> inline __device__ float sum(uint32_t v)
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b)
{ {
__nv_bfloat162 s = bf162bf162(a); float2 tmp = half2_to_float2(v);
bf16_4_t c; return tmp.x + tmp.y;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return c;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<> inline __device__ float sum(uint2 v)
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b)
{ {
bf16_8_t c; uint32_t c = add(v.x, v.y);
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); return sum(c);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return c;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<> inline __device__ float sum(uint4 v)
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
{ #if 1
__nv_bfloat162 s = bf162bf162(a); uint32_t c = add(v.x, v.y);
bf16_8_t c; c = add(c, v.z);
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); c = add(c, v.w);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); #else
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); uint32_t c = add(v.x, v.y);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); uint32_t d = add(v.z, v.w);
return c; c = add(c, d);
#endif
return sum(c);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<> inline __device__ float sum(Float4_ v)
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b)
{ {
float fa = (float)a; return v.x.x + v.x.y + v.y.x + v.y.y;
float fb = (float)b;
return fa * fb;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul(__nv_bfloat16 a, float b)
{
return __bfloat162float(a) * b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b)
{
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b)
{
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b)
{
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b)
{
__nv_bfloat162 s = bf162bf162(a);
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b)
{
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b)
{
__nv_bfloat162 s = bf162bf162(a);
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return fc;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(float v)
{
return v;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(float2 v)
{
return v.x + v.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(float4 v)
{
return v.x + v.y + v.z + v.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ float sum(__nv_bfloat162 v)
{
float2 vf = bf1622float2(v);
return vf.x + vf.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(bf16_4_t v)
{
return sum(v.x) + sum(v.y);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(bf16_8_t v)
{
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint16_t v)
{
return half_to_float(v);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint32_t v)
{
float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint2 v)
{
uint32_t c = add(v.x, v.y);
return sum(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint4 v)
{
#if 1
uint32_t c = add(v.x, v.y);
c = add(c, v.z);
c = add(c, v.w);
#else
uint32_t c = add(v.x, v.y);
uint32_t d = add(v.z, v.w);
c = add(c, d);
#endif
return sum(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(Float4_ v)
{
return v.x.x + v.x.y + v.y.x + v.y.y;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -1322,9 +851,40 @@ inline __device__ void zero(T& dst) ...@@ -1322,9 +851,40 @@ inline __device__ void zero(T& dst)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step) inline __device__ float logn_attn_get_scaling(float seq_len, int max_position_embeddings)
{ {
const float inv_freq = t_step / pow(10000.0f, zid / (float)rot_embed_dim); if (seq_len <= max_position_embeddings) {
return 1.f;
}
return log2f(seq_len) / log2f(max_position_embeddings);
}
inline __device__ float
rotary_embedding_get_base(float seq_len, int max_position_embeddings, float rot_embed_dim, float base)
{
if (seq_len < max_position_embeddings) {
return base;
}
float ntk_alpha = max(exp2f(ceilf(log2f(seq_len / max_position_embeddings) + 1.f)) - 1.f, 1.f);
base *= powf(ntk_alpha, rot_embed_dim / (rot_embed_dim - 2.f));
return base;
}
// inline __device__ float
// rotary_embedding_get_base(float seq_len, int max_position_embeddings, float rot_embed_dim, float base)
// {
// constexpr float scaling_factor = 1.f;
// if (scaling_factor * seq_len < max_position_embeddings) {
// return base;
// }
// base *= powf((scaling_factor * seq_len / max_position_embeddings) - (scaling_factor - 1.f),
// rot_embed_dim / (rot_embed_dim - 2.f));
// return base;
// }
inline __device__ float2 rotary_embedding_coefficient(int zid, int rot_embed_dim, float base, float t_step)
{
const float inv_freq = t_step / powf(base, zid / (float)rot_embed_dim);
return {cos(inv_freq), sin(inv_freq)}; return {cos(inv_freq), sin(inv_freq)};
} }
...@@ -1362,39 +922,39 @@ inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int r ...@@ -1362,39 +922,39 @@ inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int r
return; return;
} }
inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, float base, int t_step)
{ {
if (2 * tid >= rot_embed_dim) { if (2 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, t_step);
q = rotary_embedding_transform(q, coef); q = rotary_embedding_transform(q, coef);
} }
inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, float base, int t_step)
{ {
if (2 * tid >= rot_embed_dim) { if (2 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, t_step);
q = rotary_embedding_transform(q, coef); q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef); k = rotary_embedding_transform(k, coef);
} }
inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, float base, int t_step)
{ {
if (4 * tid >= rot_embed_dim) { if (4 * tid >= rot_embed_dim) {
return; return;
} }
Float4_& q_ = *reinterpret_cast<Float4_*>(&q); Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, t_step);
q_.x = rotary_embedding_transform(q_.x, coef0); q_.x = rotary_embedding_transform(q_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, t_step);
q_.y = rotary_embedding_transform(q_.y, coef1); q_.y = rotary_embedding_transform(q_.y, coef1);
} }
inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, float base, int t_step)
{ {
if (4 * tid >= rot_embed_dim) { if (4 * tid >= rot_embed_dim) {
return; return;
...@@ -1402,455 +962,90 @@ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int ...@@ -1402,455 +962,90 @@ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int
Float4_& q_ = *reinterpret_cast<Float4_*>(&q); Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
Float4_& k_ = *reinterpret_cast<Float4_*>(&k); Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, t_step);
q_.x = rotary_embedding_transform(q_.x, coef0); q_.x = rotary_embedding_transform(q_.x, coef0);
k_.x = rotary_embedding_transform(k_.x, coef0); k_.x = rotary_embedding_transform(k_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, t_step);
q_.y = rotary_embedding_transform(q_.y, coef1); q_.y = rotary_embedding_transform(q_.y, coef1);
k_.y = rotary_embedding_transform(k_.y, coef1); k_.y = rotary_embedding_transform(k_.y, coef1);
} }
inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, float base, int t_step)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
q.w = rotary_embedding_transform(q.w, coef3);
}
inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}
#ifdef ENABLE_BF16
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step)
{ {
if (2 * tid >= rot_embed_dim) { if (2 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, t_step);
q = rotary_embedding_transform(q, coef); q = rotary_embedding_transform(q, coef);
} }
inline __device__ void inline __device__ void
apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step) apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, float base, int t_step)
{ {
if (2 * tid >= rot_embed_dim) { if (2 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, t_step);
q = rotary_embedding_transform(q, coef); q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef); k = rotary_embedding_transform(k, coef);
} }
inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, float base, int t_step)
{ {
if (4 * tid >= rot_embed_dim) { if (4 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, t_step);
q.x = rotary_embedding_transform(q.x, coef0); q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, t_step);
q.y = rotary_embedding_transform(q.y, coef1); q.y = rotary_embedding_transform(q.y, coef1);
} }
inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, float base, int t_step)
{ {
if (4 * tid >= rot_embed_dim) { if (4 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, t_step);
q.x = rotary_embedding_transform(q.x, coef0); q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0); k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, t_step);
q.y = rotary_embedding_transform(q.y, coef1); q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1); k.y = rotary_embedding_transform(k.y, coef1);
} }
inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, float base, int t_step)
{ {
if (8 * tid >= rot_embed_dim) { if (8 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, t_step);
q.x = rotary_embedding_transform(q.x, coef0); q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, t_step);
q.y = rotary_embedding_transform(q.y, coef1); q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, t_step);
q.z = rotary_embedding_transform(q.z, coef2); q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, t_step);
q.w = rotary_embedding_transform(q.w, coef3); q.w = rotary_embedding_transform(q.w, coef3);
} }
inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, float base, int t_step)
{ {
if (8 * tid >= rot_embed_dim) { if (8 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, t_step);
q.x = rotary_embedding_transform(q.x, coef0); q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0); k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, t_step);
q.y = rotary_embedding_transform(q.y, coef1); q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1); k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, t_step);
q.z = rotary_embedding_transform(q.z, coef2); q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2); k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, t_step);
q.w = rotary_embedding_transform(q.w, coef3); q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3); k.w = rotary_embedding_transform(k.w, coef3);
} }
#endif // ENABLE_BF16
template<typename Vec_T, typename T>
__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
template<>
__device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch)
{
return;
}
template<>
__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u16[0] = smem[transpose_idx];
tmp.u16[1] = smem[smem_pitch + transpose_idx];
vec = tmp.u32;
}
template<>
__device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp_1, tmp_2;
tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
union {
uint2 u32x2;
uint16_t u16[4];
} tmp_3;
tmp_3.u16[0] = tmp_1.u16[0];
tmp_3.u16[1] = tmp_2.u16[0];
tmp_3.u16[2] = tmp_1.u16[1];
tmp_3.u16[3] = tmp_2.u16[1];
vec = tmp_3.u32x2;
}
template<>
__device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint64_t u64;
uint16_t u16[4];
} tmp_1, tmp_2;
tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
union {
uint4 u32x4;
uint16_t u16[8];
} tmp_3;
tmp_3.u16[0] = tmp_1.u16[0];
tmp_3.u16[1] = tmp_2.u16[0];
tmp_3.u16[2] = tmp_1.u16[1];
tmp_3.u16[3] = tmp_2.u16[1];
tmp_3.u16[4] = tmp_1.u16[2];
tmp_3.u16[5] = tmp_2.u16[2];
tmp_3.u16[6] = tmp_1.u16[3];
tmp_3.u16[7] = tmp_2.u16[3];
vec = tmp_3.u32x4;
}
#ifdef ENABLE_BF16
template<>
__device__ __inline__ void
vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
__nv_bfloat16 bf16[2];
} tmp_1, tmp_2;
tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
}
template<>
__device__ __inline__ void
vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
union {
uint64_t u64;
__nv_bfloat16 bf16[4];
} tmp_1, tmp_2;
tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]};
vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]};
}
#endif // ENABLE_BF16
template<>
__device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch)
{
vec.x = smem[transpose_idx];
vec.z = smem[transpose_idx + 1];
vec.y = smem[smem_pitch + transpose_idx];
vec.w = smem[smem_pitch + transpose_idx + 1];
}
template<>
__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
half u16[2];
} tmp;
tmp.u16[0] = smem[transpose_idx];
tmp.u16[1] = smem[smem_pitch + transpose_idx];
vec = tmp.u32;
}
#ifdef ENABLE_BF16
template<>
__device__ __inline__ void
vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
vec.x = smem[transpose_idx];
vec.y = smem[smem_pitch + transpose_idx];
}
#endif
template<>
__device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch)
{
vec.x = smem[transpose_idx];
vec.y = smem[smem_pitch + transpose_idx];
}
template<typename Vec_T, typename T>
__device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
template<>
__device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch)
{
return;
}
#ifdef ENABLE_BF16
template<>
__device__ __inline__ void
write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
return;
}
template<>
__device__ __inline__ void
write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
return;
}
#endif
#ifdef ENABLE_FP8
template<>
__device__ __inline__ void vec_from_smem_transpose(float4& vec, __nv_fp8_e4m3* smem, int transpose_idx, int smem_pitch)
{
// TODO
printf("[ERROR] still no have implementation for vec_from_smem_transpose under __nv_fp8_e4m3 \n");
}
#endif // ENABLE_FP8
template<>
__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint64_t u64;
uint16_t u16[4];
} tmp_1, tmp_2;
union {
uint4 u32x4;
uint16_t u16[8];
} tmp_3;
tmp_3.u32x4 = vec;
tmp_1.u16[0] = tmp_3.u16[0];
tmp_2.u16[0] = tmp_3.u16[1];
tmp_1.u16[1] = tmp_3.u16[2];
tmp_2.u16[1] = tmp_3.u16[3];
tmp_1.u16[2] = tmp_3.u16[4];
tmp_2.u16[2] = tmp_3.u16[5];
tmp_1.u16[3] = tmp_3.u16[6];
tmp_2.u16[3] = tmp_3.u16[7];
*reinterpret_cast<uint64_t*>(&smem[transpose_idx]) = tmp_1.u64;
*reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u64;
}
template<>
__device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp_1, tmp_2;
union {
uint2 u32x2;
uint16_t u16[4];
} tmp_3;
tmp_3.u32x2 = vec;
tmp_1.u16[0] = tmp_3.u16[0];
tmp_2.u16[0] = tmp_3.u16[1];
tmp_1.u16[1] = tmp_3.u16[2];
tmp_2.u16[1] = tmp_3.u16[3];
*reinterpret_cast<uint32_t*>(&smem[transpose_idx]) = tmp_1.u32;
*reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u32;
}
template<>
__device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u32 = vec;
smem[transpose_idx] = tmp.u16[0];
smem[smem_pitch + transpose_idx] = tmp.u16[1];
}
template<>
__device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch)
{
smem[transpose_idx] = vec.x;
smem[transpose_idx + 1] = vec.z;
smem[smem_pitch + transpose_idx] = vec.y;
smem[smem_pitch + transpose_idx + 1] = vec.w;
}
template<>
__device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
half u16[2];
} tmp;
tmp.u32 = vec;
smem[transpose_idx] = tmp.u16[0];
smem[smem_pitch + transpose_idx] = tmp.u16[1];
}
#ifdef ENABLE_BF16
template<>
__device__ __inline__ void
write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
smem[transpose_idx] = vec.x;
smem[smem_pitch + transpose_idx] = vec.y;
}
#endif
template<>
__device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch)
{
smem[transpose_idx] = vec.x;
smem[smem_pitch + transpose_idx] = vec.y;
}
#ifdef ENABLE_FP8
template<>
__device__ __inline__ void
write_smem_transpose(const float4& vec, __nv_fp8_e4m3* smem, int transpose_idx, int smem_pitch)
{
printf("[ERROR] still no have implementation for vec_from_smem_transpose under __nv_fp8_e4m3 \n");
}
#endif // ENABLE_FP8
} // namespace mmha } // namespace mmha
...@@ -29,230 +29,6 @@ __inline__ __device__ int target_index(int id1, int id2, int id3, int id4, int d ...@@ -29,230 +29,6 @@ __inline__ __device__ int target_index(int id1, int id2, int id3, int id4, int d
return id1 * (dim_2 * dim_3 * dim_4) + id3 * (dim_2 * dim_4) + id2 * dim_4 + id4; return id1 * (dim_2 * dim_3 * dim_4) + id3 * (dim_2 * dim_4) + id2 * dim_4 + id4;
} }
template<typename T>
__global__ void addQKVBiasIA3Transpose(T* q_out,
T* k_out,
T* v_out,
const T* __restrict q_in,
const T* __restrict bias_q,
const T* __restrict k_in,
const T* __restrict bias_k,
const T* __restrict v_in,
const T* __restrict bias_v,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head)
{
const int n = head_num * size_per_head;
const int batch_id = blockIdx.x;
const int word_id = blockIdx.y;
const int row_id = batch_id * seq_len + word_id;
const bool use_ia3 = ia3_tasks != nullptr;
const int ia3_task = use_ia3 ? ia3_tasks[batch_id] : 0;
const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr);
const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr);
for (int col_id = threadIdx.x; col_id < n; col_id += blockDim.x) {
const int head_id = col_id / size_per_head;
const int size_id = col_id % size_per_head;
const int target_id = batch_id * (head_num * seq_len * size_per_head) + head_id * seq_len * size_per_head
+ word_id * size_per_head + size_id;
const int src_id = row_id * n + col_id;
T q = ldg(&q_in[src_id]);
q_out[target_id] = add(q, ldg(&bias_q[col_id]));
T k = add(ldg(&k_in[src_id]), ldg(&bias_k[col_id]));
if (use_ia3_key) {
k = k * ia3_key_weights[ia3_task * n + col_id];
}
k_out[target_id] = k;
T v = add(ldg(&v_in[src_id]), ldg(&bias_v[col_id]));
if (use_ia3_value) {
v = v * ia3_value_weights[ia3_task * n + col_id];
}
v_out[target_id] = v;
}
}
template<typename T>
__global__ void QKVIA3Transpose(T* q_out,
T* k_out,
T* v_out,
const T* __restrict q_in,
const T* __restrict k_in,
const T* __restrict v_in,
const int* ia3_tasks,
const T* __restrict ia3_key_weights,
const T* __restrict ia3_value_weights,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head)
{
const int n = head_num * size_per_head;
const int batch_id = blockIdx.x;
const int word_id = blockIdx.y;
const int row_id = batch_id * seq_len + word_id;
const bool use_ia3 = ia3_tasks != nullptr;
const int ia3_task = use_ia3 ? ia3_tasks[batch_id] : 0;
const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr);
const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr);
for (int col_id = threadIdx.x; col_id < n; col_id += blockDim.x) {
const int head_id = col_id / size_per_head;
const int size_id = col_id % size_per_head;
const int target_id = batch_id * (head_num * seq_len * size_per_head) + head_id * seq_len * size_per_head
+ word_id * size_per_head + size_id;
const int src_id = row_id * n + col_id;
q_out[target_id] = ldg(&q_in[src_id]);
T k = ldg(&k_in[src_id]);
if (use_ia3_key) {
k = k * ia3_key_weights[ia3_task * n + col_id];
}
k_out[target_id] = k;
T v = ldg(&v_in[src_id]);
if (use_ia3_value) {
v = v * ia3_value_weights[ia3_task * n + col_id];
}
v_out[target_id] = v;
}
}
template<typename T>
void invokeAddQKVBiasIA3Transpose(T* q_buf,
T* k_buf,
T* v_buf,
T* Q,
const T* bias_Q,
T* K,
const T* bias_K,
T* V,
const T* bias_V,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
cudaStream_t stream)
{
const int k = head_num * size_per_head;
dim3 grid(batch_size, seq_len);
bool is_add_bias = bias_Q != nullptr;
if (sizeof(T) == 4 || k % 2 != 0) {
dim3 block(min(k, 512));
if (is_add_bias) {
addQKVBiasIA3Transpose<T><<<grid, block, 0, stream>>>(q_buf,
k_buf,
v_buf,
Q,
bias_Q,
K,
bias_K,
V,
bias_V,
ia3_tasks,
ia3_key_weights,
ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head);
}
else {
QKVIA3Transpose<T><<<grid, block, 0, stream>>>(q_buf,
k_buf,
v_buf,
Q,
K,
V,
ia3_tasks,
ia3_key_weights,
ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head);
}
sync_check_cuda_error();
}
else {
using T2 = typename TypeConverter<T>::Type; // fp16 to half2, bf16 to bf162
dim3 block(min(k / 2, 512));
if (is_add_bias) {
addQKVBiasIA3Transpose<T2><<<grid, block, 0, stream>>>((T2*)q_buf,
(T2*)k_buf,
(T2*)v_buf,
(const T2*)Q,
(const T2*)bias_Q,
(const T2*)K,
(const T2*)bias_K,
(const T2*)V,
(const T2*)bias_V,
ia3_tasks,
(const T2*)ia3_key_weights,
(const T2*)ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head / 2);
}
else {
QKVIA3Transpose<T2><<<grid, block, 0, stream>>>((T2*)q_buf,
(T2*)k_buf,
(T2*)v_buf,
(const T2*)Q,
(const T2*)K,
(const T2*)V,
ia3_tasks,
(const T2*)ia3_key_weights,
(const T2*)ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head / 2);
}
sync_check_cuda_error();
}
}
#define INSTANTIATEADDQKVBIASIA3TRANSPOSE(T) \
template void invokeAddQKVBiasIA3Transpose(T* q_buf, \
T* k_buf, \
T* v_buf, \
T* Q, \
const T* bias_Q, \
T* K, \
const T* bias_K, \
T* V, \
const T* bias_V, \
const int batch_size, \
const int seq_len, \
const int head_num, \
const int size_per_head, \
const int* ia3_tasks, \
const T* ia3_key_weights, \
const T* ia3_value_weights, \
cudaStream_t stream)
INSTANTIATEADDQKVBIASIA3TRANSPOSE(float);
INSTANTIATEADDQKVBIASIA3TRANSPOSE(half);
#ifdef ENABLE_BF16
INSTANTIATEADDQKVBIASIA3TRANSPOSE(__nv_bfloat16);
#endif
#undef INSTANTIATEADDQKVBIASTRANSPOSE
template<typename T, typename T_IN, int ITEMS_PER_THREAD> template<typename T, typename T_IN, int ITEMS_PER_THREAD>
__global__ void softmax_kernel(T* attn_score, __global__ void softmax_kernel(T* attn_score,
const T_IN* qk, const T_IN* qk,
...@@ -882,259 +658,6 @@ INSTANTIATETRANSPOSEQKV(__nv_bfloat16); ...@@ -882,259 +658,6 @@ INSTANTIATETRANSPOSEQKV(__nv_bfloat16);
#endif #endif
#undef INSTANTIATETRANSPOSEQKV #undef INSTANTIATETRANSPOSEQKV
template<typename T>
__global__ void add_QKV_bias_rebuild_padding_ia3(const T* Q,
const T* bias_Q,
const T* K,
const T* bias_K,
const T* V,
const T* bias_V,
T* q_buf_,
T* k_buf_,
T* v_buf_,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* mask_offset)
{
const int bid = blockIdx.x;
const int tgt_batch_id = (bid + mask_offset[bid]) / seq_len;
const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len;
const int n = head_num * size_per_head;
const bool use_ia3 = ia3_tasks != nullptr;
const int ia3_task = use_ia3 ? ia3_tasks[tgt_batch_id] : 0;
const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr);
const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr);
for (int idx = threadIdx.x; idx < n; idx += blockDim.x) {
const int tgt_head_id = idx / size_per_head;
const int tgt_hidden_id = idx % size_per_head;
const int src_id = bid * n + idx;
const int tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + tgt_head_id * seq_len * size_per_head
+ tgt_seq_id * size_per_head + tgt_hidden_id;
q_buf_[tgt_id] = add(ldg(&Q[src_id]), ldg(&bias_Q[idx]));
T k = ldg(&K[src_id]);
if (use_ia3_key) {
k = k * ia3_key_weights[ia3_task * n + idx];
}
k_buf_[tgt_id] = add(k, ldg(&bias_K[idx]));
T v = ldg(&V[src_id]);
if (use_ia3_value) {
v = v * ia3_value_weights[ia3_task * n + idx];
}
v_buf_[tgt_id] = add(v, ldg(&bias_V[idx]));
}
}
template<typename T>
__global__ void rebuild_padding_ia3(const T* Q,
const T* K,
const T* V,
T* q_buf_,
T* k_buf_,
T* v_buf_,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* mask_offset)
{
const int bid = blockIdx.x;
const int tgt_batch_id = (bid + mask_offset[bid]) / seq_len;
const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len;
const int n = head_num * size_per_head;
const bool use_ia3 = ia3_tasks != nullptr;
const int ia3_task = use_ia3 ? ia3_tasks[tgt_batch_id] : 0;
const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr);
const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr);
for (int idx = threadIdx.x; idx < n; idx += blockDim.x) {
const int tgt_head_id = idx / size_per_head;
const int tgt_hidden_id = idx % size_per_head;
const int src_id = bid * n + idx;
const int tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + tgt_head_id * seq_len * size_per_head
+ tgt_seq_id * size_per_head + tgt_hidden_id;
q_buf_[tgt_id] = ldg(&Q[src_id]);
T k = ldg(&K[src_id]);
if (use_ia3_key) {
k = k * ia3_key_weights[ia3_task * n + idx];
}
k_buf_[tgt_id] = k;
T v = ldg(&V[src_id]);
if (use_ia3_value) {
v = v * ia3_value_weights[ia3_task * n + idx];
}
v_buf_[tgt_id] = v;
}
}
template<typename T>
void invokeAddQKVBiasIA3RebuildPadding(T* Q,
const T* bias_Q,
T* K,
const T* bias_K,
T* V,
const T* bias_V,
T* q_buf,
T* k_buf,
T* v_buf,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int valid_word_num,
const int* mask_offset,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
cudaStream_t stream)
{
#ifdef ENABLE_BF16
bool is_half2 = (std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value) && (size_per_head % 2 == 0);
#else
bool is_half2 = (std::is_same<T, half>::value) && (size_per_head % 2 == 0);
#endif
using T2 = typename TypeConverter<T>::Type; // fp16 to half2, bf16 to bf162
int block_size = head_num * size_per_head;
if (is_half2) {
while (block_size > 512) {
if (block_size % 2 == 0) {
block_size /= 2;
}
else {
is_half2 = false;
block_size = std::min(block_size, 512);
break;
}
}
}
else {
block_size = std::min(block_size, 512);
}
if (bias_Q == nullptr && bias_K == nullptr && bias_V == nullptr) {
if (is_half2) {
rebuild_padding_ia3<<<valid_word_num, block_size, 0, stream>>>((T2*)Q,
(T2*)K,
(T2*)V,
(T2*)q_buf,
(T2*)k_buf,
(T2*)v_buf,
ia3_tasks,
(const T2*)ia3_key_weights,
(const T2*)ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head / 2,
mask_offset);
}
else {
rebuild_padding_ia3<<<valid_word_num, block_size, 0, stream>>>(Q,
K,
V,
q_buf,
k_buf,
v_buf,
ia3_tasks,
ia3_key_weights,
ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head,
mask_offset);
}
}
else if (bias_Q != nullptr && bias_K != nullptr && bias_V != nullptr) {
if (is_half2) {
add_QKV_bias_rebuild_padding_ia3<<<valid_word_num, block_size, 0, stream>>>((T2*)Q,
(const T2*)bias_Q,
(T2*)K,
(const T2*)bias_K,
(T2*)V,
(const T2*)bias_V,
(T2*)q_buf,
(T2*)k_buf,
(T2*)v_buf,
ia3_tasks,
(const T2*)ia3_key_weights,
(const T2*)ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head / 2,
mask_offset);
}
else {
add_QKV_bias_rebuild_padding_ia3<<<valid_word_num, block_size, 0, stream>>>(Q,
bias_Q,
K,
bias_K,
V,
bias_V,
q_buf,
k_buf,
v_buf,
ia3_tasks,
ia3_key_weights,
ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head,
mask_offset);
}
}
else {
FT_CHECK(false);
}
}
#define INSTANTIATEADDQKVBIASIA3REBUILDPADDING(T) \
template void invokeAddQKVBiasIA3RebuildPadding(T* Q, \
const T* bias_Q, \
T* K, \
const T* bias_K, \
T* V, \
const T* bias_V, \
T* q_buf, \
T* k_buf, \
T* v_buf, \
const int batch_size, \
const int seq_len, \
const int head_num, \
const int size_per_head, \
const int valid_word_num, \
const int* mask_offset, \
const int* ia3_tasks, \
const T* ia3_key_weights, \
const T* ia3_value_weights, \
cudaStream_t stream)
INSTANTIATEADDQKVBIASIA3REBUILDPADDING(float);
INSTANTIATEADDQKVBIASIA3REBUILDPADDING(half);
#ifdef ENABLE_BF16
INSTANTIATEADDQKVBIASIA3REBUILDPADDING(__nv_bfloat16);
#endif
#undef INSTANTIATEADDQKVBIASREBUILDPADDING
template<typename T> template<typename T>
__global__ void transpose_remove_padding(const T* src, __global__ void transpose_remove_padding(const T* src,
T* dst, T* dst,
...@@ -1326,38 +849,35 @@ struct Vec_t<__nv_bfloat16> { ...@@ -1326,38 +849,35 @@ struct Vec_t<__nv_bfloat16> {
/// TODO: support batch step offset /// TODO: support batch step offset
template<typename T, bool PREFIX_PROMPT> template<typename T, bool PREFIX_PROMPT>
__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
T* k_buf, T* k_buf,
T* v_buf, T* v_buf,
PrefixPromptBatchWeightsParam<T> param, T* QKV,
T* QKV,
const T* __restrict qkv_bias, const T* __restrict qkv_bias,
const int* padding_offset, const int* padding_offset,
const int* history_length, const int* history_length,
const int batch_size, const int* input_length,
const int seq_len, int batch_size,
const int head_num, int seq_len,
const int kv_head_num, int head_num,
const int size_per_head, int kv_head_num,
const int rotary_embedding_dim, int size_per_head,
const bool neox_rotary_style) int rotary_embedding_dim,
int max_position_embeddings,
bool use_dynamic_ntk,
bool use_logn_attn)
{ {
// This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and // This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and
// QKV split to 3 split buffer q, k, v and transpose them to [batch_size, head_num, seq_len, size_per_head]. // QKV split to 3 split buffer q, k, v and transpose them to [batch_size, head_num, seq_len, size_per_head].
// For q and k, also apply the rotary embedding. // For q and k, also apply the rotary embedding.
// When we pass prefix prompt, this kernel also concatenate the prefix prompt and key/value along
// seq_len dimension like [prompt, key/value].
// So, the final shape of q is same ([batch_size, head_num, seq_len, size_per_head]), but
// the shapes of key and values become [batch_size, head_num, max_prefix_prompt_length + seq_len, size_per_head].
// NOTE: QKV src shape (batch_size, seq_len, 3, head_num, size_per_head) // NOTE: QKV src shape (batch_size, seq_len, 3, head_num, size_per_head)
// QKV dst shape (3, batch_size, head_num, seq_len, size_per_head) // QKV dst shape (3, batch_size, head_num, seq_len, size_per_head)
extern __shared__ __align__(sizeof(float2)) char smem_[]; // align on largest vector type extern __shared__ __align__(sizeof(float2)) char smem_[]; // align on largest vector type
constexpr int vec_size = Vec_t<T>::size; constexpr int vec_size = Vec_t<T>::size;
using Vec_t = typename Vec_t<T>::Type; using Vec_t = typename Vec_t<T>::Type;
const int token_idx = blockIdx.x - batch_size * param.max_prefix_prompt_length; const int token_idx = blockIdx.x;
const int token_padding_offset = (padding_offset == nullptr || token_idx < 0) ? 0 : padding_offset[token_idx]; const int token_padding_offset = (padding_offset == nullptr || token_idx < 0) ? 0 : padding_offset[token_idx];
const int tgt_token_idx = token_idx + token_padding_offset; const int tgt_token_idx = token_idx + token_padding_offset;
...@@ -1367,49 +887,11 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* ...@@ -1367,49 +887,11 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
const int head_idx = blockIdx.y; const int head_idx = blockIdx.y;
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
const int total_seq_len = param.max_prefix_prompt_length + seq_len; const int total_seq_len = seq_len;
const bool is_masked = tidx * vec_size >= size_per_head; const bool is_masked = tidx * vec_size >= size_per_head;
// NOTE: blockIdx.x < batch_size * param.max_prefix_prompt_length really handles prefix prompts
if (PREFIX_PROMPT && token_idx < 0) {
const int prompt_batch_idx = blockIdx.x / param.max_prefix_prompt_length;
const int prompt_seq_idx = blockIdx.x % param.max_prefix_prompt_length;
const int prompt_length = param.d_prefix_prompt_lengths[prompt_batch_idx];
if (prompt_seq_idx < prompt_length) {
const int dest_kv_idx = prompt_batch_idx * size_per_head * total_seq_len * head_num
+ head_idx * size_per_head * total_seq_len + prompt_seq_idx * size_per_head
+ tidx * vec_size;
const int prefix_kv_idx =
size_per_head * prompt_length * head_idx + size_per_head * prompt_seq_idx + tidx * vec_size;
const T* prefix_prompt_k = param.d_prefix_prompt_batch[prompt_batch_idx]
+ param.prefix_prompt_layer_offset_per_seq * prompt_length;
const T* prefix_prompt_v = prefix_prompt_k + prompt_length * head_num * size_per_head;
if (!is_masked) {
*reinterpret_cast<Vec_t*>(&k_buf[dest_kv_idx]) =
*reinterpret_cast<const Vec_t*>(&prefix_prompt_k[prefix_kv_idx]);
*reinterpret_cast<Vec_t*>(&v_buf[dest_kv_idx]) =
*reinterpret_cast<const Vec_t*>(&prefix_prompt_v[prefix_kv_idx]);
}
}
return;
}
const int prefix_prompt_length = PREFIX_PROMPT ? param.d_prefix_prompt_lengths[batch_idx] : 0;
const int hidden_idx = head_idx * size_per_head + tidx * vec_size;
// const int n = head_num * size_per_head;
// the [0..seq_len) indices really handle KV [max_pp_len..seq_len+max_pp_len) const int hidden_idx = head_idx * size_per_head + tidx * vec_size;
// and Q [0..seq_len)
// Note: if !PREFIX_PROMPT, max_pp_len = 0, so it's no-op
const int dst_kv_seq_idx = seq_idx + prefix_prompt_length;
// NOTE: q has seq len excluding prefix prompt
// src QKV: [batch, time, 3, head, hidden]
// const int src_q_idx = token_idx * 3 * n + hidden_idx;
// const int src_k_idx = token_idx * 3 * n + hidden_idx + n;
// const int src_v_idx = token_idx * 3 * n + hidden_idx + 2 * n;
const int q_kv_head_num = head_num + 2 * kv_head_num; const int q_kv_head_num = head_num + 2 * kv_head_num;
...@@ -1445,48 +927,28 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* ...@@ -1445,48 +927,28 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
} }
} }
const int t_offset = history_length ? history_length[batch_idx] : 0; const int history_len = history_length[batch_idx];
const int context_len = history_len + input_length[batch_idx];
const int timestep = history_len + seq_idx;
if (!neox_rotary_style) { float rotary_emb_base = 10000.f;
// TODO: unused computation on k if GQA is used if (use_dynamic_ntk) {
mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, dst_kv_seq_idx + t_offset); rotary_emb_base = mmha::rotary_embedding_get_base(
context_len, max_position_embeddings, rotary_embedding_dim, rotary_emb_base);
} }
else {
const bool do_rotary = !is_masked && vec_size * tidx < rotary_embedding_dim;
T* q_smem = reinterpret_cast<T*>(smem_);
T* k_smem = q_smem + rotary_embedding_dim;
const int half_rotary_dim = rotary_embedding_dim / 2;
const int half_idx = (tidx * vec_size) / half_rotary_dim;
const int intra_half_idx = (tidx * vec_size) % half_rotary_dim;
const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts?
if (do_rotary) {
*reinterpret_cast<Vec_t*>(q_smem + half_idx * smem_pitch + intra_half_idx) = q;
*reinterpret_cast<Vec_t*>(k_smem + half_idx * smem_pitch + intra_half_idx) = k;
}
__syncthreads(); // TODO: unused computation on k if GQA is used
mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, rotary_emb_base, timestep);
const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2;
constexpr int tidx_factor = vec_size / 2;
if (do_rotary) {
mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
mmha::apply_rotary_embedding(
q, k, transpose_idx / tidx_factor, rotary_embedding_dim, dst_kv_seq_idx + t_offset);
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); if (use_logn_attn) {
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); // +1 to convert to context length at the timestep
float logn_scaling = mmha::logn_attn_get_scaling(timestep + 1, max_position_embeddings);
if constexpr (std::is_same_v<T, float>) {
q = mmha::mul<Vec_t, float, Vec_t>(logn_scaling, q);
} }
else if constexpr (std::is_same_v<T, half>) {
__syncthreads(); half tmp = __float2half(logn_scaling);
q = mmha::mul<Vec_t, uint16_t, Vec_t>((uint16_t&)tmp, q);
if (do_rotary) {
q = *reinterpret_cast<Vec_t*>(q_smem + half_idx * smem_pitch + intra_half_idx);
k = *reinterpret_cast<Vec_t*>(k_smem + half_idx * smem_pitch + intra_half_idx);
} }
} }
...@@ -1502,8 +964,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* ...@@ -1502,8 +964,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
+ seq_idx * size_per_head + tidx * vec_size; + seq_idx * size_per_head + tidx * vec_size;
const int dest_kv_idx = batch_idx * size_per_head * total_seq_len * kv_head_num const int dest_kv_idx = batch_idx * size_per_head * total_seq_len * kv_head_num
+ head_idx * size_per_head * total_seq_len + dst_kv_seq_idx * size_per_head + head_idx * size_per_head * total_seq_len + seq_idx * size_per_head + tidx * vec_size;
+ tidx * vec_size;
if (!is_masked) { if (!is_masked) {
*reinterpret_cast<Vec_t*>(&q_buf[dest_q_idx]) = q; *reinterpret_cast<Vec_t*>(&q_buf[dest_q_idx]) = q;
...@@ -1518,75 +979,70 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* ...@@ -1518,75 +979,70 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
add_fusedQKV_bias_transpose_kernel<T, PREFIX_PROMPT><<<grid, block, smem_size, stream>>>(q_buf, \ add_fusedQKV_bias_transpose_kernel<T, PREFIX_PROMPT><<<grid, block, smem_size, stream>>>(q_buf, \
k_buf, \ k_buf, \
v_buf, \ v_buf, \
param, \
QKV, \ QKV, \
qkv_bias, \ qkv_bias, \
padding_offset, \ padding_offset, \
history_length, \ history_length, \
input_length, \
batch_size, \ batch_size, \
seq_len, \ seq_len, \
head_num, \ head_num, \
kv_head_num, \ kv_head_num, \
size_per_head, \ size_per_head, \
rotary_embedding_dim, \ rotary_embedding_dim, \
neox_rotary_style); max_position_embeddings, \
use_dynamic_ntk, \
use_logn_attn);
template<typename T> template<typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf, void invokeAddFusedQKVBiasTranspose(T* q_buf,
T* k_buf, T* k_buf,
T* v_buf, T* v_buf,
PrefixPromptBatchWeightsParam<T> param, T* QKV,
T* QKV, const T* qkv_bias,
const T* qkv_bias, const int* padding_offset,
const int* padding_offset, const int* history_length,
const int* history_length, const int* input_length,
const int batch_size, const int batch_size,
const int seq_len, const int seq_len,
const int token_num, const int token_num,
const int head_num, const int head_num,
const int kv_head_num, const int kv_head_num,
const int size_per_head, const int size_per_head,
const int rotary_embedding_dim, const int rotary_embedding_dim,
const int neox_rotary_style, int max_position_embeddings,
const float* scale, bool use_dynamic_ntk,
const int int8_mode, bool use_logn_attn,
cudaStream_t stream) cudaStream_t stream)
{ {
FT_CHECK(rotary_embedding_dim); FT_CHECK(rotary_embedding_dim);
FT_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with prefix prompt"); // TODO(mseznec)
// To implement rotary embeddings, each thread processes two QKV elems: // To implement rotary embeddings, each thread processes two QKV elems:
dim3 block((size_per_head / Vec_t<T>::size + 31) / 32 * 32); dim3 block((size_per_head / Vec_t<T>::size + 31) / 32 * 32);
dim3 grid(token_num + batch_size * param.max_prefix_prompt_length, head_num); dim3 grid(token_num, head_num);
size_t smem_size = neox_rotary_style ? 2 * rotary_embedding_dim * sizeof(T) : 0; size_t smem_size = 0;
// NOTE: add offset for rotary embedding FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, false);
if (param.max_prefix_prompt_length == 0) {
FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, false);
}
else {
FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, true);
}
} }
#define INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(T) \ #define INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(T) \
template void invokeAddFusedQKVBiasTranspose(T* q_buf, \ template void invokeAddFusedQKVBiasTranspose(T* q_buf, \
T* k_buf, \ T* k_buf, \
T* v_buf, \ T* v_buf, \
PrefixPromptBatchWeightsParam<T> param, \ T* QKV, \
T* QKV, \ const T* qkv_bias, \
const T* qkv_bias, \ const int* padding_offset, \
const int* padding_offset, \ const int* history_length, \
const int* history_length, \ const int* input_length, \
const int batch_size, \ const int batch_size, \
const int seq_len, \ const int seq_len, \
const int token_num, \ const int token_num, \
const int head_num, \ const int head_num, \
const int kv_head_num, \ const int kv_head_num, \
const int size_per_head, \ const int size_per_head, \
const int rotary_embedding_dim, \ const int rotary_embedding_dim, \
const int neox_rotary_style, \ int max_position_embeddings, \
const float* scale, \ bool use_dynamic_ntk, \
const int int8_mode, \ bool use_logn_attn, \
cudaStream_t stream) cudaStream_t stream)
INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(float); INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(float);
INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(half); INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(half);
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
......
...@@ -19,25 +19,6 @@ ...@@ -19,25 +19,6 @@
namespace turbomind { namespace turbomind {
template<typename T>
void invokeAddQKVBiasIA3Transpose(T* q_buf,
T* k_buf,
T* v_buf,
T* Q,
const T* bias_Q,
T* K,
const T* bias_K,
T* V,
const T* bias_V,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
cudaStream_t stream);
template<typename T, typename T_IN> template<typename T, typename T_IN>
struct MaskedSoftmaxParam { struct MaskedSoftmaxParam {
// Common parameters. // Common parameters.
...@@ -69,27 +50,6 @@ void invokeTransposeQKV(T* dst, ...@@ -69,27 +50,6 @@ void invokeTransposeQKV(T* dst,
const int int8_mode, const int int8_mode,
cudaStream_t stream); cudaStream_t stream);
template<typename T>
void invokeAddQKVBiasIA3RebuildPadding(T* Q,
const T* bias_Q,
T* K,
const T* bias_K,
T* V,
const T* bias_V,
T* q_buf,
T* k_buf,
T* v_buf,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int valid_word_num,
const int* mask_offset,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
cudaStream_t stream);
template<typename T> template<typename T>
void invokeTransposeAttentionOutRemovePadding(T* src, void invokeTransposeAttentionOutRemovePadding(T* src,
T* dst, T* dst,
...@@ -103,36 +63,26 @@ void invokeTransposeAttentionOutRemovePadding(T* src, ...@@ -103,36 +63,26 @@ void invokeTransposeAttentionOutRemovePadding(T* src,
const int int8_mode, const int int8_mode,
cudaStream_t stream); cudaStream_t stream);
// Prefix Prompt Parameters
template<typename T> template<typename T>
struct PrefixPromptBatchWeightsParam { void invokeAddFusedQKVBiasTranspose(T* q_buf,
const T** d_prefix_prompt_batch = nullptr; T* k_buf,
const int* d_prefix_prompt_lengths = nullptr; T* v_buf,
const int max_prefix_prompt_length = 0; T* QKV,
// l * 2 * hidden_units_ / tensor_para_.world_size_ const T* qkv_bias,
const size_t prefix_prompt_layer_offset_per_seq = 0; const int* padding_offset,
}; const int* history_length,
const int* input_length,
template<typename T> const int batch_size,
void invokeAddFusedQKVBiasTranspose(T* q_buf, const int seq_len,
T* k_buf, const int token_num,
T* v_buf, const int head_num,
PrefixPromptBatchWeightsParam<T> param, const int kv_head_num,
T* QKV, const int size_per_head,
const T* qkv_bias, const int rotary_embedding_dim,
const int* padding_offset, int max_position_embeddings,
const int* history_length, bool use_dynamic_ntk,
const int batch_size, bool use_logn_attn,
const int seq_len, cudaStream_t stream);
const int token_num,
const int head_num,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
const int neox_rotary_style,
const float* scale,
const int int8_mode,
cudaStream_t stream);
template<typename T> template<typename T>
void invokeTranspose4d(T* dst, void invokeTranspose4d(T* dst,
......
...@@ -163,21 +163,21 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap* ...@@ -163,21 +163,21 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
invokeAddFusedQKVBiasTranspose(q_buf_2_, invokeAddFusedQKVBiasTranspose(q_buf_2_,
k_buf_2_, k_buf_2_,
v_buf_2_, v_buf_2_,
PrefixPromptBatchWeightsParam<T>{},
qkv_buf_, qkv_buf_,
weights->qkv.bias, weights->qkv.bias,
padding_offset, // padding_offset, padding_offset, // padding_offset,
history_length, // used for applying rotary embedding history_length, // used for applying rotary embedding
input_length,
batch_size, batch_size,
max_q_len, // seq_len max_q_len, // seq_len
num_token, // batch_size * seq_len num_token, // batch_size * seq_len
local_head_num_, local_head_num_,
local_kv_head_num_, local_kv_head_num_,
size_per_head_, size_per_head_,
rotary_embedding_dim_, params_.rotray_embedding_dim,
neox_rotary_style_, params_.max_position_embeddings,
nullptr, // query_weight.scale_out params_.use_dynamic_ntk,
0, // int8 mode params_.use_logn_attn,
stream_); stream_);
sync_check_cuda_error(); sync_check_cuda_error();
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "src/turbomind/models/llama/LlamaDenseWeight.h" #include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/LlamaLinear.h" #include "src/turbomind/models/llama/LlamaLinear.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/nccl_utils.h" #include "src/turbomind/utils/nccl_utils.h"
...@@ -34,26 +35,24 @@ public: ...@@ -34,26 +35,24 @@ public:
void freeBuffer(); void freeBuffer();
void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len); void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len);
LlamaContextAttentionLayer(size_t head_num, LlamaContextAttentionLayer(size_t head_num,
size_t kv_head_num, size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t rotary_embedding_dim, LlamaAttentionParams attn_params,
bool neox_rotary_style, NcclParam tensor_para,
NcclParam tensor_para, cudaStream_t stream,
cudaStream_t stream, cublasMMWrapper* cublas_wrapper,
cublasMMWrapper* cublas_wrapper, IAllocator* allocator,
IAllocator* allocator, bool is_free_buffer_after_forward,
bool is_free_buffer_after_forward, bool use_fmha,
bool use_fmha, int quant_policy):
int quant_policy):
head_num_(head_num), head_num_(head_num),
size_per_head_(size_per_head), size_per_head_(size_per_head),
hidden_units_(head_num * size_per_head), hidden_units_(head_num * size_per_head),
local_head_num_(head_num / tensor_para.world_size_), local_head_num_(head_num / tensor_para.world_size_),
local_kv_head_num_(kv_head_num / tensor_para.world_size_), local_kv_head_num_(kv_head_num / tensor_para.world_size_),
head_n_rep_(head_num / kv_head_num), head_n_rep_(head_num / kv_head_num),
rotary_embedding_dim_(rotary_embedding_dim), params_(attn_params),
neox_rotary_style_(neox_rotary_style),
tensor_para_(tensor_para), tensor_para_(tensor_para),
stream_(stream), stream_(stream),
cublas_wrapper_(cublas_wrapper), cublas_wrapper_(cublas_wrapper),
...@@ -99,10 +98,9 @@ private: ...@@ -99,10 +98,9 @@ private:
const size_t local_kv_head_num_; const size_t local_kv_head_num_;
const size_t local_head_num_; const size_t local_head_num_;
const size_t head_n_rep_; const size_t head_n_rep_;
const size_t rotary_embedding_dim_;
const bool is_free_buffer_after_forward_; const bool is_free_buffer_after_forward_;
const bool neox_rotary_style_; const LlamaAttentionParams params_;
const bool use_fmha_; const bool use_fmha_;
const int quant_policy_; const int quant_policy_;
......
...@@ -61,15 +61,17 @@ void LlamaContextDecoder<T>::freeBuffer() ...@@ -61,15 +61,17 @@ void LlamaContextDecoder<T>::freeBuffer()
} }
template<typename T> template<typename T>
void LlamaContextDecoder<T>::initialize(size_t kv_head_num, bool use_fmha, int quant_policy) void LlamaContextDecoder<T>::initialize(const LlamaAttentionParams& attn_params,
size_t kv_head_num,
bool use_fmha,
int quant_policy)
{ {
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true); h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
context_attention_layer_ = new LlamaContextAttentionLayer<T>(head_num_, context_attention_layer_ = new LlamaContextAttentionLayer<T>(head_num_,
kv_head_num, kv_head_num,
size_per_head_, size_per_head_,
rotary_embedding_dim_, attn_params,
false, // neox_rotary_style
tensor_para_, tensor_para_,
stream_, stream_,
cublas_wrapper_, cublas_wrapper_,
...@@ -124,32 +126,31 @@ void LlamaContextDecoder<T>::forwardSelfAttn(const Session& ...@@ -124,32 +126,31 @@ void LlamaContextDecoder<T>::forwardSelfAttn(const Session&
} }
template<typename T> template<typename T>
LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num, LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num,
size_t kv_head_num, size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
size_t rotary_embedding_dim, const LlamaAttentionParams& attn_params,
float rmsnorm_eps, float rmsnorm_eps,
NcclParam tensor_para, NcclParam tensor_para,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
bool use_fmha, bool use_fmha,
int quant_policy): int quant_policy):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
head_num_(head_num), head_num_(head_num),
size_per_head_(size_per_head), size_per_head_(size_per_head),
inter_size_(inter_size), inter_size_(inter_size),
hidden_units_(head_num * size_per_head), hidden_units_(head_num * size_per_head),
num_layer_(num_layer), num_layer_(num_layer),
rotary_embedding_dim_(rotary_embedding_dim),
rmsnorm_eps_(rmsnorm_eps), rmsnorm_eps_(rmsnorm_eps),
tensor_para_(tensor_para), tensor_para_(tensor_para),
data_type_(getTensorType<T>()) data_type_(getTensorType<T>())
{ {
initialize(kv_head_num, use_fmha, quant_policy); initialize(attn_params, kv_head_num, use_fmha, quant_policy);
} }
template<typename T> template<typename T>
......
...@@ -20,14 +20,11 @@ ...@@ -20,14 +20,11 @@
#pragma once #pragma once
// #include "src/turbomind/kernels/add_residual_kernels.h"
// #include "src/turbomind/kernels/layernorm_kernels.h"
#include "src/turbomind/layers/BaseLayer.h" #include "src/turbomind/layers/BaseLayer.h"
// #include "src/turbomind/layers/FfnLayer.h"
// #include "src/turbomind/layers/attention_layers/BaseAttentionLayer.h"
#include "src/turbomind/models/llama/LlamaContextAttentionLayer.h" #include "src/turbomind/models/llama/LlamaContextAttentionLayer.h"
#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h" #include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
#include "src/turbomind/models/llama/LlamaFfnLayer.h" #include "src/turbomind/models/llama/LlamaFfnLayer.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/allocator.h" #include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cublasMMWrapper.h"
...@@ -43,13 +40,12 @@ protected: ...@@ -43,13 +40,12 @@ protected:
void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len); void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len);
void freeBuffer() override; void freeBuffer() override;
void initialize(size_t kv_head_num, bool use_fmha, int quant_policy); void initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, bool use_fmha, int quant_policy);
size_t head_num_; size_t head_num_;
size_t size_per_head_; size_t size_per_head_;
size_t inter_size_; size_t inter_size_;
size_t num_layer_; size_t num_layer_;
size_t rotary_embedding_dim_;
size_t hidden_units_; size_t hidden_units_;
float rmsnorm_eps_; float rmsnorm_eps_;
...@@ -87,20 +83,20 @@ protected: ...@@ -87,20 +83,20 @@ protected:
bool is_final); bool is_final);
public: public:
LlamaContextDecoder(size_t head_num, LlamaContextDecoder(size_t head_num,
size_t kv_head_num, size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
size_t rotary_embedding_dim, const LlamaAttentionParams& attn_params,
float rmsnorm_eps, float rmsnorm_eps,
NcclParam tensor_para, NcclParam tensor_para,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
bool use_fmha, bool use_fmha,
int quant_policy); int quant_policy);
~LlamaContextDecoder() override; ~LlamaContextDecoder() override;
......
...@@ -23,37 +23,37 @@ ...@@ -23,37 +23,37 @@
#include "src/turbomind/macro.h" #include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/llama_decoder_kernels.h" #include "src/turbomind/models/llama/llama_decoder_kernels.h"
#include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/models/llama/llama_utils.h"
namespace turbomind { namespace turbomind {
template<typename T> template<typename T>
LlamaDecoder<T>::LlamaDecoder(size_t head_num, LlamaDecoder<T>::LlamaDecoder(size_t head_num,
size_t kv_head_num, size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
size_t rotary_embedding_dim, const LlamaAttentionParams& attn_params,
float rmsnorm_eps, float rmsnorm_eps,
NcclParam tensor_para, NcclParam tensor_para,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
int quant_policy): int quant_policy):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
head_num_(head_num), head_num_(head_num),
size_per_head_(size_per_head), size_per_head_(size_per_head),
inter_size_(inter_size), inter_size_(inter_size),
num_layer_(num_layer), num_layer_(num_layer),
rotary_embedding_dim_(rotary_embedding_dim),
hidden_units_(head_num * size_per_head), hidden_units_(head_num * size_per_head),
rmsnorm_eps_(rmsnorm_eps), rmsnorm_eps_(rmsnorm_eps),
tensor_para_(tensor_para), tensor_para_(tensor_para),
data_type_(getTensorType<T>()) data_type_(getTensorType<T>())
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
initialize(kv_head_num, quant_policy); initialize(attn_params, kv_head_num, quant_policy);
} }
template<typename T> template<typename T>
...@@ -65,15 +65,14 @@ LlamaDecoder<T>::~LlamaDecoder() ...@@ -65,15 +65,14 @@ LlamaDecoder<T>::~LlamaDecoder()
} }
template<typename T> template<typename T>
void LlamaDecoder<T>::initialize(size_t kv_head_num, int quant_policy) void LlamaDecoder<T>::initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, int quant_policy)
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
self_attention_layer_ = new LlamaDecoderSelfAttentionLayer<T>(head_num_, self_attention_layer_ = new LlamaDecoderSelfAttentionLayer<T>(head_num_,
kv_head_num, kv_head_num,
size_per_head_, size_per_head_,
rotary_embedding_dim_, attn_params,
false, // neox_rotary_style
tensor_para_, tensor_para_,
stream_, stream_,
cublas_wrapper_, cublas_wrapper_,
......
...@@ -20,10 +20,10 @@ ...@@ -20,10 +20,10 @@
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptDecoder.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptDecoder.h
#include "src/turbomind/layers/BaseLayer.h" #include "src/turbomind/layers/BaseLayer.h"
// #include "src/turbomind/layers/FfnLayer.h"
#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h" #include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
#include "src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h" #include "src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h"
#include "src/turbomind/models/llama/LlamaFfnLayer.h" #include "src/turbomind/models/llama/LlamaFfnLayer.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/custom_ar_comm.h" #include "src/turbomind/utils/custom_ar_comm.h"
#include "src/turbomind/utils/nccl_utils.h" #include "src/turbomind/utils/nccl_utils.h"
...@@ -35,13 +35,12 @@ protected: ...@@ -35,13 +35,12 @@ protected:
void allocateBuffer() override; // deprecated void allocateBuffer() override; // deprecated
void allocateBuffer(size_t batch_size); void allocateBuffer(size_t batch_size);
void freeBuffer() override; void freeBuffer() override;
void initialize(size_t kv_head_num, int quant_policy); void initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, int quant_policy);
size_t head_num_; size_t head_num_;
size_t size_per_head_; size_t size_per_head_;
size_t inter_size_; size_t inter_size_;
size_t num_layer_; size_t num_layer_;
size_t rotary_embedding_dim_;
size_t hidden_units_; size_t hidden_units_;
float rmsnorm_eps_; float rmsnorm_eps_;
...@@ -69,19 +68,19 @@ protected: ...@@ -69,19 +68,19 @@ protected:
void forwardFfn(const LlamaDecoder::Session& sess, T* ffn_io, size_t layer); void forwardFfn(const LlamaDecoder::Session& sess, T* ffn_io, size_t layer);
public: public:
LlamaDecoder(size_t head_num, LlamaDecoder(size_t head_num,
size_t kv_head_num, size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
size_t rotary_embedding_dim, const LlamaAttentionParams& attn_params,
float rmsnorm_eps, float rmsnorm_eps,
NcclParam tensor_para, NcclParam tensor_para,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
int quant_policy); int quant_policy);
~LlamaDecoder() override; ~LlamaDecoder() override;
......
...@@ -61,6 +61,9 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -61,6 +61,9 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
const int kv_head_num, const int kv_head_num,
const int size_per_head, const int size_per_head,
const int rotary_embedding_dim, const int rotary_embedding_dim,
const int max_position_embeddings,
const bool use_dynamic_ntk,
const bool use_logn_attn,
const int memory_max_len, const int memory_max_len,
const int* prefix_prompt_lengths, const int* prefix_prompt_lengths,
const int max_prefix_prompt_length, const int max_prefix_prompt_length,
...@@ -110,13 +113,9 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -110,13 +113,9 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
FT_CHECK(k_cache_per_sample && v_cache_per_sample); FT_CHECK(k_cache_per_sample && v_cache_per_sample);
params.k_cache = reinterpret_cast<DataType*>(key_cache);
params.v_cache = reinterpret_cast<DataType*>(value_cache);
params.k_cache_per_sample = reinterpret_cast<DataType**>(k_cache_per_sample); params.k_cache_per_sample = reinterpret_cast<DataType**>(k_cache_per_sample);
params.v_cache_per_sample = reinterpret_cast<DataType**>(v_cache_per_sample); params.v_cache_per_sample = reinterpret_cast<DataType**>(v_cache_per_sample);
params.kv_cache_per_sample_offset = kv_cache_per_sample_offset; params.kv_cache_per_sample_offset = kv_cache_per_sample_offset;
params.k_cache_interleaved = false;
params.cache_indir = cache_indir;
params.batch_size = inference_batch_size; params.batch_size = inference_batch_size;
params.beam_width = beam_width; params.beam_width = beam_width;
params.memory_max_len = memory_max_len; params.memory_max_len = memory_max_len;
...@@ -128,8 +127,12 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -128,8 +127,12 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
params.num_heads = head_num; params.num_heads = head_num;
params.num_kv_heads = kv_head_num; params.num_kv_heads = kv_head_num;
params.hidden_size_per_head = size_per_head; params.hidden_size_per_head = size_per_head;
params.rotary_embedding_dim = rotary_embedding_dim; params.rotary_embedding_dim = rotary_embedding_dim;
params.max_position_embeddings = max_position_embeddings;
params.use_dynamic_ntk = use_dynamic_ntk;
params.use_logn_attn = use_logn_attn;
// Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust) // Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * q_scaling); params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * q_scaling);
...@@ -146,10 +149,6 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -146,10 +149,6 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
} }
params.max_input_length = max_input_len; params.max_input_length = max_input_len;
params.ia3_tasks = ia3_tasks;
params.ia3_key_weights = reinterpret_cast<const DataType*>(ia3_key_weights);
params.ia3_value_weights = reinterpret_cast<const DataType*>(ia3_value_weights);
params.int8_mode = int8_mode; params.int8_mode = int8_mode;
if (int8_mode & QuantPolicy::kCacheKVInt8) { if (int8_mode & QuantPolicy::kCacheKVInt8) {
...@@ -185,8 +184,6 @@ void LlamaDecoderSelfAttentionLayer<T>::freeBuffer() ...@@ -185,8 +184,6 @@ void LlamaDecoderSelfAttentionLayer<T>::freeBuffer()
if (is_allocate_buffer_) { if (is_allocate_buffer_) {
allocator_->free((void**)(&qkv_buf_)); allocator_->free((void**)(&qkv_buf_));
allocator_->free((void**)(&context_buf_)); allocator_->free((void**)(&context_buf_));
// allocator_->free((void**)(&k_cache_buf_));
// allocator_->free((void**)(&v_cache_buf_));
is_allocate_buffer_ = false; is_allocate_buffer_ = false;
} }
} }
...@@ -263,7 +260,10 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o ...@@ -263,7 +260,10 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
local_head_num_, local_head_num_,
local_kv_head_num_, local_kv_head_num_,
size_per_head_, size_per_head_,
rotary_embedding_dim_, params_.rotray_embedding_dim,
params_.max_position_embeddings,
params_.use_dynamic_ntk,
params_.use_logn_attn,
memory_len, memory_len,
nullptr, // prefix_prompt_lengths nullptr, // prefix_prompt_lengths
0, // max_prefix_prompt_length 0, // max_prefix_prompt_length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment