Commit 61ae257c authored by wooway777's avatar wooway777
Browse files

issue/37 - fixed an inappropriate qk buffer slicing

parent e753b372
......@@ -181,7 +181,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
max_seq_len = std::max(max_seq_len, size_t(seq_len));
}
auto qk_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool);
auto qk_buf = Tensor::buffer(dt_logits, {nh * max_qk_size}, rsrc.memory_pool);
auto rearrange_q_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
auto q_rearrange = rearrange_q_buf->view({nkvh, ngroup, max_seq_len, dh});
auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool);
......@@ -218,11 +218,11 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v);
// qk
rearrange(q_rearrange->slice(2, 0, seq_len), q);
auto qk_gemm = qk_buf->slice(1, 0, seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len});
auto qk_gemm = qk_buf->slice(0, 0, nh * seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len});
auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0});
linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr);
// softmax
auto qk_softmax = qk_buf->slice(1, 0, seq_len * total_len)->view({nh, seq_len, total_len});
auto qk_softmax = qk_gemm->view({nh, seq_len, total_len});
causalSoftmax(qk_softmax, qk_softmax);
auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2});
linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment