Commit c3f8b79b authored by wooway777's avatar wooway777
Browse files

issue/29 - further fixed tensor dimensions on qk_gemm and qk_softmax

parent 9038f7ab
......@@ -218,11 +218,11 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v);
// qk
rearrange(q_rearrange->slice(2, 0, seq_len), q);
auto qk_gemm = qk_buf->view({nkvh, ngroup * seq_len, total_len});
auto qk_gemm = qk_buf->slice(1, 0, seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len});
auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0});
linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr);
// softmax
auto qk_softmax = qk_buf->view({nh, seq_len, total_len});
auto qk_softmax = qk_buf->slice(1, 0, seq_len * total_len)->view({nh, seq_len, total_len});
causalSoftmax(qk_softmax, qk_softmax);
auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2});
linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment