Unverified Commit 434961c6 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

Fix cache/output length calculation (#738)

parent 6b00f623
......@@ -207,7 +207,6 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
auto& seq = *state.sequences[idx];
if (int step = r->inputs[rank_].getVal<int>("step", -1); step >= 0) {
/// TODO: revise step setting
if (step <= seq.tokens.size()) {
seq.tokens.resize(step);
seq.cache_len = std::min(seq.cache_len, step);
......@@ -1258,7 +1257,17 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
check_cuda_error(cudaStreamSynchronize(stream_));
// invariant: context_length = sequence_length + 1
// `SequenceManager` needs real-time value of cache length
// ! Must be done before incrementing `h_context_length` because the generated token is NOT kv-cached yet
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]) {
FT_CHECK(state_->sequences[i]);
state_->sequences[i]->cache_len = state_->h_context_length[i];
}
}
// invariant: context_length = sequence_length + 1, so that h_context_length include all (including the one just
// generated) tokens
for (int i = 0; i < batch_size; ++i) {
++state_->h_context_length[i];
}
......@@ -1267,7 +1276,7 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
int* output_ptr = h_output_ids_;
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i] && (state_->requests[i]->stream_cb || state_->h_finished[i])) {
const int count = state_->h_context_length[i] - 1 + int(g.step != g.max_init_ctx_len);
const int count = state_->h_context_length[i];
// TODO: sync history output tokens at when receiving the request and copy only the last token here
std::copy(output_ptr, output_ptr + count, h_request_output_ids_ptrs_[i]);
*h_request_seqlen_ptrs_[i] = count;
......@@ -1284,14 +1293,6 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
TM_LOG_INFO("[finish] [%s]", ss.str().c_str());
}
// `SequenceManager` needs real-time value of cache length
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]) {
FT_CHECK(state_->sequences[i]);
state_->sequences[i]->cache_len = state_->h_context_length[i];
}
}
std::vector<Signal> signals;
{
NvtxScope _("stream_and_completion_signal");
......@@ -1343,8 +1344,7 @@ auto LlamaBatch<T>::Interrupt(int index, bool force_stop, bool force_end) -> Sig
FT_CHECK(sequence_manager_->Erase(state_->requests[index]->id));
}
else {
// Account for the last generated token if not a stop request (which doesn't generate)
const int output_len = state_->h_context_length[index] + 1 - static_cast<int>(force_stop);
const int output_len = state_->h_context_length[index];
auto& seq = *state_->sequences[index];
// Update token IDs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment