Unverified Commit 0d2a151e authored by akhoroshev's avatar akhoroshev Committed by GitHub
Browse files

[bug] fix mismatched shape for decoder output tensor (#517)

parent 169d088a
......@@ -256,7 +256,7 @@ void LlamaV2<T>::contextDecode(T* deocder_output,
};
std::unordered_map<std::string, Tensor> decoder_output_tensors{
{"decoder_output", {MEMORY_GPU, dtype, {bsz, max_input_len, hidden_units_}, context_decoder_output_buf}},
{"decoder_output", {MEMORY_GPU, dtype, {token_num, hidden_units_}, context_decoder_output_buf}},
{"key_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, k_cache_ptr}},
{"value_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, v_cache_ptr}},
{"last_token_hidden_units", {MEMORY_GPU, dtype, {bsz, hidden_units_}, deocder_output}}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment