Unverified Commit f07b697b authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

[Feature] Support Llama-2 with GQA (#147)

* add GQA for llama2

* fix model conversion

* fix lint & remove dev log

* update news

* minor

* fix allocation size

* fix split_dim for w_qkv.bias
parent 2a475478
......@@ -13,8 +13,9 @@ ______________________________________________________________________
## News 🎉
- \[2023/07\] TurboMind supports Llama-2 70B with GQA.
- \[2023/07\] TurboMind supports Llama-2 7B/13B.
- \[2023/07\] TurboMind supports tensor-parallel inference of InternLM.
- \[2023/07\] TurboMind supports llama2 7b/13b.
______________________________________________________________________
......
......@@ -13,8 +13,9 @@ ______________________________________________________________________
## 更新 🎉
- \[2023/07\] TurboMind 支持使用 GQA 的 Llama-2 70B 模型
- \[2023/07\] TurboMind 支持 Llama-2 7B/13B 模型
- \[2023/07\] TurboMind 支持 InternLM 的 Tensor Parallel 推理
- \[2023/07\] TurboMind 支持 Llama2 7b/13b 模型
______________________________________________________________________
......
......@@ -3,6 +3,6 @@
add_executable(llama_triton_example llama_triton_example.cc)
target_link_libraries(llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart
LlamaTritonBackend TransformerTritonBackend mpi_utils nccl_utils
nvtx_utils word_list glog)
nvtx_utils word_list)
install(TARGETS llama_triton_example DESTINATION ${CMAKE_INSTALL_PREFIX}/bin)
......@@ -95,6 +95,7 @@ def tokenizer_info(model_path: str):
def export(model_name: str,
num_layer: int,
norm_eps: float,
kv_head_num: int,
model_params: dict,
tokenizer_path: str,
out_dir: str,
......@@ -133,10 +134,12 @@ def export(model_name: str,
if key == 'w_qkv' and ext == 'bias':
attn_bias = True
copy = False
if key in ['w1', 'w3', 'w_qkv']:
if key in ['w1', 'w3']:
split_dim = -1
if key == 'w1':
inter_size = param_data.shape[-1]
elif key == 'w_qkv':
split_dim = -2
elif key in ['w2', 'wo']:
if ext in ['scales', 'zeros', 'bias']:
copy = True
......@@ -167,6 +170,7 @@ def export(model_name: str,
cfg = dict(llama=dict(
model_name=model_name,
head_num=head_num,
kv_head_num=kv_head_num,
size_per_head=size_per_head,
vocab_size=vocab_size,
num_layer=num_layer,
......@@ -184,7 +188,7 @@ def export(model_name: str,
step_length=1,
cache_max_entry_count=48,
cache_chunk_size=1,
use_context_fmha=1,
use_context_fmha=int(kv_head_num == head_num),
quant_policy=0,
tensor_para_size=tp))
......@@ -198,6 +202,15 @@ def export(model_name: str,
return True
def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int,
dim: int):
def reshape(x):
return x.view(x.size(0), tp, -1) if dim == 2 else x.view(tp, -1)
return torch.cat((reshape(q), reshape(k), reshape(v)), dim=-1)
def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
triton_models_path: str, tp: int):
"""Deploy a model with huggingface transformers' format.
......@@ -223,6 +236,8 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
model_arg = json.load(f)
num_layer = model_arg['n_layers']
norm_eps = model_arg['norm_eps']
head_num = model_arg.get('n_heads', 32)
kv_head_num = model_arg.get('n_kv_heads', head_num)
except Exception as e:
print(f'get "n_layers" and "norm_eps" from {params_path} failed: {e}')
return False
......@@ -268,7 +283,6 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
else: # bias
param = get_param(param_name, [size])
param.data = param_data
elif i == 0:
param = get_param(param_name, param_data.size())
param.data = param_data
......@@ -291,14 +305,14 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
qkv = tuple(map(model_params.pop, _qkv))
except KeyError:
break
# concat by output_dims
qkv = torch.stack(qkv, dim=qkv[0].dim() - 1)
# concat by heads
qkv = merge_qkv(*qkv, tp, dim=2 if t == 'weight' else 1)
print(f'layers.{i}.attention.w_qkv.{t}', qkv.shape)
model_params[f'layers.{i}.attention.w_qkv.{t}'] = qkv
assert i == 0 or num_layer == i, f'miss matched layers: {num_layer} vs {i}'
return export(model_name, num_layer, norm_eps, model_params,
return export(model_name, num_layer, norm_eps, kv_head_num, model_params,
tokenizer_path, triton_models_path, tp)
......@@ -349,6 +363,10 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
model_arg = json.load(f)
num_layer = model_arg['num_hidden_layers']
norm_eps = model_arg['rms_norm_eps']
if 'num_key_value_heads' in model_arg:
kv_head_num = model_arg['num_key_value_heads']
else:
kv_head_num = model_arg['num_attention_heads']
except Exception as e:
print(f'get "num_hidden_layers" and "rms_norm_eps" from '
f'{params_path} failed: {e}')
......@@ -416,11 +434,10 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
q = permute(q)
k = permute(k)
if suffix == _qweight: # weight, qweight
# insert a dimension for splitting heads later
qkv = torch.stack((q, k, v), dim=1)
qkv = merge_qkv(q, k, v, tp, dim=2)
print(suffix, qkv.shape)
else: # scales, zeros, bias
qkv = torch.stack((q.squeeze(), k.squeeze(), v.squeeze()),
dim=0).squeeze(dim=-1)
qkv = merge_qkv(q, k, v, tp, dim=1)
print(suffix, qkv.shape)
for k, v in [('w_qkv', qkv), ('wo', o)]:
model_params[f'layers.{i}.attention.{k}.{suffix}'] = v
......@@ -456,7 +473,7 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
for ft, hf in other:
model_params[ft] = get_tensor(hf)
return export(model_name, num_layer, norm_eps, model_params,
return export(model_name, num_layer, norm_eps, kv_head_num, model_params,
tokenizer_path, triton_models_path, tp)
......
......@@ -132,6 +132,7 @@ struct Multihead_attention_params: public Multihead_attention_params_base<T> {
T** v_cache_per_sample = nullptr;
size_t kv_cache_per_sample_offset = 0;
bool k_cache_interleaved = true;
int num_kv_heads = 0;
};
template<class T>
......
......@@ -677,7 +677,7 @@ inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
" {%7, %7, %7, %7}; \n"
: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
: "r"(a.x), "r"(a.y), "r"(b), "f"(zero));
return c;
}
......@@ -1349,16 +1349,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (params.finished != nullptr && params.finished[bi] == true) {
return;
}
// The beam idx
const int beami = bi % params.beam_width;
// The "beam-aware" batch idx
const int bbi = bi / params.beam_width;
// The head.
const int hi = blockIdx.x;
// Combine the batch and the head indices.
const int bhi = bi * params.num_heads + hi;
// Combine the "beam-aware" batch idx and the head indices.
const int bbhi = bbi * params.beam_width * params.num_heads + hi;
const int head_n_rep = params.num_heads / params.num_kv_heads;
const int kvhi = hi / head_n_rep; // heads in the same group collapse to the same kv head
const bool group_leader = hi % head_n_rep == 0; // only group leader writes to kv cache
// The thread in the block.
const int tidx = threadIdx.x;
......@@ -1369,8 +1371,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
float qk = 0.0F;
int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;
const size_t bi_seq_len_offset = bi * params.memory_max_len;
const int tlength = params.length_per_sample[bi] + params.max_prefix_prompt_length;
......@@ -1380,10 +1380,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
const bool is_masked = tidx >= QK_VECS_PER_WARP;
const int q_base_offset = bi * params.stride + hi * Dh;
const int k_base_offset = bi * params.stride + kvhi * Dh;
// The offset in the Q and K buffer also accounts for the batch.
int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
const int q_offset = q_base_offset + tidx * QK_VEC_SIZE;
const int k_offset = k_base_offset + tidx * QK_VEC_SIZE;
// The offset in the bias buffer.
int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
const int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
const int k_bias_offset = kvhi * Dh + tidx * QK_VEC_SIZE;
// past kv quant param
const float k_scale = params.attention_k_scale;
......@@ -1393,31 +1399,30 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
Qk_vec_k q;
zero(q);
if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q[qk_offset]));
q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q[q_offset]));
}
Qk_vec_k k;
zero(k);
{
k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k[qk_offset])) :
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k[k_offset])) :
k;
}
// Trigger the loads from the Q and K bias buffers.
Qk_vec_k q_bias;
zero(q_bias);
q_bias =
(!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q_bias[qk_bias_offset])) :
q_bias;
q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q_bias[q_bias_offset])) :
q_bias;
Qk_vec_k k_bias;
zero(k_bias);
if (handle_kv) {
k_bias =
!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_bias[qk_bias_offset])) :
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_bias[k_bias_offset])) :
k_bias;
}
......@@ -1454,7 +1459,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// The position of the thread in that 16B chunk.
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
if (handle_kv) {
if (handle_kv && group_leader) {
// Trigger the stores to global memory.
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
if (!params.k_cache_per_sample) {
......@@ -1476,12 +1481,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
else {
int offset;
if (params.k_cache_interleaved) {
offset = params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh
offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh
+ co * params.memory_max_len * QK_ELTS_IN_16B + tlength_circ * QK_ELTS_IN_16B + ci;
}
else {
offset = params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + tlength_circ * Dh
+ co * QK_ELTS_IN_16B + ci;
offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh
+ tlength_circ * Dh + co * QK_ELTS_IN_16B + ci;
}
if (not QUANT_POLICY) {
......@@ -1577,7 +1582,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (not QUANT_POLICY) {
k_cache_batch = params.k_cache_per_sample ? (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset
+ hi * params.memory_max_len * Dh + ki) :
+ kvhi * params.memory_max_len * Dh + ki) :
&params.k_cache[bhi * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
......@@ -1586,7 +1591,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// convert k_cache_per_sample to int8
if (params.k_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki;
k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + ki;
}
else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache);
......@@ -1765,7 +1770,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (not QUANT_POLICY) {
v_cache = params.v_cache_per_sample ? (params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset
+ hi * params.memory_max_len * Dh + vi) :
+ kvhi * params.memory_max_len * Dh + vi) :
&params.v_cache[bhi * params.memory_max_len * Dh + vi];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
......@@ -1774,7 +1779,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
else if (QUANT_POLICY == 4) {
if (params.v_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]);
v_cache_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi;
v_cache_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + vi;
}
else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache);
......@@ -1787,22 +1792,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// The number of values processed per iteration of the loop.
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
// One group of threads computes the product(s) for the current timestep.
V_vec_k v_bias;
zero(v_bias);
// if( vo == params.timestep % V_PER_ITER ) {
if (Dh == Dh_MAX || vi < Dh) {
if (handle_kv) {
if (vo == tlength % V_PER_ITER) {
// Trigger the loads from the V bias buffer.
if (params.v_bias != nullptr) {
v_bias = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&params.v_bias[hi * Dh + vi]));
}
}
}
}
// From previous, before values, step
// Also make sure the logits are in shared memory.
__syncthreads();
......@@ -1924,14 +1913,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
V_vec_k v;
// Trigger the loads from the V buffer.
const auto v_offset = qkv_base_offset + vi;
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&params.v[v_offset]));
const auto v_offset = k_base_offset + vi;
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&params.v[v_offset]));
// Trigger the loads from the V bias buffer.
// V_vec v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi*Dh + vi]);
if (params.v_bias != nullptr) {
V_vec_k v_bias = *reinterpret_cast<const V_vec_k*>(&params.v_bias[kvhi * Dh + vi]);
v = add(v, v_bias);
}
// Compute the V values with bias.
if (handle_kv) {
v = add(v, v_bias);
// Store the V values to cache
if (handle_kv && group_leader) {
// Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v;
......
......@@ -20,6 +20,7 @@
#include "src/turbomind/kernels/unfused_attention_kernels.h"
#include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
namespace turbomind {
......@@ -1336,6 +1337,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
const int batch_size,
const int seq_len,
const int head_num,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
const bool neox_rotary_style)
......@@ -1396,7 +1398,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
const int prefix_prompt_length = PREFIX_PROMPT ? param.d_prefix_prompt_lengths[batch_idx] : 0;
const int hidden_idx = head_idx * size_per_head + tidx * vec_size;
const int n = head_num * size_per_head;
// const int n = head_num * size_per_head;
// the [0..seq_len) indices really handle KV [max_pp_len..seq_len+max_pp_len)
// and Q [0..seq_len)
......@@ -1405,33 +1407,48 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
// NOTE: q has seq len excluding prefix prompt
// src QKV: [batch, time, 3, head, hidden]
const int src_q_idx = token_idx * 3 * n + hidden_idx;
const int src_k_idx = token_idx * 3 * n + hidden_idx + n;
const int src_v_idx = token_idx * 3 * n + hidden_idx + 2 * n;
// const int src_q_idx = token_idx * 3 * n + hidden_idx;
// const int src_k_idx = token_idx * 3 * n + hidden_idx + n;
// const int src_v_idx = token_idx * 3 * n + hidden_idx + 2 * n;
const int q_kv_head_num = head_num + 2 * kv_head_num;
const int k_offset = head_num * size_per_head;
const int v_offset = k_offset + kv_head_num * size_per_head;
// src QKV: [batch, time, q_kv_head_num, hidden]
const int src_q_idx = token_idx * q_kv_head_num * size_per_head + hidden_idx;
const int src_k_idx = token_idx * q_kv_head_num * size_per_head + hidden_idx + k_offset;
const int src_v_idx = token_idx * q_kv_head_num * size_per_head + hidden_idx + v_offset;
Vec_t q, k, v;
Vec_t q_bias, k_bias, v_bias;
// load Q and apply bias
if (!is_masked) {
q = *reinterpret_cast<const Vec_t*>(&QKV[src_q_idx]);
k = *reinterpret_cast<const Vec_t*>(&QKV[src_k_idx]);
v = *reinterpret_cast<const Vec_t*>(&QKV[src_v_idx]);
if (qkv_bias) {
q_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx]);
k_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + n]);
v_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + 2 * n]);
q = mmha::add(q, q_bias);
}
}
if (qkv_bias) {
q = mmha::add(q, q_bias);
k = mmha::add(k, k_bias);
v = mmha::add(v, v_bias);
// load KV and apply bias
if (!is_masked && head_idx < kv_head_num) {
k = *reinterpret_cast<const Vec_t*>(&QKV[src_k_idx]);
v = *reinterpret_cast<const Vec_t*>(&QKV[src_v_idx]);
if (qkv_bias) {
k_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + k_offset]);
v_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + v_offset]);
k = mmha::add(k, k_bias);
v = mmha::add(v, v_bias);
}
}
const int t_offset = history_length ? history_length[batch_idx] : 0;
if (!neox_rotary_style) {
// TODO: unused computation on k if GQA is used
mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, dst_kv_seq_idx + t_offset);
}
else {
......@@ -1472,23 +1489,28 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
k = *reinterpret_cast<Vec_t*>(k_smem + half_idx * smem_pitch + intra_half_idx);
}
}
if (!is_masked && !q_buf) { // also skip modifying QKV if q/k/v_buf are present
*reinterpret_cast<Vec_t*>(&QKV[src_q_idx]) = q;
*reinterpret_cast<Vec_t*>(&QKV[src_k_idx]) = k;
*reinterpret_cast<Vec_t*>(&QKV[src_v_idx]) = v;
if (head_idx < kv_head_num) {
*reinterpret_cast<Vec_t*>(&QKV[src_k_idx]) = k;
*reinterpret_cast<Vec_t*>(&QKV[src_v_idx]) = v;
}
}
const int dest_q_idx = batch_idx * size_per_head * seq_len * head_num + head_idx * size_per_head * seq_len
+ seq_idx * size_per_head + tidx * vec_size;
const int dest_kv_idx = batch_idx * size_per_head * total_seq_len * head_num
const int dest_kv_idx = batch_idx * size_per_head * total_seq_len * kv_head_num
+ head_idx * size_per_head * total_seq_len + dst_kv_seq_idx * size_per_head
+ tidx * vec_size;
if (!is_masked) {
*reinterpret_cast<Vec_t*>(&q_buf[dest_q_idx]) = q;
*reinterpret_cast<Vec_t*>(&k_buf[dest_kv_idx]) = k;
*reinterpret_cast<Vec_t*>(&v_buf[dest_kv_idx]) = v;
*reinterpret_cast<Vec_t*>(&q_buf[dest_q_idx]) = q;
if (head_idx < kv_head_num) {
*reinterpret_cast<Vec_t*>(&k_buf[dest_kv_idx]) = k;
*reinterpret_cast<Vec_t*>(&v_buf[dest_kv_idx]) = v;
}
}
}
......@@ -1504,6 +1526,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
batch_size, \
seq_len, \
head_num, \
kv_head_num, \
size_per_head, \
rotary_embedding_dim, \
neox_rotary_style);
......@@ -1521,6 +1544,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int seq_len,
const int token_num,
const int head_num,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
const int neox_rotary_style,
......@@ -1528,42 +1552,18 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int int8_mode,
cudaStream_t stream)
{
// [bs, seq_len, 3, head, Dh]
if (rotary_embedding_dim == 0 && param.max_prefix_prompt_length == 0) {
const int m = token_num;
const int n = head_num * size_per_head;
dim3 block(384);
dim3 grid((int)(ceil(1.0 * m * n / 384)));
add_fusedQKV_bias_transpose_kernel<<<grid, block, 0, stream>>>(q_buf,
k_buf,
v_buf,
QKV,
qkv_bias,
padding_offset,
batch_size,
seq_len,
token_num,
head_num,
size_per_head,
scale,
int8_mode);
FT_CHECK(rotary_embedding_dim);
FT_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with prefix prompt"); // TODO(mseznec)
// To implement rotary embeddings, each thread processes two QKV elems:
dim3 block((size_per_head / Vec_t<T>::size + 31) / 32 * 32);
dim3 grid(token_num + batch_size * param.max_prefix_prompt_length, head_num);
size_t smem_size = neox_rotary_style ? 2 * rotary_embedding_dim * sizeof(T) : 0;
// NOTE: add offset for rotary embedding
if (param.max_prefix_prompt_length == 0) {
FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, false);
}
else {
FT_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with prefix prompt"); // TODO(mseznec)
// To implement rotary embeddings, each thread processes two QKV elems:
dim3 block((size_per_head / Vec_t<T>::size + 31) / 32 * 32);
dim3 grid(token_num + batch_size * param.max_prefix_prompt_length, head_num);
size_t smem_size = neox_rotary_style ? 2 * rotary_embedding_dim * sizeof(T) : 0;
// NOTE: add offset for rotary embedding
// add_fusedQKV_bias_transpose_kernel<<<grid, block, 0, stream>>>(
// q_buf, k_buf, v_buf, param, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head,
// rotary_embedding_dim);
if (param.max_prefix_prompt_length == 0) {
FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, false);
}
else {
FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, true);
}
FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, true);
}
}
......@@ -1580,6 +1580,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int seq_len, \
const int token_num, \
const int head_num, \
const int kv_head_num, \
const int size_per_head, \
const int rotary_embedding_dim, \
const int neox_rotary_style, \
......
......@@ -112,6 +112,7 @@ struct PrefixPromptBatchWeightsParam {
// l * 2 * hidden_units_ / tensor_para_.world_size_
const size_t prefix_prompt_layer_offset_per_seq = 0;
};
template<typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf,
T* k_buf,
......@@ -125,83 +126,13 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int seq_len,
const int token_num,
const int head_num,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
const int neox_rotary_style,
const float* scale,
const int int8_mode,
cudaStream_t stream);
template<typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf,
T* k_buf,
T* v_buf,
PrefixPromptBatchWeightsParam<T> param,
T* QKV,
const T* qkv_bias,
const int* padding_offset,
const int batch_size,
const int seq_len,
const int token_num,
const int head_num,
const int size_per_head,
const int rotary_embedding_dim,
const int neox_rotary_style,
const float* scale,
const int int8_mode,
cudaStream_t stream)
{
invokeAddFusedQKVBiasTranspose(q_buf,
k_buf,
v_buf,
param,
QKV,
qkv_bias,
padding_offset,
nullptr,
batch_size,
seq_len,
token_num,
head_num,
size_per_head,
rotary_embedding_dim,
neox_rotary_style,
scale,
int8_mode,
stream);
}
template<typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf,
T* k_buf,
T* v_buf,
T* QKV,
const T* qkv_bias,
const int* padding_offset,
const int batch_size,
const int seq_len,
const int token_num,
const int head_num,
const int size_per_head,
cudaStream_t stream)
{
invokeAddFusedQKVBiasTranspose(q_buf,
k_buf,
v_buf,
PrefixPromptBatchWeightsParam<T>{},
QKV,
qkv_bias,
padding_offset,
batch_size,
seq_len,
token_num,
head_num,
size_per_head,
0,
false,
(float*)nullptr,
0,
stream);
}
template<typename T>
void invokeTranspose4d(T* dst,
......
......@@ -27,6 +27,7 @@
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
namespace turbomind {
......@@ -38,13 +39,17 @@ void LlamaContextAttentionLayer<T>::allocateBuffer(size_t batch_size,
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_;
// no padding
qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * 3 * local_hidden_units_, true);
qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_, true);
// padding is rebuilt for q/k/v_buf_2_
q_buf_2_ = (T*)allocator_->reMalloc(q_buf_2_, sizeof(T) * 3 * batch_size * max_q_len * local_hidden_units_, true);
k_buf_2_ = q_buf_2_ + batch_size * max_q_len * local_hidden_units_;
v_buf_2_ = k_buf_2_ + batch_size * max_q_len * local_hidden_units_;
// [qH + 2kvH, B, S, D]
q_buf_2_ = (T*)allocator_->reMalloc(
q_buf_2_, sizeof(T) * local_q_kv_head_num * batch_size * max_q_len * size_per_head_, true);
k_buf_2_ = q_buf_2_ + local_head_num_ * batch_size * max_q_len * size_per_head_;
v_buf_2_ = k_buf_2_ + local_kv_head_num_ * batch_size * max_q_len * size_per_head_;
if (use_fmha_) {
FlashAttentionOp<T> flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_);
......@@ -53,6 +58,7 @@ void LlamaContextAttentionLayer<T>::allocateBuffer(size_t batch_size,
}
}
else {
// kv heads are repeated for unfused attention
k_cache_buf_ = (T*)allocator_->reMalloc(
k_cache_buf_, 2 * sizeof(T) * batch_size * local_head_num_ * max_k_len * size_per_head_, true);
v_cache_buf_ = k_cache_buf_ + batch_size * local_head_num_ * max_k_len * size_per_head_;
......@@ -61,12 +67,12 @@ void LlamaContextAttentionLayer<T>::allocateBuffer(size_t batch_size,
(T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * max_q_len * max_k_len, true);
// qkv_buf_2_ has padding
qkv_buf_2_ =
(T*)allocator_->reMalloc(qkv_buf_2_, sizeof(T) * batch_size * max_q_len * local_hidden_units_, true);
qkv_buf_2_ = (T*)allocator_->reMalloc(
qkv_buf_2_, sizeof(T) * batch_size * max_q_len * local_head_num_ * size_per_head_, true);
}
// qkv_buf_3_ padding is removed
qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_hidden_units_, true);
qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_head_num_ * size_per_head_, true);
is_allocate_buffer_ = true;
}
......@@ -152,7 +158,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
//////////////////////////////////////////////
/// transpose qkv & apply rotary embedding & rebuild padding
/// qkv [B, s, 3, H, D] -> (q [B, H, s, D], k [B, H, s, D], v [B, H, s, D])
/// qkv [B, s, H + 2kvH, D] -> (q [B, H, s, D], k [B, kvH, s, D], v [B, kvH, s, D])
invokeAddFusedQKVBiasTranspose(q_buf_2_,
k_buf_2_,
v_buf_2_,
......@@ -165,6 +171,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
max_q_len, // seq_len
num_token, // batch_size * seq_len
local_head_num_,
local_kv_head_num_,
size_per_head_,
rotary_embedding_dim_,
neox_rotary_style_,
......@@ -173,16 +180,16 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
stream_);
sync_check_cuda_error();
const size_t layer_offset = layer_id * local_head_num_ * max_seq_len * size_per_head_;
const size_t layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_;
auto k_cache_ptrs = output_tensors->getPtr<T*>("key_cache");
auto v_cache_ptrs = output_tensors->getPtr<T*>("value_cache");
//////////////////////////////////////////////////////////
/// insert the k/v computed from inputs into k/v cache
/// transpose kv -> kv cache
// put k/v_buf from shape [B, H, s, D] to
// k_buf_2 [B, H, s, D] -> key_cache [B, H, S[t:t+s], D/x, x]
// v_buf_2 [B, H, s, D] -> val_cache [B, H, S[t:t+s], D/x, x]
// put k/v_buf from shape [B, kvH, s, D] to
// k_buf_2 [B, kvH, s, D] -> key_cache [B, kvH, S[t:t+s], D/x, x]
// v_buf_2 [B, kvH, s, D] -> val_cache [B, kvH, S[t:t+s], D/x, x]
invokeExtendKVCache(k_cache_ptrs,
v_cache_ptrs,
layer_offset,
......@@ -194,13 +201,14 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
history_length,
max_seq_len,
size_per_head_,
local_head_num_,
local_kv_head_num_,
stream_,
quant_policy_,
weights->past_kv_scale.data());
sync_check_cuda_error();
if (use_fmha_) {
FT_CHECK(local_head_num_ == local_kv_head_num_);
fusedMultiHeadAttention(k_cache_ptrs,
v_cache_ptrs,
layer_offset,
......@@ -311,8 +319,8 @@ void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T** key_c
int quant,
const float* kv_scale)
{
// key_cache [B, H, S[:t+s], D/x, x] -> [B, H, t+s, D]
// val_cache [B, H, S[:t+s], D/x, x] -> [B, H, t+s, D]
// key_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D]
// val_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D]
invokeTransposeKVCache(k_cache_buf_,
v_cache_buf_,
(const T**)key_cache_ptrs,
......@@ -324,6 +332,7 @@ void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T** key_c
max_seq_len,
size_per_head_,
local_head_num_,
head_n_rep_,
stream_,
quant,
kv_scale);
......
......@@ -35,6 +35,7 @@ public:
void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len);
LlamaContextAttentionLayer(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t rotary_embedding_dim,
bool neox_rotary_style,
......@@ -49,7 +50,8 @@ public:
size_per_head_(size_per_head),
hidden_units_(head_num * size_per_head),
local_head_num_(head_num / tensor_para.world_size_),
local_hidden_units_(hidden_units_ / tensor_para.world_size_),
local_kv_head_num_(kv_head_num / tensor_para.world_size_),
head_n_rep_(head_num / kv_head_num),
rotary_embedding_dim_(rotary_embedding_dim),
neox_rotary_style_(neox_rotary_style),
tensor_para_(tensor_para),
......@@ -61,6 +63,7 @@ public:
use_fmha_(use_fmha),
quant_policy_(quant_policy)
{
FT_CHECK(head_num % kv_head_num == 0);
}
void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaAttentionWeight<T>* weights);
......@@ -93,8 +96,9 @@ private:
const size_t head_num_;
const size_t size_per_head_;
const size_t hidden_units_;
const size_t local_kv_head_num_;
const size_t local_head_num_;
const size_t local_hidden_units_;
const size_t head_n_rep_;
const size_t rotary_embedding_dim_;
const bool is_free_buffer_after_forward_;
......
......@@ -62,11 +62,12 @@ void LlamaContextDecoder<T>::freeBuffer()
}
template<typename T>
void LlamaContextDecoder<T>::initialize(bool use_fmha, int quant_policy)
void LlamaContextDecoder<T>::initialize(size_t kv_head_num, bool use_fmha, int quant_policy)
{
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
context_attention_layer_ = new LlamaContextAttentionLayer<T>(head_num_,
kv_head_num,
size_per_head_,
rotary_embedding_dim_,
false, // neox_rotary_style
......@@ -124,6 +125,7 @@ void LlamaContextDecoder<T>::forwardSelfAttn(const Session&
template<typename T>
LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size,
size_t num_layer,
......@@ -147,7 +149,7 @@ LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num,
tensor_para_(tensor_para),
data_type_(getTensorType<T>())
{
initialize(use_fmha, quant_policy);
initialize(kv_head_num, use_fmha, quant_policy);
}
template<typename T>
......
......@@ -43,7 +43,7 @@ protected:
void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len);
void freeBuffer() override;
void initialize(bool use_fmha, int quant_policy);
void initialize(size_t kv_head_num, bool use_fmha, int quant_policy);
size_t head_num_;
size_t size_per_head_;
......@@ -88,6 +88,7 @@ protected:
public:
LlamaContextDecoder(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size,
size_t num_layer,
......
......@@ -28,6 +28,7 @@ namespace turbomind {
template<typename T>
LlamaDecoder<T>::LlamaDecoder(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size,
size_t num_layer,
......@@ -51,7 +52,7 @@ LlamaDecoder<T>::LlamaDecoder(size_t head_num,
data_type_(getTensorType<T>())
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
initialize(quant_policy);
initialize(kv_head_num, quant_policy);
}
template<typename T>
......@@ -63,11 +64,12 @@ LlamaDecoder<T>::~LlamaDecoder()
}
template<typename T>
void LlamaDecoder<T>::initialize(int quant_policy)
void LlamaDecoder<T>::initialize(size_t kv_head_num, int quant_policy)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
self_attention_layer_ = new LlamaDecoderSelfAttentionLayer<T>(head_num_,
kv_head_num,
size_per_head_,
rotary_embedding_dim_,
false, // neox_rotary_style
......
......@@ -35,7 +35,7 @@ protected:
void allocateBuffer() override; // deprecated
void allocateBuffer(size_t batch_size);
void freeBuffer() override;
void initialize(int quant_policy);
void initialize(size_t kv_head_num, int quant_policy);
size_t head_num_;
size_t size_per_head_;
......@@ -70,6 +70,7 @@ protected:
public:
LlamaDecoder(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size,
size_t num_layer,
......@@ -80,9 +81,9 @@ public:
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
bool is_free_buffer_after_forward,
int quant_policy),
int quant_policy);
~LlamaDecoder() override;
~LlamaDecoder() override;
virtual void forward(std::unordered_map<std::string, Tensor>* output_tensors,
const std::unordered_map<std::string, Tensor>* input_tensors,
......
......@@ -25,13 +25,18 @@
namespace turbomind {
template<typename T>
LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t hidden_units,
LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size,
WeightType weight_type,
bool attn_bias,
size_t tensor_para_size,
size_t tensor_para_rank):
hidden_units_(hidden_units),
head_num_(head_num),
kv_head_num_(kv_head_num),
size_per_head_(size_per_head),
hidden_units_(head_num * size_per_head),
inter_size_(inter_size),
weight_type_(weight_type),
attn_bias_(attn_bias),
......@@ -39,7 +44,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t hidden_units,
tensor_para_rank_(tensor_para_rank)
{
self_attn_weights.qkv.input_dims = hidden_units_;
self_attn_weights.qkv.output_dims = 3 * hidden_units_ / tensor_para_size_;
self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_;
self_attn_weights.qkv.type = weight_type;
self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_;
......
......@@ -28,7 +28,9 @@ template<typename T>
struct LlamaDecoderLayerWeight {
public:
LlamaDecoderLayerWeight() = delete;
LlamaDecoderLayerWeight(size_t hidden_units,
LlamaDecoderLayerWeight(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size,
WeightType weight_type,
bool attn_bias,
......@@ -46,6 +48,9 @@ public:
LlamaFfnWeight<T> ffn_weights{};
private:
size_t head_num_;
size_t kv_head_num_;
size_t size_per_head_;
size_t hidden_units_;
size_t inter_size_;
WeightType weight_type_;
......
......@@ -23,6 +23,7 @@
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/nvtx_utils.h"
#include <string>
// #include <glog/logging.h>
......@@ -56,6 +57,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
const int inference_batch_size,
const int beam_width,
const int head_num,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
const int memory_max_len,
......@@ -81,11 +83,11 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
// Prepare the parameters.
Masked_multihead_attention_params<DataType> params;
memset(&params, 0, sizeof(params));
int hidden_units = head_num * size_per_head;
// int hidden_units = head_num * size_per_head;
if (qkv_bias != nullptr) {
params.q_bias = reinterpret_cast<const DataType*>(qkv_bias);
params.k_bias = reinterpret_cast<const DataType*>(qkv_bias) + hidden_units;
params.v_bias = reinterpret_cast<const DataType*>(qkv_bias) + 2 * hidden_units;
params.k_bias = reinterpret_cast<const DataType*>(qkv_bias) + head_num * size_per_head;
params.v_bias = reinterpret_cast<const DataType*>(qkv_bias) + (head_num + kv_head_num) * size_per_head;
}
else {
params.q_bias = nullptr;
......@@ -97,13 +99,16 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
params.out = reinterpret_cast<DataType*>(context_buf);
// Set the input buffers.
// [B, nH + kvH, D]
params.q = reinterpret_cast<const DataType*>(qkv_buf);
params.k = reinterpret_cast<const DataType*>(qkv_buf) + hidden_units;
params.v = reinterpret_cast<const DataType*>(qkv_buf) + 2 * hidden_units;
params.k = reinterpret_cast<const DataType*>(qkv_buf) + head_num * size_per_head;
params.v = reinterpret_cast<const DataType*>(qkv_buf) + (head_num + kv_head_num) * size_per_head;
params.stride = 3 * hidden_units;
params.stride = (head_num + 2 * kv_head_num) * size_per_head;
params.finished = const_cast<bool*>(finished);
FT_CHECK(k_cache_per_sample && v_cache_per_sample);
params.k_cache = reinterpret_cast<DataType*>(key_cache);
params.v_cache = reinterpret_cast<DataType*>(value_cache);
params.k_cache_per_sample = reinterpret_cast<DataType**>(k_cache_per_sample);
......@@ -118,8 +123,10 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
params.max_prefix_prompt_length = max_prefix_prompt_length;
params.length_per_sample = sequence_lengths; // max_input_length + current output length
// timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation
params.timestep = step + max_prefix_prompt_length - 1;
params.num_heads = head_num;
params.timestep = step + max_prefix_prompt_length - 1;
params.num_heads = head_num;
params.num_kv_heads = kv_head_num;
params.hidden_size_per_head = size_per_head;
params.rotary_embedding_dim = rotary_embedding_dim;
// Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
......@@ -158,8 +165,11 @@ template<typename T>
void LlamaDecoderSelfAttentionLayer<T>::allocateBuffer(size_t batch_size, int key_len, int max_memory_len)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
qkv_buf_ =
reinterpret_cast<T*>(allocator_->reMalloc(qkv_buf_, sizeof(T) * batch_size * 3 * local_hidden_units_, false));
const size_t local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_;
qkv_buf_ = reinterpret_cast<T*>(
allocator_->reMalloc(qkv_buf_, sizeof(T) * batch_size * local_q_kv_head_num * size_per_head_, false));
context_buf_ =
reinterpret_cast<T*>(allocator_->reMalloc(context_buf_, sizeof(T) * batch_size * local_hidden_units_, false));
......@@ -197,7 +207,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
*
* output tensors:
* \param attention_output [batch_size, hidden_units],
* \param key_cache [batch, local_head_num, size_per_head / x, memory_max_len, x]
* \param key_cache [batch, local_head_num, memory_max_len, size_per_head]
* \param value_cache [batch, local_head_num, memory_max_len, size_per_head]
*/
......@@ -228,7 +238,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
POP_RANGE;
const auto kv_cache_layer_offset = layer_id * local_head_num_ * max_seq_len * size_per_head_;
const auto kv_cache_layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_;
const int memory_len = max_seq_len;
fusedQKV_masked_attention_dispatch<T>(
......@@ -248,6 +258,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
batch_size,
beam_width,
local_head_num_,
local_kv_head_num_,
size_per_head_,
rotary_embedding_dim_,
memory_len,
......
......@@ -34,6 +34,7 @@ public:
void allocateBuffer(size_t batch_size, int key_len, int max_memory_len);
LlamaDecoderSelfAttentionLayer(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t rotary_embedding_dim,
bool neox_rotary_style,
......@@ -44,9 +45,11 @@ public:
bool is_free_buffer_after_forward,
int quant_policy):
head_num_(head_num),
kv_head_num_(kv_head_num),
size_per_head_(size_per_head),
hidden_units_(head_num * size_per_head),
local_head_num_(head_num / tensor_para.world_size_),
local_kv_head_num_(kv_head_num_ / tensor_para.world_size_),
local_hidden_units_(hidden_units_ / tensor_para.world_size_),
rotary_embedding_dim_(rotary_embedding_dim),
neox_rotary_style_(neox_rotary_style),
......@@ -68,9 +71,11 @@ public:
private:
const size_t head_num_;
const size_t kv_head_num_;
const size_t size_per_head_;
const size_t hidden_units_;
const size_t local_head_num_;
const size_t local_kv_head_num_;
const size_t local_hidden_units_;
const size_t rotary_embedding_dim_;
const bool is_free_buffer_after_forward_;
......
......@@ -39,6 +39,7 @@ namespace turbomind {
template<typename T>
LlamaV2<T>::LlamaV2(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size,
size_t num_layer,
......@@ -102,8 +103,11 @@ LlamaV2<T>::LlamaV2(size_t head_num,
else {
elem_bits = sizeof(T) * 8;
}
const size_t local_kv_head_num = kv_head_num / tensor_para.world_size_;
kv_cache_mgr_ = std::make_unique<LlamaCacheManager>(num_layer_,
local_head_num_,
local_kv_head_num,
size_per_head_,
session_len,
elem_bits,
......@@ -111,7 +115,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
cache_chunk_size,
tensor_para.rank_,
allocator);
initialize(use_context_fmha, quant_policy);
initialize(kv_head_num, use_context_fmha, quant_policy);
start();
}
......@@ -126,11 +130,12 @@ LlamaV2<T>::~LlamaV2()
}
template<typename T>
void LlamaV2<T>::initialize(bool use_context_fmha, int quant_policy)
void LlamaV2<T>::initialize(size_t kv_head_num, bool use_context_fmha, int quant_policy)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
context_decoder_ = new LlamaContextDecoder<T>(head_num_,
kv_head_num,
size_per_head_,
inter_size_,
num_layer_,
......@@ -145,6 +150,7 @@ void LlamaV2<T>::initialize(bool use_context_fmha, int quant_policy)
quant_policy);
decoder_ = new LlamaDecoder<T>(head_num_,
kv_head_num,
size_per_head_,
inter_size_,
num_layer_,
......
......@@ -49,6 +49,7 @@ public:
~LlamaV2();
LlamaV2(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size,
size_t num_layer,
......@@ -90,7 +91,7 @@ private:
void internalThreadEntry(int device_id);
void initialize(bool use_context_fmha, int quant_policy);
void initialize(size_t kv_head_num, bool use_context_fmha, int quant_policy);
void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment