Commit c3d5efa5 authored by PanZezhong's avatar PanZezhong
Browse files

support qkv bias

parent 5540d53a
......@@ -77,10 +77,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto dh = meta.dh;
auto d = meta.d;
auto dt_logits = meta.dt_logits;
// std::cout << "dt_logits: " <<(int)dt_logits << std::endl;
auto di = meta.di / ndev;
auto dvoc = meta.dvoc;
auto stream = rsrc.stream;
bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0;
// Allocate buffers
auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, stream);
......@@ -128,6 +128,12 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
workspace_size = std::max(workspace_size, temp_size);
// Attention
infiniopGemmDescriptor_t desc_attn_qkv, desc_attn_o;
infiniopRearrangeDescriptor_t desc_qkv_bias;
if (has_qkv_bias) {
RUN_INFINI(infiniopCreateRearrangeDescriptor(
rsrc.handle, &desc_qkv_bias, qkv_buf->desc()->get(),
TensorDesc::create(dt_logits, {ntok, (nh + nkvh * 2) * dh}, {0, 1})->get()));
}
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_attn_qkv, qkv_buf->desc()->get(),
logits_in->desc()->get(), rsrc.w_attn_qkv[0]->desc()->get()));
......@@ -224,7 +230,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
workspace_size = std::max(workspace_size, temp_size);
// Allocate workspace
workspace = rsrc.workspace_allocator->alloc(workspace_size);
// Compute
for (uint32_t layer = 0; layer < nlayer; layer++) {
// 1. Attention
......@@ -234,6 +240,11 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
logits_out->data(), logits_in->data(),
rsrc.w_attn_norm[layer]->data(), stream));
// qkv_proj
if (has_qkv_bias) {
RUN_INFINI(infiniopRearrange(
desc_qkv_bias,
qkv_buf->data(), rsrc.b_attn_qkv.data(), stream));
}
RUN_INFINI(infiniopGemm(
desc_attn_qkv, workspace, workspace_size,
qkv_buf->data(), logits_out->data(),
......@@ -347,6 +358,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
// Clean up
infiniopDestroyRMSNormDescriptor(desc_norm);
infiniopDestroyRearrangeDescriptor(desc_qkv_bias);
infiniopDestroyGemmDescriptor(desc_attn_qkv);
infiniopDestroyGemmDescriptor(desc_attn_o);
infiniopDestroyRoPEDescriptor(desc_rope_q);
......@@ -354,6 +366,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
for (uint32_t req = 0; req < nreq; req++) {
infiniopDestroyAttentionDescriptor(desc_attns[req]);
}
infiniopDestroyGemmDescriptor(desc_ffn_gate_up);
infiniopDestroySwiGLUDescriptor(desc_swiglu);
infiniopDestroyGemmDescriptor(desc_ffn_down);
infiniopDestroyRMSNormDescriptor(desc_norm_out);
infiniopDestroyGemmDescriptor(desc_out_embd);
infiniopDestroyRandomSampleDescriptor(desc_sample);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment