Unverified Commit 4eb8dd83 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

Fix init of batch state (#682)

* fix init of finished buf

* fix `finished_count`
parent b7c88ca8
...@@ -355,7 +355,9 @@ bool LlamaBatch<T>::Initialize() ...@@ -355,7 +355,9 @@ bool LlamaBatch<T>::Initialize()
}); });
// all blocks are not enough to hold a single sequence // all blocks are not enough to hold a single sequence
FT_CHECK_WITH_INFO(active_end != idxs.begin(), "No enough blocks."); if (!sequences.empty()) {
FT_CHECK_WITH_INFO(active_end != idxs.begin(), "No enough blocks.");
}
// move swap-ins to the back // move swap-ins to the back
auto swapin_beg = std::stable_partition(idxs.begin(), active_end, [&](int idx) { auto swapin_beg = std::stable_partition(idxs.begin(), active_end, [&](int idx) {
...@@ -398,6 +400,8 @@ bool LlamaBatch<T>::Initialize() ...@@ -398,6 +400,8 @@ bool LlamaBatch<T>::Initialize()
ClearState(*incoming_); ClearState(*incoming_);
} }
FT_CHECK(state_->size <= max_batch_size_);
/// Update block ptrs when there were /// Update block ptrs when there were
// 1. swap-in or swap-out // 1. swap-in or swap-out
// 2. holes in the active buffer // 2. holes in the active buffer
...@@ -810,9 +814,6 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState ...@@ -810,9 +814,6 @@ auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
// for // for
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
h_seq_limit_len_[i] = state_->seq_len_limit[i] + (max_context_len - state_->h_context_length[i]); h_seq_limit_len_[i] = state_->seq_len_limit[i] + (max_context_len - state_->h_context_length[i]);
if (max_context_len >= h_seq_limit_len_[i]) { // mask finished sequences
state_->h_finished[i] = true;
}
} }
Copy(h_seq_limit_len_, batch_size, seq_limit_len_); Copy(h_seq_limit_len_, batch_size, seq_limit_len_);
Copy(state_->h_finished, batch_size, finished_buf_); Copy(state_->h_finished, batch_size, finished_buf_);
...@@ -1402,6 +1403,8 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id) ...@@ -1402,6 +1403,8 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
shared_state->barrier->wait(); shared_state->barrier->wait();
auto modified = Initialize(); auto modified = Initialize();
// finished sequences is handled by `Initialize()`
finished_count = 0;
ContextDecode(); ContextDecode();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment