Unverified Commit e753b372 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #30 from InfiniTensor/issue/29

issue/29 - fixing tensor length mismatches across requests
parents c41055b5 c3f8b79b
...@@ -217,17 +217,17 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, ...@@ -217,17 +217,17 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k); rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k);
rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v); rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v);
// qk // qk
rearrange(q_rearrange, q); rearrange(q_rearrange->slice(2, 0, seq_len), q);
auto qk_gemm = qk_buf->view({nkvh, ngroup * seq_len, total_len}); auto qk_gemm = qk_buf->slice(1, 0, seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len});
auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0});
linear(qk_gemm, rearrange_q_buf, k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr);
// softmax // softmax
auto qk_softmax = qk_buf->view({nh, seq_len, total_len}); auto qk_softmax = qk_buf->slice(1, 0, seq_len * total_len)->view({nh, seq_len, total_len});
causalSoftmax(qk_softmax, qk_softmax); causalSoftmax(qk_softmax, qk_softmax);
auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2});
linear(attn_val_buf, qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr); linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr);
// rearrange attn val // rearrange attn val
rearrange(o, attn_val_gemm); rearrange(o, attn_val_gemm->slice(2, 0, seq_len));
token_offset += seq_len; token_offset += seq_len;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment