Unverified Commit fbb4754c authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix vocab mask update bug (#1376)

parent 6c7cb903
...@@ -652,8 +652,6 @@ class ScheduleBatch: ...@@ -652,8 +652,6 @@ class ScheduleBatch:
self.req_pool_indices, self.seq_lens - 1 self.req_pool_indices, self.seq_lens - 1
] = self.out_cache_loc ] = self.out_cache_loc
self.sampling_info.update_regex_vocab_mask(self)
def filter_batch(self, unfinished_indices: List[int]): def filter_batch(self, unfinished_indices: List[int]):
if unfinished_indices is None or len(unfinished_indices) == 0: if unfinished_indices is None or len(unfinished_indices) == 0:
# Filter out all requests # Filter out all requests
......
...@@ -195,7 +195,8 @@ class InputMetadata: ...@@ -195,7 +195,8 @@ class InputMetadata:
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
) )
ret.sampling_info.prepare_penalties() ret.sampling_info.update_penalties()
ret.sampling_info.update_regex_vocab_mask(batch)
ret.compute_positions(batch) ret.compute_positions(batch)
......
...@@ -34,6 +34,9 @@ class SamplingBatchInfo: ...@@ -34,6 +34,9 @@ class SamplingBatchInfo:
linear_penalties: torch.Tensor = None linear_penalties: torch.Tensor = None
scaling_penalties: torch.Tensor = None scaling_penalties: torch.Tensor = None
def __len__(self):
return len(self.temperatures)
def can_run_in_cuda_graph(self): def can_run_in_cuda_graph(self):
# Vocab bias and min_ps are not supported in CUDA graph # Vocab bias and min_ps are not supported in CUDA graph
return ( return (
...@@ -118,11 +121,9 @@ class SamplingBatchInfo: ...@@ -118,11 +121,9 @@ class SamplingBatchInfo:
# Handle logit bias but only allocate when needed # Handle logit bias but only allocate when needed
ret.logit_bias = None ret.logit_bias = None
ret.update_regex_vocab_mask(batch)
return ret return ret
def prepare_penalties(self): def update_penalties(self):
self.scaling_penalties = None self.scaling_penalties = None
self.linear_penalties = None self.linear_penalties = None
...@@ -174,6 +175,26 @@ class SamplingBatchInfo: ...@@ -174,6 +175,26 @@ class SamplingBatchInfo:
if self_val is not None: # logit_bias can be None if self_val is not None: # logit_bias can be None
setattr(self, item, self_val[new_indices]) setattr(self, item, self_val[new_indices])
@staticmethod
def merge_bias_tensor(
lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0
):
# bias tensor can be None
if lhs is not None or rhs is not None:
shape, dtype = None, None
if lhs is not None:
shape, dtype = lhs.shape[1:], lhs.dtype
else:
shape, dtype = rhs.shape[1:], rhs.dtype
with torch.dtype(dtype):
if lhs is None:
lhs = torch.empty((bs1, *shape), device="cuda").fill_(default)
if rhs is None:
rhs = torch.empty((bs2, *shape), device="cuda").fill_(default)
return torch.cat([lhs, rhs])
return None
def merge(self, other: "SamplingBatchInfo"): def merge(self, other: "SamplingBatchInfo"):
self.penalizer_orchestrator.merge(other.penalizer_orchestrator) self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
...@@ -187,19 +208,6 @@ class SamplingBatchInfo: ...@@ -187,19 +208,6 @@ class SamplingBatchInfo:
other_val = getattr(other, item, None) other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val])) setattr(self, item, torch.concat([self_val, other_val]))
# logit_bias can be None self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
if self.logit_bias is not None or other.logit_bias is not None: self.logit_bias, other.logit_bias, len(self), len(other)
vocab_size = ( )
self.logit_bias.shape[1]
if self.logit_bias is not None
else other.logit_bias.shape[1]
)
if self.logit_bias is None:
self.logit_bias = torch.zeros(
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
)
if other.logit_bias is None:
other.logit_bias = torch.zeros(
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
)
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment