Unverified Commit fbb4754c authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix vocab mask update bug (#1376)

parent 6c7cb903
......@@ -652,8 +652,6 @@ class ScheduleBatch:
self.req_pool_indices, self.seq_lens - 1
] = self.out_cache_loc
self.sampling_info.update_regex_vocab_mask(self)
def filter_batch(self, unfinished_indices: List[int]):
if unfinished_indices is None or len(unfinished_indices) == 0:
# Filter out all requests
......
......@@ -195,7 +195,8 @@ class InputMetadata:
top_logprobs_nums=batch.top_logprobs_nums,
)
ret.sampling_info.prepare_penalties()
ret.sampling_info.update_penalties()
ret.sampling_info.update_regex_vocab_mask(batch)
ret.compute_positions(batch)
......
......@@ -34,6 +34,9 @@ class SamplingBatchInfo:
linear_penalties: torch.Tensor = None
scaling_penalties: torch.Tensor = None
def __len__(self):
return len(self.temperatures)
def can_run_in_cuda_graph(self):
# Vocab bias and min_ps are not supported in CUDA graph
return (
......@@ -118,11 +121,9 @@ class SamplingBatchInfo:
# Handle logit bias but only allocate when needed
ret.logit_bias = None
ret.update_regex_vocab_mask(batch)
return ret
def prepare_penalties(self):
def update_penalties(self):
self.scaling_penalties = None
self.linear_penalties = None
......@@ -174,6 +175,26 @@ class SamplingBatchInfo:
if self_val is not None: # logit_bias can be None
setattr(self, item, self_val[new_indices])
@staticmethod
def merge_bias_tensor(
lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0
):
# bias tensor can be None
if lhs is not None or rhs is not None:
shape, dtype = None, None
if lhs is not None:
shape, dtype = lhs.shape[1:], lhs.dtype
else:
shape, dtype = rhs.shape[1:], rhs.dtype
with torch.dtype(dtype):
if lhs is None:
lhs = torch.empty((bs1, *shape), device="cuda").fill_(default)
if rhs is None:
rhs = torch.empty((bs2, *shape), device="cuda").fill_(default)
return torch.cat([lhs, rhs])
return None
def merge(self, other: "SamplingBatchInfo"):
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
......@@ -187,19 +208,6 @@ class SamplingBatchInfo:
other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val]))
# logit_bias can be None
if self.logit_bias is not None or other.logit_bias is not None:
vocab_size = (
self.logit_bias.shape[1]
if self.logit_bias is not None
else other.logit_bias.shape[1]
)
if self.logit_bias is None:
self.logit_bias = torch.zeros(
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
)
if other.logit_bias is None:
other.logit_bias = torch.zeros(
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
)
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment