Unverified Commit a7c5007c authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

[Fix] Skip empty batch (#747)

parent d3386351
...@@ -475,6 +475,10 @@ bool LlamaBatch<T>::Initialize() ...@@ -475,6 +475,10 @@ bool LlamaBatch<T>::Initialize()
template<typename T> template<typename T>
void LlamaBatch<T>::CopyState(const std::vector<std::tuple<BatchState*, BatchState*, int, int>>& desc) void LlamaBatch<T>::CopyState(const std::vector<std::tuple<BatchState*, BatchState*, int, int>>& desc)
{ {
if (desc.empty()) {
return;
}
std::vector<int> idxs(desc.size()); std::vector<int> idxs(desc.size());
std::iota(idxs.begin(), idxs.end(), 0); std::iota(idxs.begin(), idxs.end(), 0);
...@@ -1430,18 +1434,21 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id) ...@@ -1430,18 +1434,21 @@ void LlamaBatch<T>::InternalThreadEntry(int device_id)
// finished sequences is handled by `Initialize()` // finished sequences is handled by `Initialize()`
finished_count = 0; finished_count = 0;
ContextDecode();
if (state_->active_size) { if (state_->active_size) {
ContextDecode();
if (modified) { if (modified) {
g = InitializeGeneration(); g = InitializeGeneration();
InitializeSampling(); InitializeSampling();
} }
for (int i = 0; i < step_length_; ++i) { for (int i = 0; i < step_length_; ++i) {
if (!Generate(g)) { if (!Generate(g)) {
break; break;
} }
} }
if (auto signals = Finish(g, finished_count); !signals.empty()) { if (auto signals = Finish(g, finished_count); !signals.empty()) {
if (finished_count) { if (finished_count) {
// Finished requests and corresponding output tensors will be released when notified // Finished requests and corresponding output tensors will be released when notified
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment