Unverified Commit 4eb8dd83 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

Fix init of batch state (#682)

* fix init of finished buf

* fix `finished_count`
parent b7c88ca8
......@@ -355,7 +355,9 @@ bool LlamaBatch<T>::Initialize()
});
// all blocks are not enough to hold a single sequence
if (!sequences.empty()) {
FT_CHECK_WITH_INFO(active_end != idxs.begin(), "No enough blocks.");
}
// move swap-ins to the back
auto swapin_beg = std::stable_partition(idxs.begin(), active_end, [&](int idx) {
......@@ -398,6 +400,8 @@ bool LlamaBatch<T>::Initialize()
ClearState(*incoming_);
}
FT_CHECK(state_->size <= max_batch_size_);
/// Update block ptrs when there were
// 1. swap-in or swap-out
// 2. holes in the active buffer
......@@ -810,9 +814,6 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
// for
for (int i = 0; i < batch_size; ++i) {
h_seq_limit_len_[i] = state_->seq_len_limit[i] + (max_context_len - state_->h_context_length[i]);
if (max_context_len >= h_seq_limit_len_[i]) { // mask finished sequences
state_->h_finished[i] = true;
}
}
Copy(h_seq_limit_len_, batch_size, seq_limit_len_);
Copy(state_->h_finished, batch_size, finished_buf_);
......@@ -1402,6 +1403,8 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
shared_state->barrier->wait();
auto modified = Initialize();
// finished sequences is handled by `Initialize()`
finished_count = 0;
ContextDecode();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment