Unverified Commit f07b697b authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

[Feature] Support Llama-2 with GQA (#147)

* add GQA for llama2

* fix model conversion

* fix lint & remove dev log

* update news

* minor

* fix allocation size

* fix split_dim for w_qkv.bias
parent 2a475478
...@@ -13,8 +13,9 @@ ______________________________________________________________________ ...@@ -13,8 +13,9 @@ ______________________________________________________________________
## News 🎉 ## News 🎉
- \[2023/07\] TurboMind supports Llama-2 70B with GQA.
- \[2023/07\] TurboMind supports Llama-2 7B/13B.
- \[2023/07\] TurboMind supports tensor-parallel inference of InternLM. - \[2023/07\] TurboMind supports tensor-parallel inference of InternLM.
- \[2023/07\] TurboMind supports llama2 7b/13b.
______________________________________________________________________ ______________________________________________________________________
......
...@@ -13,8 +13,9 @@ ______________________________________________________________________ ...@@ -13,8 +13,9 @@ ______________________________________________________________________
## 更新 🎉 ## 更新 🎉
- \[2023/07\] TurboMind 支持使用 GQA 的 Llama-2 70B 模型
- \[2023/07\] TurboMind 支持 Llama-2 7B/13B 模型
- \[2023/07\] TurboMind 支持 InternLM 的 Tensor Parallel 推理 - \[2023/07\] TurboMind 支持 InternLM 的 Tensor Parallel 推理
- \[2023/07\] TurboMind 支持 Llama2 7b/13b 模型
______________________________________________________________________ ______________________________________________________________________
......
...@@ -3,6 +3,6 @@ ...@@ -3,6 +3,6 @@
add_executable(llama_triton_example llama_triton_example.cc) add_executable(llama_triton_example llama_triton_example.cc)
target_link_libraries(llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart target_link_libraries(llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart
LlamaTritonBackend TransformerTritonBackend mpi_utils nccl_utils LlamaTritonBackend TransformerTritonBackend mpi_utils nccl_utils
nvtx_utils word_list glog) nvtx_utils word_list)
install(TARGETS llama_triton_example DESTINATION ${CMAKE_INSTALL_PREFIX}/bin) install(TARGETS llama_triton_example DESTINATION ${CMAKE_INSTALL_PREFIX}/bin)
...@@ -95,6 +95,7 @@ def tokenizer_info(model_path: str): ...@@ -95,6 +95,7 @@ def tokenizer_info(model_path: str):
def export(model_name: str, def export(model_name: str,
num_layer: int, num_layer: int,
norm_eps: float, norm_eps: float,
kv_head_num: int,
model_params: dict, model_params: dict,
tokenizer_path: str, tokenizer_path: str,
out_dir: str, out_dir: str,
...@@ -133,10 +134,12 @@ def export(model_name: str, ...@@ -133,10 +134,12 @@ def export(model_name: str,
if key == 'w_qkv' and ext == 'bias': if key == 'w_qkv' and ext == 'bias':
attn_bias = True attn_bias = True
copy = False copy = False
if key in ['w1', 'w3', 'w_qkv']: if key in ['w1', 'w3']:
split_dim = -1 split_dim = -1
if key == 'w1': if key == 'w1':
inter_size = param_data.shape[-1] inter_size = param_data.shape[-1]
elif key == 'w_qkv':
split_dim = -2
elif key in ['w2', 'wo']: elif key in ['w2', 'wo']:
if ext in ['scales', 'zeros', 'bias']: if ext in ['scales', 'zeros', 'bias']:
copy = True copy = True
...@@ -167,6 +170,7 @@ def export(model_name: str, ...@@ -167,6 +170,7 @@ def export(model_name: str,
cfg = dict(llama=dict( cfg = dict(llama=dict(
model_name=model_name, model_name=model_name,
head_num=head_num, head_num=head_num,
kv_head_num=kv_head_num,
size_per_head=size_per_head, size_per_head=size_per_head,
vocab_size=vocab_size, vocab_size=vocab_size,
num_layer=num_layer, num_layer=num_layer,
...@@ -184,7 +188,7 @@ def export(model_name: str, ...@@ -184,7 +188,7 @@ def export(model_name: str,
step_length=1, step_length=1,
cache_max_entry_count=48, cache_max_entry_count=48,
cache_chunk_size=1, cache_chunk_size=1,
use_context_fmha=1, use_context_fmha=int(kv_head_num == head_num),
quant_policy=0, quant_policy=0,
tensor_para_size=tp)) tensor_para_size=tp))
...@@ -198,6 +202,15 @@ def export(model_name: str, ...@@ -198,6 +202,15 @@ def export(model_name: str,
return True return True
def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int,
dim: int):
def reshape(x):
return x.view(x.size(0), tp, -1) if dim == 2 else x.view(tp, -1)
return torch.cat((reshape(q), reshape(k), reshape(v)), dim=-1)
def deploy_llama(model_name: str, model_path: str, tokenizer_path: str, def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
triton_models_path: str, tp: int): triton_models_path: str, tp: int):
"""Deploy a model with huggingface transformers' format. """Deploy a model with huggingface transformers' format.
...@@ -223,6 +236,8 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str, ...@@ -223,6 +236,8 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
model_arg = json.load(f) model_arg = json.load(f)
num_layer = model_arg['n_layers'] num_layer = model_arg['n_layers']
norm_eps = model_arg['norm_eps'] norm_eps = model_arg['norm_eps']
head_num = model_arg.get('n_heads', 32)
kv_head_num = model_arg.get('n_kv_heads', head_num)
except Exception as e: except Exception as e:
print(f'get "n_layers" and "norm_eps" from {params_path} failed: {e}') print(f'get "n_layers" and "norm_eps" from {params_path} failed: {e}')
return False return False
...@@ -268,7 +283,6 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str, ...@@ -268,7 +283,6 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
else: # bias else: # bias
param = get_param(param_name, [size]) param = get_param(param_name, [size])
param.data = param_data param.data = param_data
elif i == 0: elif i == 0:
param = get_param(param_name, param_data.size()) param = get_param(param_name, param_data.size())
param.data = param_data param.data = param_data
...@@ -291,14 +305,14 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str, ...@@ -291,14 +305,14 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
qkv = tuple(map(model_params.pop, _qkv)) qkv = tuple(map(model_params.pop, _qkv))
except KeyError: except KeyError:
break break
# concat by output_dims # concat by heads
qkv = torch.stack(qkv, dim=qkv[0].dim() - 1) qkv = merge_qkv(*qkv, tp, dim=2 if t == 'weight' else 1)
print(f'layers.{i}.attention.w_qkv.{t}', qkv.shape) print(f'layers.{i}.attention.w_qkv.{t}', qkv.shape)
model_params[f'layers.{i}.attention.w_qkv.{t}'] = qkv model_params[f'layers.{i}.attention.w_qkv.{t}'] = qkv
assert i == 0 or num_layer == i, f'miss matched layers: {num_layer} vs {i}' assert i == 0 or num_layer == i, f'miss matched layers: {num_layer} vs {i}'
return export(model_name, num_layer, norm_eps, model_params, return export(model_name, num_layer, norm_eps, kv_head_num, model_params,
tokenizer_path, triton_models_path, tp) tokenizer_path, triton_models_path, tp)
...@@ -349,6 +363,10 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, ...@@ -349,6 +363,10 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
model_arg = json.load(f) model_arg = json.load(f)
num_layer = model_arg['num_hidden_layers'] num_layer = model_arg['num_hidden_layers']
norm_eps = model_arg['rms_norm_eps'] norm_eps = model_arg['rms_norm_eps']
if 'num_key_value_heads' in model_arg:
kv_head_num = model_arg['num_key_value_heads']
else:
kv_head_num = model_arg['num_attention_heads']
except Exception as e: except Exception as e:
print(f'get "num_hidden_layers" and "rms_norm_eps" from ' print(f'get "num_hidden_layers" and "rms_norm_eps" from '
f'{params_path} failed: {e}') f'{params_path} failed: {e}')
...@@ -416,11 +434,10 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, ...@@ -416,11 +434,10 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
q = permute(q) q = permute(q)
k = permute(k) k = permute(k)
if suffix == _qweight: # weight, qweight if suffix == _qweight: # weight, qweight
# insert a dimension for splitting heads later qkv = merge_qkv(q, k, v, tp, dim=2)
qkv = torch.stack((q, k, v), dim=1) print(suffix, qkv.shape)
else: # scales, zeros, bias else: # scales, zeros, bias
qkv = torch.stack((q.squeeze(), k.squeeze(), v.squeeze()), qkv = merge_qkv(q, k, v, tp, dim=1)
dim=0).squeeze(dim=-1)
print(suffix, qkv.shape) print(suffix, qkv.shape)
for k, v in [('w_qkv', qkv), ('wo', o)]: for k, v in [('w_qkv', qkv), ('wo', o)]:
model_params[f'layers.{i}.attention.{k}.{suffix}'] = v model_params[f'layers.{i}.attention.{k}.{suffix}'] = v
...@@ -456,7 +473,7 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, ...@@ -456,7 +473,7 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
for ft, hf in other: for ft, hf in other:
model_params[ft] = get_tensor(hf) model_params[ft] = get_tensor(hf)
return export(model_name, num_layer, norm_eps, model_params, return export(model_name, num_layer, norm_eps, kv_head_num, model_params,
tokenizer_path, triton_models_path, tp) tokenizer_path, triton_models_path, tp)
......
...@@ -132,6 +132,7 @@ struct Multihead_attention_params: public Multihead_attention_params_base<T> { ...@@ -132,6 +132,7 @@ struct Multihead_attention_params: public Multihead_attention_params_base<T> {
T** v_cache_per_sample = nullptr; T** v_cache_per_sample = nullptr;
size_t kv_cache_per_sample_offset = 0; size_t kv_cache_per_sample_offset = 0;
bool k_cache_interleaved = true; bool k_cache_interleaved = true;
int num_kv_heads = 0;
}; };
template<class T> template<class T>
......
...@@ -677,7 +677,7 @@ inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) ...@@ -677,7 +677,7 @@ inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
" {%7, %7, %7, %7}; \n" " {%7, %7, %7, %7}; \n"
: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); : "r"(a.x), "r"(a.y), "r"(b), "f"(zero));
return c; return c;
} }
...@@ -1349,16 +1349,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1349,16 +1349,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (params.finished != nullptr && params.finished[bi] == true) { if (params.finished != nullptr && params.finished[bi] == true) {
return; return;
} }
// The beam idx
const int beami = bi % params.beam_width;
// The "beam-aware" batch idx
const int bbi = bi / params.beam_width;
// The head. // The head.
const int hi = blockIdx.x; const int hi = blockIdx.x;
// Combine the batch and the head indices. // Combine the batch and the head indices.
const int bhi = bi * params.num_heads + hi; const int bhi = bi * params.num_heads + hi;
// Combine the "beam-aware" batch idx and the head indices.
const int bbhi = bbi * params.beam_width * params.num_heads + hi; const int head_n_rep = params.num_heads / params.num_kv_heads;
const int kvhi = hi / head_n_rep; // heads in the same group collapse to the same kv head
const bool group_leader = hi % head_n_rep == 0; // only group leader writes to kv cache
// The thread in the block. // The thread in the block.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
...@@ -1369,8 +1371,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1369,8 +1371,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
float qk = 0.0F; float qk = 0.0F;
int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;
const size_t bi_seq_len_offset = bi * params.memory_max_len; const size_t bi_seq_len_offset = bi * params.memory_max_len;
const int tlength = params.length_per_sample[bi] + params.max_prefix_prompt_length; const int tlength = params.length_per_sample[bi] + params.max_prefix_prompt_length;
...@@ -1380,10 +1380,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1380,10 +1380,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
const bool is_masked = tidx >= QK_VECS_PER_WARP; const bool is_masked = tidx >= QK_VECS_PER_WARP;
const int q_base_offset = bi * params.stride + hi * Dh;
const int k_base_offset = bi * params.stride + kvhi * Dh;
// The offset in the Q and K buffer also accounts for the batch. // The offset in the Q and K buffer also accounts for the batch.
int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; const int q_offset = q_base_offset + tidx * QK_VEC_SIZE;
const int k_offset = k_base_offset + tidx * QK_VEC_SIZE;
// The offset in the bias buffer. // The offset in the bias buffer.
int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; const int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
const int k_bias_offset = kvhi * Dh + tidx * QK_VEC_SIZE;
// past kv quant param // past kv quant param
const float k_scale = params.attention_k_scale; const float k_scale = params.attention_k_scale;
...@@ -1393,31 +1399,30 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1393,31 +1399,30 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
Qk_vec_k q; Qk_vec_k q;
zero(q); zero(q);
if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q[qk_offset])); q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q[q_offset]));
} }
Qk_vec_k k; Qk_vec_k k;
zero(k); zero(k);
{ {
k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k[qk_offset])) : vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k[k_offset])) :
k; k;
} }
// Trigger the loads from the Q and K bias buffers. // Trigger the loads from the Q and K bias buffers.
Qk_vec_k q_bias; Qk_vec_k q_bias;
zero(q_bias); zero(q_bias);
q_bias = q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
(!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q_bias[q_bias_offset])) :
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q_bias[qk_bias_offset])) : q_bias;
q_bias;
Qk_vec_k k_bias; Qk_vec_k k_bias;
zero(k_bias); zero(k_bias);
if (handle_kv) { if (handle_kv) {
k_bias = k_bias =
!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_bias[qk_bias_offset])) : vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_bias[k_bias_offset])) :
k_bias; k_bias;
} }
...@@ -1454,7 +1459,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1454,7 +1459,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// The position of the thread in that 16B chunk. // The position of the thread in that 16B chunk.
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
if (handle_kv) { if (handle_kv && group_leader) {
// Trigger the stores to global memory. // Trigger the stores to global memory.
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
if (!params.k_cache_per_sample) { if (!params.k_cache_per_sample) {
...@@ -1476,12 +1481,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1476,12 +1481,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
else { else {
int offset; int offset;
if (params.k_cache_interleaved) { if (params.k_cache_interleaved) {
offset = params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh
+ co * params.memory_max_len * QK_ELTS_IN_16B + tlength_circ * QK_ELTS_IN_16B + ci; + co * params.memory_max_len * QK_ELTS_IN_16B + tlength_circ * QK_ELTS_IN_16B + ci;
} }
else { else {
offset = params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + tlength_circ * Dh offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh
+ co * QK_ELTS_IN_16B + ci; + tlength_circ * Dh + co * QK_ELTS_IN_16B + ci;
} }
if (not QUANT_POLICY) { if (not QUANT_POLICY) {
...@@ -1577,7 +1582,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1577,7 +1582,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (not QUANT_POLICY) { if (not QUANT_POLICY) {
k_cache_batch = params.k_cache_per_sample ? (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset k_cache_batch = params.k_cache_per_sample ? (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset
+ hi * params.memory_max_len * Dh + ki) : + kvhi * params.memory_max_len * Dh + ki) :
&params.k_cache[bhi * params.memory_max_len * Dh + ki]; &params.k_cache[bhi * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer // Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki]; // T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
...@@ -1586,7 +1591,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1586,7 +1591,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// convert k_cache_per_sample to int8 // convert k_cache_per_sample to int8
if (params.k_cache_per_sample) { if (params.k_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]); int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki; k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + ki;
} }
else { else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache); int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache);
...@@ -1765,7 +1770,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1765,7 +1770,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (not QUANT_POLICY) { if (not QUANT_POLICY) {
v_cache = params.v_cache_per_sample ? (params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset v_cache = params.v_cache_per_sample ? (params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset
+ hi * params.memory_max_len * Dh + vi) : + kvhi * params.memory_max_len * Dh + vi) :
&params.v_cache[bhi * params.memory_max_len * Dh + vi]; &params.v_cache[bhi * params.memory_max_len * Dh + vi];
// Base pointer for the beam's batch, before offsetting with indirection buffer // Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi]; // T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
...@@ -1774,7 +1779,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1774,7 +1779,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
else if (QUANT_POLICY == 4) { else if (QUANT_POLICY == 4) {
if (params.v_cache_per_sample) { if (params.v_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]); int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]);
v_cache_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi; v_cache_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + vi;
} }
else { else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache); int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache);
...@@ -1787,22 +1792,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1787,22 +1792,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// The number of values processed per iteration of the loop. // The number of values processed per iteration of the loop.
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
// One group of threads computes the product(s) for the current timestep.
V_vec_k v_bias;
zero(v_bias);
// if( vo == params.timestep % V_PER_ITER ) {
if (Dh == Dh_MAX || vi < Dh) {
if (handle_kv) {
if (vo == tlength % V_PER_ITER) {
// Trigger the loads from the V bias buffer.
if (params.v_bias != nullptr) {
v_bias = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&params.v_bias[hi * Dh + vi]));
}
}
}
}
// From previous, before values, step // From previous, before values, step
// Also make sure the logits are in shared memory. // Also make sure the logits are in shared memory.
__syncthreads(); __syncthreads();
...@@ -1924,14 +1913,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1924,14 +1913,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
V_vec_k v; V_vec_k v;
// Trigger the loads from the V buffer. // Trigger the loads from the V buffer.
const auto v_offset = qkv_base_offset + vi; const auto v_offset = k_base_offset + vi;
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&params.v[v_offset]));
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&params.v[v_offset]));
// Trigger the loads from the V bias buffer. // Trigger the loads from the V bias buffer.
// V_vec v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi*Dh + vi]); if (params.v_bias != nullptr) {
V_vec_k v_bias = *reinterpret_cast<const V_vec_k*>(&params.v_bias[kvhi * Dh + vi]);
v = add(v, v_bias);
}
// Compute the V values with bias. // Store the V values to cache
if (handle_kv) { if (handle_kv && group_leader) {
v = add(v, v_bias);
// Store the values with bias back to global memory in the cache for V. // Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v; //*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v;
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "src/turbomind/kernels/unfused_attention_kernels.h" #include "src/turbomind/kernels/unfused_attention_kernels.h"
#include "src/turbomind/utils/cuda_type_utils.cuh" #include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
namespace turbomind { namespace turbomind {
...@@ -1336,6 +1337,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* ...@@ -1336,6 +1337,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
const int batch_size, const int batch_size,
const int seq_len, const int seq_len,
const int head_num, const int head_num,
const int kv_head_num,
const int size_per_head, const int size_per_head,
const int rotary_embedding_dim, const int rotary_embedding_dim,
const bool neox_rotary_style) const bool neox_rotary_style)
...@@ -1396,7 +1398,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* ...@@ -1396,7 +1398,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
const int prefix_prompt_length = PREFIX_PROMPT ? param.d_prefix_prompt_lengths[batch_idx] : 0; const int prefix_prompt_length = PREFIX_PROMPT ? param.d_prefix_prompt_lengths[batch_idx] : 0;
const int hidden_idx = head_idx * size_per_head + tidx * vec_size; const int hidden_idx = head_idx * size_per_head + tidx * vec_size;
const int n = head_num * size_per_head; // const int n = head_num * size_per_head;
// the [0..seq_len) indices really handle KV [max_pp_len..seq_len+max_pp_len) // the [0..seq_len) indices really handle KV [max_pp_len..seq_len+max_pp_len)
// and Q [0..seq_len) // and Q [0..seq_len)
...@@ -1405,33 +1407,48 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* ...@@ -1405,33 +1407,48 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
// NOTE: q has seq len excluding prefix prompt // NOTE: q has seq len excluding prefix prompt
// src QKV: [batch, time, 3, head, hidden] // src QKV: [batch, time, 3, head, hidden]
const int src_q_idx = token_idx * 3 * n + hidden_idx; // const int src_q_idx = token_idx * 3 * n + hidden_idx;
const int src_k_idx = token_idx * 3 * n + hidden_idx + n; // const int src_k_idx = token_idx * 3 * n + hidden_idx + n;
const int src_v_idx = token_idx * 3 * n + hidden_idx + 2 * n; // const int src_v_idx = token_idx * 3 * n + hidden_idx + 2 * n;
const int q_kv_head_num = head_num + 2 * kv_head_num;
const int k_offset = head_num * size_per_head;
const int v_offset = k_offset + kv_head_num * size_per_head;
// src QKV: [batch, time, q_kv_head_num, hidden]
const int src_q_idx = token_idx * q_kv_head_num * size_per_head + hidden_idx;
const int src_k_idx = token_idx * q_kv_head_num * size_per_head + hidden_idx + k_offset;
const int src_v_idx = token_idx * q_kv_head_num * size_per_head + hidden_idx + v_offset;
Vec_t q, k, v; Vec_t q, k, v;
Vec_t q_bias, k_bias, v_bias; Vec_t q_bias, k_bias, v_bias;
// load Q and apply bias
if (!is_masked) { if (!is_masked) {
q = *reinterpret_cast<const Vec_t*>(&QKV[src_q_idx]); q = *reinterpret_cast<const Vec_t*>(&QKV[src_q_idx]);
k = *reinterpret_cast<const Vec_t*>(&QKV[src_k_idx]);
v = *reinterpret_cast<const Vec_t*>(&QKV[src_v_idx]);
if (qkv_bias) { if (qkv_bias) {
q_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx]); q_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx]);
k_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + n]); q = mmha::add(q, q_bias);
v_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + 2 * n]);
} }
} }
if (qkv_bias) { // load KV and apply bias
q = mmha::add(q, q_bias); if (!is_masked && head_idx < kv_head_num) {
k = mmha::add(k, k_bias); k = *reinterpret_cast<const Vec_t*>(&QKV[src_k_idx]);
v = mmha::add(v, v_bias); v = *reinterpret_cast<const Vec_t*>(&QKV[src_v_idx]);
if (qkv_bias) {
k_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + k_offset]);
v_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + v_offset]);
k = mmha::add(k, k_bias);
v = mmha::add(v, v_bias);
}
} }
const int t_offset = history_length ? history_length[batch_idx] : 0; const int t_offset = history_length ? history_length[batch_idx] : 0;
if (!neox_rotary_style) { if (!neox_rotary_style) {
// TODO: unused computation on k if GQA is used
mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, dst_kv_seq_idx + t_offset); mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, dst_kv_seq_idx + t_offset);
} }
else { else {
...@@ -1472,23 +1489,28 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* ...@@ -1472,23 +1489,28 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
k = *reinterpret_cast<Vec_t*>(k_smem + half_idx * smem_pitch + intra_half_idx); k = *reinterpret_cast<Vec_t*>(k_smem + half_idx * smem_pitch + intra_half_idx);
} }
} }
if (!is_masked && !q_buf) { // also skip modifying QKV if q/k/v_buf are present if (!is_masked && !q_buf) { // also skip modifying QKV if q/k/v_buf are present
*reinterpret_cast<Vec_t*>(&QKV[src_q_idx]) = q; *reinterpret_cast<Vec_t*>(&QKV[src_q_idx]) = q;
*reinterpret_cast<Vec_t*>(&QKV[src_k_idx]) = k; if (head_idx < kv_head_num) {
*reinterpret_cast<Vec_t*>(&QKV[src_v_idx]) = v; *reinterpret_cast<Vec_t*>(&QKV[src_k_idx]) = k;
*reinterpret_cast<Vec_t*>(&QKV[src_v_idx]) = v;
}
} }
const int dest_q_idx = batch_idx * size_per_head * seq_len * head_num + head_idx * size_per_head * seq_len const int dest_q_idx = batch_idx * size_per_head * seq_len * head_num + head_idx * size_per_head * seq_len
+ seq_idx * size_per_head + tidx * vec_size; + seq_idx * size_per_head + tidx * vec_size;
const int dest_kv_idx = batch_idx * size_per_head * total_seq_len * head_num const int dest_kv_idx = batch_idx * size_per_head * total_seq_len * kv_head_num
+ head_idx * size_per_head * total_seq_len + dst_kv_seq_idx * size_per_head + head_idx * size_per_head * total_seq_len + dst_kv_seq_idx * size_per_head
+ tidx * vec_size; + tidx * vec_size;
if (!is_masked) { if (!is_masked) {
*reinterpret_cast<Vec_t*>(&q_buf[dest_q_idx]) = q; *reinterpret_cast<Vec_t*>(&q_buf[dest_q_idx]) = q;
*reinterpret_cast<Vec_t*>(&k_buf[dest_kv_idx]) = k; if (head_idx < kv_head_num) {
*reinterpret_cast<Vec_t*>(&v_buf[dest_kv_idx]) = v; *reinterpret_cast<Vec_t*>(&k_buf[dest_kv_idx]) = k;
*reinterpret_cast<Vec_t*>(&v_buf[dest_kv_idx]) = v;
}
} }
} }
...@@ -1504,6 +1526,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* ...@@ -1504,6 +1526,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
batch_size, \ batch_size, \
seq_len, \ seq_len, \
head_num, \ head_num, \
kv_head_num, \
size_per_head, \ size_per_head, \
rotary_embedding_dim, \ rotary_embedding_dim, \
neox_rotary_style); neox_rotary_style);
...@@ -1521,6 +1544,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, ...@@ -1521,6 +1544,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int seq_len, const int seq_len,
const int token_num, const int token_num,
const int head_num, const int head_num,
const int kv_head_num,
const int size_per_head, const int size_per_head,
const int rotary_embedding_dim, const int rotary_embedding_dim,
const int neox_rotary_style, const int neox_rotary_style,
...@@ -1528,42 +1552,18 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, ...@@ -1528,42 +1552,18 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int int8_mode, const int int8_mode,
cudaStream_t stream) cudaStream_t stream)
{ {
// [bs, seq_len, 3, head, Dh] FT_CHECK(rotary_embedding_dim);
if (rotary_embedding_dim == 0 && param.max_prefix_prompt_length == 0) { FT_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with prefix prompt"); // TODO(mseznec)
const int m = token_num; // To implement rotary embeddings, each thread processes two QKV elems:
const int n = head_num * size_per_head; dim3 block((size_per_head / Vec_t<T>::size + 31) / 32 * 32);
dim3 block(384); dim3 grid(token_num + batch_size * param.max_prefix_prompt_length, head_num);
dim3 grid((int)(ceil(1.0 * m * n / 384))); size_t smem_size = neox_rotary_style ? 2 * rotary_embedding_dim * sizeof(T) : 0;
add_fusedQKV_bias_transpose_kernel<<<grid, block, 0, stream>>>(q_buf, // NOTE: add offset for rotary embedding
k_buf, if (param.max_prefix_prompt_length == 0) {
v_buf, FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, false);
QKV,
qkv_bias,
padding_offset,
batch_size,
seq_len,
token_num,
head_num,
size_per_head,
scale,
int8_mode);
} }
else { else {
FT_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with prefix prompt"); // TODO(mseznec) FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, true);
// To implement rotary embeddings, each thread processes two QKV elems:
dim3 block((size_per_head / Vec_t<T>::size + 31) / 32 * 32);
dim3 grid(token_num + batch_size * param.max_prefix_prompt_length, head_num);
size_t smem_size = neox_rotary_style ? 2 * rotary_embedding_dim * sizeof(T) : 0;
// NOTE: add offset for rotary embedding
// add_fusedQKV_bias_transpose_kernel<<<grid, block, 0, stream>>>(
// q_buf, k_buf, v_buf, param, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head,
// rotary_embedding_dim);
if (param.max_prefix_prompt_length == 0) {
FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, false);
}
else {
FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, true);
}
} }
} }
...@@ -1580,6 +1580,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, ...@@ -1580,6 +1580,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int seq_len, \ const int seq_len, \
const int token_num, \ const int token_num, \
const int head_num, \ const int head_num, \
const int kv_head_num, \
const int size_per_head, \ const int size_per_head, \
const int rotary_embedding_dim, \ const int rotary_embedding_dim, \
const int neox_rotary_style, \ const int neox_rotary_style, \
......
...@@ -112,6 +112,7 @@ struct PrefixPromptBatchWeightsParam { ...@@ -112,6 +112,7 @@ struct PrefixPromptBatchWeightsParam {
// l * 2 * hidden_units_ / tensor_para_.world_size_ // l * 2 * hidden_units_ / tensor_para_.world_size_
const size_t prefix_prompt_layer_offset_per_seq = 0; const size_t prefix_prompt_layer_offset_per_seq = 0;
}; };
template<typename T> template<typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf, void invokeAddFusedQKVBiasTranspose(T* q_buf,
T* k_buf, T* k_buf,
...@@ -125,83 +126,13 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, ...@@ -125,83 +126,13 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int seq_len, const int seq_len,
const int token_num, const int token_num,
const int head_num, const int head_num,
const int kv_head_num,
const int size_per_head, const int size_per_head,
const int rotary_embedding_dim, const int rotary_embedding_dim,
const int neox_rotary_style, const int neox_rotary_style,
const float* scale, const float* scale,
const int int8_mode, const int int8_mode,
cudaStream_t stream); cudaStream_t stream);
template<typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf,
T* k_buf,
T* v_buf,
PrefixPromptBatchWeightsParam<T> param,
T* QKV,
const T* qkv_bias,
const int* padding_offset,
const int batch_size,
const int seq_len,
const int token_num,
const int head_num,
const int size_per_head,
const int rotary_embedding_dim,
const int neox_rotary_style,
const float* scale,
const int int8_mode,
cudaStream_t stream)
{
invokeAddFusedQKVBiasTranspose(q_buf,
k_buf,
v_buf,
param,
QKV,
qkv_bias,
padding_offset,
nullptr,
batch_size,
seq_len,
token_num,
head_num,
size_per_head,
rotary_embedding_dim,
neox_rotary_style,
scale,
int8_mode,
stream);
}
template<typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf,
T* k_buf,
T* v_buf,
T* QKV,
const T* qkv_bias,
const int* padding_offset,
const int batch_size,
const int seq_len,
const int token_num,
const int head_num,
const int size_per_head,
cudaStream_t stream)
{
invokeAddFusedQKVBiasTranspose(q_buf,
k_buf,
v_buf,
PrefixPromptBatchWeightsParam<T>{},
QKV,
qkv_bias,
padding_offset,
batch_size,
seq_len,
token_num,
head_num,
size_per_head,
0,
false,
(float*)nullptr,
0,
stream);
}
template<typename T> template<typename T>
void invokeTranspose4d(T* dst, void invokeTranspose4d(T* dst,
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
namespace turbomind { namespace turbomind {
...@@ -38,13 +39,17 @@ void LlamaContextAttentionLayer<T>::allocateBuffer(size_t batch_size, ...@@ -38,13 +39,17 @@ void LlamaContextAttentionLayer<T>::allocateBuffer(size_t batch_size,
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_;
// no padding // no padding
qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * 3 * local_hidden_units_, true); qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_, true);
// padding is rebuilt for q/k/v_buf_2_ // padding is rebuilt for q/k/v_buf_2_
q_buf_2_ = (T*)allocator_->reMalloc(q_buf_2_, sizeof(T) * 3 * batch_size * max_q_len * local_hidden_units_, true); // [qH + 2kvH, B, S, D]
k_buf_2_ = q_buf_2_ + batch_size * max_q_len * local_hidden_units_; q_buf_2_ = (T*)allocator_->reMalloc(
v_buf_2_ = k_buf_2_ + batch_size * max_q_len * local_hidden_units_; q_buf_2_, sizeof(T) * local_q_kv_head_num * batch_size * max_q_len * size_per_head_, true);
k_buf_2_ = q_buf_2_ + local_head_num_ * batch_size * max_q_len * size_per_head_;
v_buf_2_ = k_buf_2_ + local_kv_head_num_ * batch_size * max_q_len * size_per_head_;
if (use_fmha_) { if (use_fmha_) {
FlashAttentionOp<T> flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_); FlashAttentionOp<T> flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_);
...@@ -53,6 +58,7 @@ void LlamaContextAttentionLayer<T>::allocateBuffer(size_t batch_size, ...@@ -53,6 +58,7 @@ void LlamaContextAttentionLayer<T>::allocateBuffer(size_t batch_size,
} }
} }
else { else {
// kv heads are repeated for unfused attention
k_cache_buf_ = (T*)allocator_->reMalloc( k_cache_buf_ = (T*)allocator_->reMalloc(
k_cache_buf_, 2 * sizeof(T) * batch_size * local_head_num_ * max_k_len * size_per_head_, true); k_cache_buf_, 2 * sizeof(T) * batch_size * local_head_num_ * max_k_len * size_per_head_, true);
v_cache_buf_ = k_cache_buf_ + batch_size * local_head_num_ * max_k_len * size_per_head_; v_cache_buf_ = k_cache_buf_ + batch_size * local_head_num_ * max_k_len * size_per_head_;
...@@ -61,12 +67,12 @@ void LlamaContextAttentionLayer<T>::allocateBuffer(size_t batch_size, ...@@ -61,12 +67,12 @@ void LlamaContextAttentionLayer<T>::allocateBuffer(size_t batch_size,
(T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * max_q_len * max_k_len, true); (T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * max_q_len * max_k_len, true);
// qkv_buf_2_ has padding // qkv_buf_2_ has padding
qkv_buf_2_ = qkv_buf_2_ = (T*)allocator_->reMalloc(
(T*)allocator_->reMalloc(qkv_buf_2_, sizeof(T) * batch_size * max_q_len * local_hidden_units_, true); qkv_buf_2_, sizeof(T) * batch_size * max_q_len * local_head_num_ * size_per_head_, true);
} }
// qkv_buf_3_ padding is removed // qkv_buf_3_ padding is removed
qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_hidden_units_, true); qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_head_num_ * size_per_head_, true);
is_allocate_buffer_ = true; is_allocate_buffer_ = true;
} }
...@@ -152,7 +158,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap* ...@@ -152,7 +158,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
////////////////////////////////////////////// //////////////////////////////////////////////
/// transpose qkv & apply rotary embedding & rebuild padding /// transpose qkv & apply rotary embedding & rebuild padding
/// qkv [B, s, 3, H, D] -> (q [B, H, s, D], k [B, H, s, D], v [B, H, s, D]) /// qkv [B, s, H + 2kvH, D] -> (q [B, H, s, D], k [B, kvH, s, D], v [B, kvH, s, D])
invokeAddFusedQKVBiasTranspose(q_buf_2_, invokeAddFusedQKVBiasTranspose(q_buf_2_,
k_buf_2_, k_buf_2_,
v_buf_2_, v_buf_2_,
...@@ -165,6 +171,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap* ...@@ -165,6 +171,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
max_q_len, // seq_len max_q_len, // seq_len
num_token, // batch_size * seq_len num_token, // batch_size * seq_len
local_head_num_, local_head_num_,
local_kv_head_num_,
size_per_head_, size_per_head_,
rotary_embedding_dim_, rotary_embedding_dim_,
neox_rotary_style_, neox_rotary_style_,
...@@ -173,16 +180,16 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap* ...@@ -173,16 +180,16 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
stream_); stream_);
sync_check_cuda_error(); sync_check_cuda_error();
const size_t layer_offset = layer_id * local_head_num_ * max_seq_len * size_per_head_; const size_t layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_;
auto k_cache_ptrs = output_tensors->getPtr<T*>("key_cache"); auto k_cache_ptrs = output_tensors->getPtr<T*>("key_cache");
auto v_cache_ptrs = output_tensors->getPtr<T*>("value_cache"); auto v_cache_ptrs = output_tensors->getPtr<T*>("value_cache");
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
/// insert the k/v computed from inputs into k/v cache /// insert the k/v computed from inputs into k/v cache
/// transpose kv -> kv cache /// transpose kv -> kv cache
// put k/v_buf from shape [B, H, s, D] to // put k/v_buf from shape [B, kvH, s, D] to
// k_buf_2 [B, H, s, D] -> key_cache [B, H, S[t:t+s], D/x, x] // k_buf_2 [B, kvH, s, D] -> key_cache [B, kvH, S[t:t+s], D/x, x]
// v_buf_2 [B, H, s, D] -> val_cache [B, H, S[t:t+s], D/x, x] // v_buf_2 [B, kvH, s, D] -> val_cache [B, kvH, S[t:t+s], D/x, x]
invokeExtendKVCache(k_cache_ptrs, invokeExtendKVCache(k_cache_ptrs,
v_cache_ptrs, v_cache_ptrs,
layer_offset, layer_offset,
...@@ -194,13 +201,14 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap* ...@@ -194,13 +201,14 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
history_length, history_length,
max_seq_len, max_seq_len,
size_per_head_, size_per_head_,
local_head_num_, local_kv_head_num_,
stream_, stream_,
quant_policy_, quant_policy_,
weights->past_kv_scale.data()); weights->past_kv_scale.data());
sync_check_cuda_error(); sync_check_cuda_error();
if (use_fmha_) { if (use_fmha_) {
FT_CHECK(local_head_num_ == local_kv_head_num_);
fusedMultiHeadAttention(k_cache_ptrs, fusedMultiHeadAttention(k_cache_ptrs,
v_cache_ptrs, v_cache_ptrs,
layer_offset, layer_offset,
...@@ -311,8 +319,8 @@ void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T** key_c ...@@ -311,8 +319,8 @@ void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T** key_c
int quant, int quant,
const float* kv_scale) const float* kv_scale)
{ {
// key_cache [B, H, S[:t+s], D/x, x] -> [B, H, t+s, D] // key_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D]
// val_cache [B, H, S[:t+s], D/x, x] -> [B, H, t+s, D] // val_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D]
invokeTransposeKVCache(k_cache_buf_, invokeTransposeKVCache(k_cache_buf_,
v_cache_buf_, v_cache_buf_,
(const T**)key_cache_ptrs, (const T**)key_cache_ptrs,
...@@ -324,6 +332,7 @@ void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T** key_c ...@@ -324,6 +332,7 @@ void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T** key_c
max_seq_len, max_seq_len,
size_per_head_, size_per_head_,
local_head_num_, local_head_num_,
head_n_rep_,
stream_, stream_,
quant, quant,
kv_scale); kv_scale);
......
...@@ -35,6 +35,7 @@ public: ...@@ -35,6 +35,7 @@ public:
void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len); void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len);
LlamaContextAttentionLayer(size_t head_num, LlamaContextAttentionLayer(size_t head_num,
size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t rotary_embedding_dim, size_t rotary_embedding_dim,
bool neox_rotary_style, bool neox_rotary_style,
...@@ -49,7 +50,8 @@ public: ...@@ -49,7 +50,8 @@ public:
size_per_head_(size_per_head), size_per_head_(size_per_head),
hidden_units_(head_num * size_per_head), hidden_units_(head_num * size_per_head),
local_head_num_(head_num / tensor_para.world_size_), local_head_num_(head_num / tensor_para.world_size_),
local_hidden_units_(hidden_units_ / tensor_para.world_size_), local_kv_head_num_(kv_head_num / tensor_para.world_size_),
head_n_rep_(head_num / kv_head_num),
rotary_embedding_dim_(rotary_embedding_dim), rotary_embedding_dim_(rotary_embedding_dim),
neox_rotary_style_(neox_rotary_style), neox_rotary_style_(neox_rotary_style),
tensor_para_(tensor_para), tensor_para_(tensor_para),
...@@ -61,6 +63,7 @@ public: ...@@ -61,6 +63,7 @@ public:
use_fmha_(use_fmha), use_fmha_(use_fmha),
quant_policy_(quant_policy) quant_policy_(quant_policy)
{ {
FT_CHECK(head_num % kv_head_num == 0);
} }
void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaAttentionWeight<T>* weights); void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaAttentionWeight<T>* weights);
...@@ -93,8 +96,9 @@ private: ...@@ -93,8 +96,9 @@ private:
const size_t head_num_; const size_t head_num_;
const size_t size_per_head_; const size_t size_per_head_;
const size_t hidden_units_; const size_t hidden_units_;
const size_t local_kv_head_num_;
const size_t local_head_num_; const size_t local_head_num_;
const size_t local_hidden_units_; const size_t head_n_rep_;
const size_t rotary_embedding_dim_; const size_t rotary_embedding_dim_;
const bool is_free_buffer_after_forward_; const bool is_free_buffer_after_forward_;
......
...@@ -62,11 +62,12 @@ void LlamaContextDecoder<T>::freeBuffer() ...@@ -62,11 +62,12 @@ void LlamaContextDecoder<T>::freeBuffer()
} }
template<typename T> template<typename T>
void LlamaContextDecoder<T>::initialize(bool use_fmha, int quant_policy) void LlamaContextDecoder<T>::initialize(size_t kv_head_num, bool use_fmha, int quant_policy)
{ {
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true); h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
context_attention_layer_ = new LlamaContextAttentionLayer<T>(head_num_, context_attention_layer_ = new LlamaContextAttentionLayer<T>(head_num_,
kv_head_num,
size_per_head_, size_per_head_,
rotary_embedding_dim_, rotary_embedding_dim_,
false, // neox_rotary_style false, // neox_rotary_style
...@@ -124,6 +125,7 @@ void LlamaContextDecoder<T>::forwardSelfAttn(const Session& ...@@ -124,6 +125,7 @@ void LlamaContextDecoder<T>::forwardSelfAttn(const Session&
template<typename T> template<typename T>
LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num, LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num,
size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
...@@ -147,7 +149,7 @@ LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num, ...@@ -147,7 +149,7 @@ LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num,
tensor_para_(tensor_para), tensor_para_(tensor_para),
data_type_(getTensorType<T>()) data_type_(getTensorType<T>())
{ {
initialize(use_fmha, quant_policy); initialize(kv_head_num, use_fmha, quant_policy);
} }
template<typename T> template<typename T>
......
...@@ -43,7 +43,7 @@ protected: ...@@ -43,7 +43,7 @@ protected:
void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len); void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len);
void freeBuffer() override; void freeBuffer() override;
void initialize(bool use_fmha, int quant_policy); void initialize(size_t kv_head_num, bool use_fmha, int quant_policy);
size_t head_num_; size_t head_num_;
size_t size_per_head_; size_t size_per_head_;
...@@ -88,6 +88,7 @@ protected: ...@@ -88,6 +88,7 @@ protected:
public: public:
LlamaContextDecoder(size_t head_num, LlamaContextDecoder(size_t head_num,
size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
......
...@@ -28,6 +28,7 @@ namespace turbomind { ...@@ -28,6 +28,7 @@ namespace turbomind {
template<typename T> template<typename T>
LlamaDecoder<T>::LlamaDecoder(size_t head_num, LlamaDecoder<T>::LlamaDecoder(size_t head_num,
size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
...@@ -51,7 +52,7 @@ LlamaDecoder<T>::LlamaDecoder(size_t head_num, ...@@ -51,7 +52,7 @@ LlamaDecoder<T>::LlamaDecoder(size_t head_num,
data_type_(getTensorType<T>()) data_type_(getTensorType<T>())
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
initialize(quant_policy); initialize(kv_head_num, quant_policy);
} }
template<typename T> template<typename T>
...@@ -63,11 +64,12 @@ LlamaDecoder<T>::~LlamaDecoder() ...@@ -63,11 +64,12 @@ LlamaDecoder<T>::~LlamaDecoder()
} }
template<typename T> template<typename T>
void LlamaDecoder<T>::initialize(int quant_policy) void LlamaDecoder<T>::initialize(size_t kv_head_num, int quant_policy)
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
self_attention_layer_ = new LlamaDecoderSelfAttentionLayer<T>(head_num_, self_attention_layer_ = new LlamaDecoderSelfAttentionLayer<T>(head_num_,
kv_head_num,
size_per_head_, size_per_head_,
rotary_embedding_dim_, rotary_embedding_dim_,
false, // neox_rotary_style false, // neox_rotary_style
......
...@@ -35,7 +35,7 @@ protected: ...@@ -35,7 +35,7 @@ protected:
void allocateBuffer() override; // deprecated void allocateBuffer() override; // deprecated
void allocateBuffer(size_t batch_size); void allocateBuffer(size_t batch_size);
void freeBuffer() override; void freeBuffer() override;
void initialize(int quant_policy); void initialize(size_t kv_head_num, int quant_policy);
size_t head_num_; size_t head_num_;
size_t size_per_head_; size_t size_per_head_;
...@@ -70,6 +70,7 @@ protected: ...@@ -70,6 +70,7 @@ protected:
public: public:
LlamaDecoder(size_t head_num, LlamaDecoder(size_t head_num,
size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
...@@ -80,9 +81,9 @@ public: ...@@ -80,9 +81,9 @@ public:
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
int quant_policy), int quant_policy);
~LlamaDecoder() override; ~LlamaDecoder() override;
virtual void forward(std::unordered_map<std::string, Tensor>* output_tensors, virtual void forward(std::unordered_map<std::string, Tensor>* output_tensors,
const std::unordered_map<std::string, Tensor>* input_tensors, const std::unordered_map<std::string, Tensor>* input_tensors,
......
...@@ -25,13 +25,18 @@ ...@@ -25,13 +25,18 @@
namespace turbomind { namespace turbomind {
template<typename T> template<typename T>
LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t hidden_units, LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size, size_t inter_size,
WeightType weight_type, WeightType weight_type,
bool attn_bias, bool attn_bias,
size_t tensor_para_size, size_t tensor_para_size,
size_t tensor_para_rank): size_t tensor_para_rank):
hidden_units_(hidden_units), head_num_(head_num),
kv_head_num_(kv_head_num),
size_per_head_(size_per_head),
hidden_units_(head_num * size_per_head),
inter_size_(inter_size), inter_size_(inter_size),
weight_type_(weight_type), weight_type_(weight_type),
attn_bias_(attn_bias), attn_bias_(attn_bias),
...@@ -39,7 +44,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t hidden_units, ...@@ -39,7 +44,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t hidden_units,
tensor_para_rank_(tensor_para_rank) tensor_para_rank_(tensor_para_rank)
{ {
self_attn_weights.qkv.input_dims = hidden_units_; self_attn_weights.qkv.input_dims = hidden_units_;
self_attn_weights.qkv.output_dims = 3 * hidden_units_ / tensor_para_size_; self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_;
self_attn_weights.qkv.type = weight_type; self_attn_weights.qkv.type = weight_type;
self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_; self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_;
......
...@@ -28,7 +28,9 @@ template<typename T> ...@@ -28,7 +28,9 @@ template<typename T>
struct LlamaDecoderLayerWeight { struct LlamaDecoderLayerWeight {
public: public:
LlamaDecoderLayerWeight() = delete; LlamaDecoderLayerWeight() = delete;
LlamaDecoderLayerWeight(size_t hidden_units, LlamaDecoderLayerWeight(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size, size_t inter_size,
WeightType weight_type, WeightType weight_type,
bool attn_bias, bool attn_bias,
...@@ -46,6 +48,9 @@ public: ...@@ -46,6 +48,9 @@ public:
LlamaFfnWeight<T> ffn_weights{}; LlamaFfnWeight<T> ffn_weights{};
private: private:
size_t head_num_;
size_t kv_head_num_;
size_t size_per_head_;
size_t hidden_units_; size_t hidden_units_;
size_t inter_size_; size_t inter_size_;
WeightType weight_type_; WeightType weight_type_;
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/nvtx_utils.h" #include "src/turbomind/utils/nvtx_utils.h"
#include <string> #include <string>
// #include <glog/logging.h> // #include <glog/logging.h>
...@@ -56,6 +57,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -56,6 +57,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
const int inference_batch_size, const int inference_batch_size,
const int beam_width, const int beam_width,
const int head_num, const int head_num,
const int kv_head_num,
const int size_per_head, const int size_per_head,
const int rotary_embedding_dim, const int rotary_embedding_dim,
const int memory_max_len, const int memory_max_len,
...@@ -81,11 +83,11 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -81,11 +83,11 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
// Prepare the parameters. // Prepare the parameters.
Masked_multihead_attention_params<DataType> params; Masked_multihead_attention_params<DataType> params;
memset(&params, 0, sizeof(params)); memset(&params, 0, sizeof(params));
int hidden_units = head_num * size_per_head; // int hidden_units = head_num * size_per_head;
if (qkv_bias != nullptr) { if (qkv_bias != nullptr) {
params.q_bias = reinterpret_cast<const DataType*>(qkv_bias); params.q_bias = reinterpret_cast<const DataType*>(qkv_bias);
params.k_bias = reinterpret_cast<const DataType*>(qkv_bias) + hidden_units; params.k_bias = reinterpret_cast<const DataType*>(qkv_bias) + head_num * size_per_head;
params.v_bias = reinterpret_cast<const DataType*>(qkv_bias) + 2 * hidden_units; params.v_bias = reinterpret_cast<const DataType*>(qkv_bias) + (head_num + kv_head_num) * size_per_head;
} }
else { else {
params.q_bias = nullptr; params.q_bias = nullptr;
...@@ -97,13 +99,16 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -97,13 +99,16 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
params.out = reinterpret_cast<DataType*>(context_buf); params.out = reinterpret_cast<DataType*>(context_buf);
// Set the input buffers. // Set the input buffers.
// [B, nH + kvH, D]
params.q = reinterpret_cast<const DataType*>(qkv_buf); params.q = reinterpret_cast<const DataType*>(qkv_buf);
params.k = reinterpret_cast<const DataType*>(qkv_buf) + hidden_units; params.k = reinterpret_cast<const DataType*>(qkv_buf) + head_num * size_per_head;
params.v = reinterpret_cast<const DataType*>(qkv_buf) + 2 * hidden_units; params.v = reinterpret_cast<const DataType*>(qkv_buf) + (head_num + kv_head_num) * size_per_head;
params.stride = 3 * hidden_units; params.stride = (head_num + 2 * kv_head_num) * size_per_head;
params.finished = const_cast<bool*>(finished); params.finished = const_cast<bool*>(finished);
FT_CHECK(k_cache_per_sample && v_cache_per_sample);
params.k_cache = reinterpret_cast<DataType*>(key_cache); params.k_cache = reinterpret_cast<DataType*>(key_cache);
params.v_cache = reinterpret_cast<DataType*>(value_cache); params.v_cache = reinterpret_cast<DataType*>(value_cache);
params.k_cache_per_sample = reinterpret_cast<DataType**>(k_cache_per_sample); params.k_cache_per_sample = reinterpret_cast<DataType**>(k_cache_per_sample);
...@@ -118,8 +123,10 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -118,8 +123,10 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
params.max_prefix_prompt_length = max_prefix_prompt_length; params.max_prefix_prompt_length = max_prefix_prompt_length;
params.length_per_sample = sequence_lengths; // max_input_length + current output length params.length_per_sample = sequence_lengths; // max_input_length + current output length
// timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation // timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation
params.timestep = step + max_prefix_prompt_length - 1; params.timestep = step + max_prefix_prompt_length - 1;
params.num_heads = head_num; params.num_heads = head_num;
params.num_kv_heads = kv_head_num;
params.hidden_size_per_head = size_per_head; params.hidden_size_per_head = size_per_head;
params.rotary_embedding_dim = rotary_embedding_dim; params.rotary_embedding_dim = rotary_embedding_dim;
// Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust) // Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
...@@ -158,8 +165,11 @@ template<typename T> ...@@ -158,8 +165,11 @@ template<typename T>
void LlamaDecoderSelfAttentionLayer<T>::allocateBuffer(size_t batch_size, int key_len, int max_memory_len) void LlamaDecoderSelfAttentionLayer<T>::allocateBuffer(size_t batch_size, int key_len, int max_memory_len)
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
qkv_buf_ =
reinterpret_cast<T*>(allocator_->reMalloc(qkv_buf_, sizeof(T) * batch_size * 3 * local_hidden_units_, false)); const size_t local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_;
qkv_buf_ = reinterpret_cast<T*>(
allocator_->reMalloc(qkv_buf_, sizeof(T) * batch_size * local_q_kv_head_num * size_per_head_, false));
context_buf_ = context_buf_ =
reinterpret_cast<T*>(allocator_->reMalloc(context_buf_, sizeof(T) * batch_size * local_hidden_units_, false)); reinterpret_cast<T*>(allocator_->reMalloc(context_buf_, sizeof(T) * batch_size * local_hidden_units_, false));
...@@ -197,7 +207,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o ...@@ -197,7 +207,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
* *
* output tensors: * output tensors:
* \param attention_output [batch_size, hidden_units], * \param attention_output [batch_size, hidden_units],
* \param key_cache [batch, local_head_num, size_per_head / x, memory_max_len, x] * \param key_cache [batch, local_head_num, memory_max_len, size_per_head]
* \param value_cache [batch, local_head_num, memory_max_len, size_per_head] * \param value_cache [batch, local_head_num, memory_max_len, size_per_head]
*/ */
...@@ -228,7 +238,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o ...@@ -228,7 +238,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv); linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
POP_RANGE; POP_RANGE;
const auto kv_cache_layer_offset = layer_id * local_head_num_ * max_seq_len * size_per_head_; const auto kv_cache_layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_;
const int memory_len = max_seq_len; const int memory_len = max_seq_len;
fusedQKV_masked_attention_dispatch<T>( fusedQKV_masked_attention_dispatch<T>(
...@@ -248,6 +258,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o ...@@ -248,6 +258,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
batch_size, batch_size,
beam_width, beam_width,
local_head_num_, local_head_num_,
local_kv_head_num_,
size_per_head_, size_per_head_,
rotary_embedding_dim_, rotary_embedding_dim_,
memory_len, memory_len,
......
...@@ -34,6 +34,7 @@ public: ...@@ -34,6 +34,7 @@ public:
void allocateBuffer(size_t batch_size, int key_len, int max_memory_len); void allocateBuffer(size_t batch_size, int key_len, int max_memory_len);
LlamaDecoderSelfAttentionLayer(size_t head_num, LlamaDecoderSelfAttentionLayer(size_t head_num,
size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t rotary_embedding_dim, size_t rotary_embedding_dim,
bool neox_rotary_style, bool neox_rotary_style,
...@@ -44,9 +45,11 @@ public: ...@@ -44,9 +45,11 @@ public:
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
int quant_policy): int quant_policy):
head_num_(head_num), head_num_(head_num),
kv_head_num_(kv_head_num),
size_per_head_(size_per_head), size_per_head_(size_per_head),
hidden_units_(head_num * size_per_head), hidden_units_(head_num * size_per_head),
local_head_num_(head_num / tensor_para.world_size_), local_head_num_(head_num / tensor_para.world_size_),
local_kv_head_num_(kv_head_num_ / tensor_para.world_size_),
local_hidden_units_(hidden_units_ / tensor_para.world_size_), local_hidden_units_(hidden_units_ / tensor_para.world_size_),
rotary_embedding_dim_(rotary_embedding_dim), rotary_embedding_dim_(rotary_embedding_dim),
neox_rotary_style_(neox_rotary_style), neox_rotary_style_(neox_rotary_style),
...@@ -68,9 +71,11 @@ public: ...@@ -68,9 +71,11 @@ public:
private: private:
const size_t head_num_; const size_t head_num_;
const size_t kv_head_num_;
const size_t size_per_head_; const size_t size_per_head_;
const size_t hidden_units_; const size_t hidden_units_;
const size_t local_head_num_; const size_t local_head_num_;
const size_t local_kv_head_num_;
const size_t local_hidden_units_; const size_t local_hidden_units_;
const size_t rotary_embedding_dim_; const size_t rotary_embedding_dim_;
const bool is_free_buffer_after_forward_; const bool is_free_buffer_after_forward_;
......
...@@ -39,6 +39,7 @@ namespace turbomind { ...@@ -39,6 +39,7 @@ namespace turbomind {
template<typename T> template<typename T>
LlamaV2<T>::LlamaV2(size_t head_num, LlamaV2<T>::LlamaV2(size_t head_num,
size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
...@@ -102,8 +103,11 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -102,8 +103,11 @@ LlamaV2<T>::LlamaV2(size_t head_num,
else { else {
elem_bits = sizeof(T) * 8; elem_bits = sizeof(T) * 8;
} }
const size_t local_kv_head_num = kv_head_num / tensor_para.world_size_;
kv_cache_mgr_ = std::make_unique<LlamaCacheManager>(num_layer_, kv_cache_mgr_ = std::make_unique<LlamaCacheManager>(num_layer_,
local_head_num_, local_kv_head_num,
size_per_head_, size_per_head_,
session_len, session_len,
elem_bits, elem_bits,
...@@ -111,7 +115,7 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -111,7 +115,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
cache_chunk_size, cache_chunk_size,
tensor_para.rank_, tensor_para.rank_,
allocator); allocator);
initialize(use_context_fmha, quant_policy); initialize(kv_head_num, use_context_fmha, quant_policy);
start(); start();
} }
...@@ -126,11 +130,12 @@ LlamaV2<T>::~LlamaV2() ...@@ -126,11 +130,12 @@ LlamaV2<T>::~LlamaV2()
} }
template<typename T> template<typename T>
void LlamaV2<T>::initialize(bool use_context_fmha, int quant_policy) void LlamaV2<T>::initialize(size_t kv_head_num, bool use_context_fmha, int quant_policy)
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
context_decoder_ = new LlamaContextDecoder<T>(head_num_, context_decoder_ = new LlamaContextDecoder<T>(head_num_,
kv_head_num,
size_per_head_, size_per_head_,
inter_size_, inter_size_,
num_layer_, num_layer_,
...@@ -145,6 +150,7 @@ void LlamaV2<T>::initialize(bool use_context_fmha, int quant_policy) ...@@ -145,6 +150,7 @@ void LlamaV2<T>::initialize(bool use_context_fmha, int quant_policy)
quant_policy); quant_policy);
decoder_ = new LlamaDecoder<T>(head_num_, decoder_ = new LlamaDecoder<T>(head_num_,
kv_head_num,
size_per_head_, size_per_head_,
inter_size_, inter_size_,
num_layer_, num_layer_,
......
...@@ -49,6 +49,7 @@ public: ...@@ -49,6 +49,7 @@ public:
~LlamaV2(); ~LlamaV2();
LlamaV2(size_t head_num, LlamaV2(size_t head_num,
size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
...@@ -90,7 +91,7 @@ private: ...@@ -90,7 +91,7 @@ private:
void internalThreadEntry(int device_id); void internalThreadEntry(int device_id);
void initialize(bool use_context_fmha, int quant_policy); void initialize(size_t kv_head_num, bool use_context_fmha, int quant_policy);
void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step); void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment