Unverified Commit 424591d5 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix spec filter batch when target extend (#10991)

parent d1676cd4
......@@ -1736,7 +1736,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
if self.spec_info:
self.spec_info.filter_batch(keep_indices_device)
if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
has_been_filtered = False
else:
has_been_filtered = True
self.spec_info.filter_batch(
new_indices=keep_indices_device,
has_been_filtered=has_been_filtered,
)
def merge_batch(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
......
......@@ -405,7 +405,7 @@ class NgramVerifyInput:
return logits_output, self.verified_id, self.accept_length.sum().item()
def filter_batch(self, new_indices: torch.Tensor):
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
pass
def merge_batch(self, spec_info: NgramVerifyInput):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment