"src/array/cpu/negative_sampling.cc" did not exist on "01bec4a31d1a3135af33ac69cefaf7ceecbaf7b3"
Unverified Commit 317631ca authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Fix] Move ScheduleBatch out of SamplingInfo (#1556)

parent b5648353
...@@ -423,10 +423,14 @@ class ScheduleBatch: ...@@ -423,10 +423,14 @@ class ScheduleBatch:
# Stream # Stream
has_stream: bool = False has_stream: bool = False
# Has regex
has_regex: bool = False
@classmethod @classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
return_logprob = any(req.return_logprob for req in reqs) return_logprob = any(req.return_logprob for req in reqs)
has_stream = any(req.stream for req in reqs) has_stream = any(req.stream for req in reqs)
has_regex = any(req.regex_fsm for req in reqs)
return cls( return cls(
reqs=reqs, reqs=reqs,
...@@ -435,6 +439,7 @@ class ScheduleBatch: ...@@ -435,6 +439,7 @@ class ScheduleBatch:
tree_cache=tree_cache, tree_cache=tree_cache,
return_logprob=return_logprob, return_logprob=return_logprob,
has_stream=has_stream, has_stream=has_stream,
has_regex=has_regex,
) )
def batch_size(self): def batch_size(self):
...@@ -750,7 +755,9 @@ class ScheduleBatch: ...@@ -750,7 +755,9 @@ class ScheduleBatch:
] ]
else: else:
self.top_logprobs_nums = None self.top_logprobs_nums = None
self.has_stream = any(req.stream for req in self.reqs) self.has_stream = any(req.stream for req in self.reqs)
self.has_regex = any(req.regex_fsm for req in self.reqs)
self.sampling_info.filter_batch(unfinished_indices, new_indices) self.sampling_info.filter_batch(unfinished_indices, new_indices)
...@@ -771,9 +778,11 @@ class ScheduleBatch: ...@@ -771,9 +778,11 @@ class ScheduleBatch:
self.top_logprobs_nums.extend([0] * len(other.reqs)) self.top_logprobs_nums.extend([0] * len(other.reqs))
elif other.return_logprob: elif other.return_logprob:
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.has_stream = any(req.stream for req in self.reqs)
self.reqs.extend(other.reqs) self.reqs.extend(other.reqs)
self.return_logprob = self.return_logprob or other.return_logprob self.return_logprob = self.return_logprob or other.return_logprob
self.has_stream = self.has_stream or other.has_stream
self.has_regex = self.has_regex or other.has_regex
def get_model_worker_batch(self): def get_model_worker_batch(self):
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
...@@ -787,7 +796,11 @@ class ScheduleBatch: ...@@ -787,7 +796,11 @@ class ScheduleBatch:
image_inputs = [r.image_inputs for r in self.reqs] image_inputs = [r.image_inputs for r in self.reqs]
lora_paths = [req.lora_path for req in self.reqs] lora_paths = [req.lora_path for req in self.reqs]
self.sampling_info.regex_fsm_states = [req.regex_fsm_state for req in self.reqs] if self.has_regex:
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
self.sampling_info.regex_fsm_states = [
req.regex_fsm_state for req in self.reqs
]
return ModelWorkerBatch( return ModelWorkerBatch(
forward_mode=self.forward_mode, forward_mode=self.forward_mode,
......
...@@ -84,10 +84,6 @@ class SamplingBatchInfo: ...@@ -84,10 +84,6 @@ class SamplingBatchInfo:
# Handle logit bias but only allocate when needed # Handle logit bias but only allocate when needed
ret.logit_bias = None ret.logit_bias = None
# This is only for regex_fsm. We notice a regression if we maintain the list of regex_fsm
# in SamplingBatchInfo, so we keep it here.
ret.schedule_batch = batch
return ret return ret
def __len__(self): def __len__(self):
...@@ -113,7 +109,7 @@ class SamplingBatchInfo: ...@@ -113,7 +109,7 @@ class SamplingBatchInfo:
self.linear_penalties = penalizer.apply(self.linear_penalties) self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self): def update_regex_vocab_mask(self):
has_regex = any(req.regex_fsm is not None for req in self.schedule_batch.reqs) has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
# Reset the vocab mask # Reset the vocab mask
self.vocab_mask = None self.vocab_mask = None
...@@ -122,11 +118,11 @@ class SamplingBatchInfo: ...@@ -122,11 +118,11 @@ class SamplingBatchInfo:
self.vocab_mask = torch.zeros( self.vocab_mask = torch.zeros(
len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda" len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda"
) )
for i, req in enumerate(self.schedule_batch.reqs): for i, regex_fsm in enumerate(self.regex_fsms):
if req.regex_fsm is not None: if regex_fsm is not None:
self.vocab_mask[i].fill_(1) self.vocab_mask[i].fill_(1)
self.vocab_mask[i][ self.vocab_mask[i][
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
] = 0 ] = 0
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment