Commit c3d5efa5 authored by PanZezhong's avatar PanZezhong
Browse files

support qkv bias

parent 5540d53a
...@@ -77,10 +77,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -77,10 +77,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto dh = meta.dh; auto dh = meta.dh;
auto d = meta.d; auto d = meta.d;
auto dt_logits = meta.dt_logits; auto dt_logits = meta.dt_logits;
// std::cout << "dt_logits: " <<(int)dt_logits << std::endl;
auto di = meta.di / ndev; auto di = meta.di / ndev;
auto dvoc = meta.dvoc; auto dvoc = meta.dvoc;
auto stream = rsrc.stream; auto stream = rsrc.stream;
bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0;
// Allocate buffers // Allocate buffers
auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, stream); auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, stream);
...@@ -128,6 +128,12 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -128,6 +128,12 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
workspace_size = std::max(workspace_size, temp_size); workspace_size = std::max(workspace_size, temp_size);
// Attention // Attention
infiniopGemmDescriptor_t desc_attn_qkv, desc_attn_o; infiniopGemmDescriptor_t desc_attn_qkv, desc_attn_o;
infiniopRearrangeDescriptor_t desc_qkv_bias;
if (has_qkv_bias) {
RUN_INFINI(infiniopCreateRearrangeDescriptor(
rsrc.handle, &desc_qkv_bias, qkv_buf->desc()->get(),
TensorDesc::create(dt_logits, {ntok, (nh + nkvh * 2) * dh}, {0, 1})->get()));
}
RUN_INFINI(infiniopCreateGemmDescriptor( RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_attn_qkv, qkv_buf->desc()->get(), rsrc.handle, &desc_attn_qkv, qkv_buf->desc()->get(),
logits_in->desc()->get(), rsrc.w_attn_qkv[0]->desc()->get())); logits_in->desc()->get(), rsrc.w_attn_qkv[0]->desc()->get()));
...@@ -224,7 +230,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -224,7 +230,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
workspace_size = std::max(workspace_size, temp_size); workspace_size = std::max(workspace_size, temp_size);
// Allocate workspace // Allocate workspace
workspace = rsrc.workspace_allocator->alloc(workspace_size); workspace = rsrc.workspace_allocator->alloc(workspace_size);
// Compute // Compute
for (uint32_t layer = 0; layer < nlayer; layer++) { for (uint32_t layer = 0; layer < nlayer; layer++) {
// 1. Attention // 1. Attention
...@@ -234,6 +240,11 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -234,6 +240,11 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
logits_out->data(), logits_in->data(), logits_out->data(), logits_in->data(),
rsrc.w_attn_norm[layer]->data(), stream)); rsrc.w_attn_norm[layer]->data(), stream));
// qkv_proj // qkv_proj
if (has_qkv_bias) {
RUN_INFINI(infiniopRearrange(
desc_qkv_bias,
qkv_buf->data(), rsrc.b_attn_qkv.data(), stream));
}
RUN_INFINI(infiniopGemm( RUN_INFINI(infiniopGemm(
desc_attn_qkv, workspace, workspace_size, desc_attn_qkv, workspace, workspace_size,
qkv_buf->data(), logits_out->data(), qkv_buf->data(), logits_out->data(),
...@@ -347,6 +358,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -347,6 +358,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
// Clean up // Clean up
infiniopDestroyRMSNormDescriptor(desc_norm); infiniopDestroyRMSNormDescriptor(desc_norm);
infiniopDestroyRearrangeDescriptor(desc_qkv_bias);
infiniopDestroyGemmDescriptor(desc_attn_qkv); infiniopDestroyGemmDescriptor(desc_attn_qkv);
infiniopDestroyGemmDescriptor(desc_attn_o); infiniopDestroyGemmDescriptor(desc_attn_o);
infiniopDestroyRoPEDescriptor(desc_rope_q); infiniopDestroyRoPEDescriptor(desc_rope_q);
...@@ -354,6 +366,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -354,6 +366,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
for (uint32_t req = 0; req < nreq; req++) { for (uint32_t req = 0; req < nreq; req++) {
infiniopDestroyAttentionDescriptor(desc_attns[req]); infiniopDestroyAttentionDescriptor(desc_attns[req]);
} }
infiniopDestroyGemmDescriptor(desc_ffn_gate_up);
infiniopDestroySwiGLUDescriptor(desc_swiglu);
infiniopDestroyGemmDescriptor(desc_ffn_down);
infiniopDestroyRMSNormDescriptor(desc_norm_out); infiniopDestroyRMSNormDescriptor(desc_norm_out);
infiniopDestroyGemmDescriptor(desc_out_embd); infiniopDestroyGemmDescriptor(desc_out_embd);
infiniopDestroyRandomSampleDescriptor(desc_sample); infiniopDestroyRandomSampleDescriptor(desc_sample);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment