"tutorials/vscode:/vscode.git/clone" did not exist on "53b9a4bdbc36e5253adfbb780dacccffa66c4fb7"
Unverified Commit 47f20da2 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix regex mask (#1296)

parent 4a9f8ea4
...@@ -63,7 +63,7 @@ class Sampler(CustomOp): ...@@ -63,7 +63,7 @@ class Sampler(CustomOp):
logits.add_(sampling_info.logit_bias) logits.add_(sampling_info.logit_bias)
if sampling_info.vocab_mask is not None: if sampling_info.vocab_mask is not None:
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf")) logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
logits = self._apply_penalties(logits, sampling_info) logits = self._apply_penalties(logits, sampling_info)
......
...@@ -154,15 +154,15 @@ class SamplingBatchInfo: ...@@ -154,15 +154,15 @@ class SamplingBatchInfo:
self.vocab_mask = None self.vocab_mask = None
if has_regex: if has_regex:
self.vocab_mask = torch.zeros(
bs, self.vocab_size, dtype=torch.bool, device=device
)
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
if req.regex_fsm is not None: if req.regex_fsm is not None:
if self.vocab_mask is None: self.vocab_mask[i].fill_(1)
self.vocab_mask = torch.zeros(
bs, self.vocab_size, dtype=torch.bool, device=device
)
self.vocab_mask[i][ self.vocab_mask[i][
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
] = 1 ] = 0
def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor): def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices) self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment