Unverified Commit 434961c6 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

Fix cache/output length calculation (#738)

parent 6b00f623
...@@ -207,7 +207,6 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests) ...@@ -207,7 +207,6 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
auto& seq = *state.sequences[idx]; auto& seq = *state.sequences[idx];
if (int step = r->inputs[rank_].getVal<int>("step", -1); step >= 0) { if (int step = r->inputs[rank_].getVal<int>("step", -1); step >= 0) {
/// TODO: revise step setting
if (step <= seq.tokens.size()) { if (step <= seq.tokens.size()) {
seq.tokens.resize(step); seq.tokens.resize(step);
seq.cache_len = std::min(seq.cache_len, step); seq.cache_len = std::min(seq.cache_len, step);
...@@ -1258,7 +1257,17 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect ...@@ -1258,7 +1257,17 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
check_cuda_error(cudaStreamSynchronize(stream_)); check_cuda_error(cudaStreamSynchronize(stream_));
// invariant: context_length = sequence_length + 1 // `SequenceManager` needs real-time value of cache length
// ! Must be done before incrementing `h_context_length` because the generated token is NOT kv-cached yet
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]) {
FT_CHECK(state_->sequences[i]);
state_->sequences[i]->cache_len = state_->h_context_length[i];
}
}
// invariant: context_length = sequence_length + 1, so that h_context_length include all (including the one just
// generated) tokens
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
++state_->h_context_length[i]; ++state_->h_context_length[i];
} }
...@@ -1267,7 +1276,7 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect ...@@ -1267,7 +1276,7 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
int* output_ptr = h_output_ids_; int* output_ptr = h_output_ids_;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i] && (state_->requests[i]->stream_cb || state_->h_finished[i])) { if (state_->requests[i] && (state_->requests[i]->stream_cb || state_->h_finished[i])) {
const int count = state_->h_context_length[i] - 1 + int(g.step != g.max_init_ctx_len); const int count = state_->h_context_length[i];
// TODO: sync history output tokens at when receiving the request and copy only the last token here // TODO: sync history output tokens at when receiving the request and copy only the last token here
std::copy(output_ptr, output_ptr + count, h_request_output_ids_ptrs_[i]); std::copy(output_ptr, output_ptr + count, h_request_output_ids_ptrs_[i]);
*h_request_seqlen_ptrs_[i] = count; *h_request_seqlen_ptrs_[i] = count;
...@@ -1284,14 +1293,6 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect ...@@ -1284,14 +1293,6 @@ auto LlamaBatch<T>::Finish(GenerationState& g, int& finished_count) -> std::vect
TM_LOG_INFO("[finish] [%s]", ss.str().c_str()); TM_LOG_INFO("[finish] [%s]", ss.str().c_str());
} }
// `SequenceManager` needs real-time value of cache length
for (int i = 0; i < batch_size; ++i) {
if (state_->requests[i]) {
FT_CHECK(state_->sequences[i]);
state_->sequences[i]->cache_len = state_->h_context_length[i];
}
}
std::vector<Signal> signals; std::vector<Signal> signals;
{ {
NvtxScope _("stream_and_completion_signal"); NvtxScope _("stream_and_completion_signal");
...@@ -1343,8 +1344,7 @@ auto LlamaBatch<T>::Interrupt(int index, bool force_stop, bool force_end) -> Sig ...@@ -1343,8 +1344,7 @@ auto LlamaBatch<T>::Interrupt(int index, bool force_stop, bool force_end) -> Sig
FT_CHECK(sequence_manager_->Erase(state_->requests[index]->id)); FT_CHECK(sequence_manager_->Erase(state_->requests[index]->id));
} }
else { else {
// Account for the last generated token if not a stop request (which doesn't generate) const int output_len = state_->h_context_length[index];
const int output_len = state_->h_context_length[index] + 1 - static_cast<int>(force_stop);
auto& seq = *state_->sequences[index]; auto& seq = *state_->sequences[index];
// Update token IDs // Update token IDs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment