Unverified Commit 9752da9d authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Minor simplification for BadWordsState (#34669)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 04925b22
......@@ -39,18 +39,11 @@ class BadWordsState:
)
# number of bad words per request
self.num_bad_words = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
# whether request uses bad words
self.use_bad_words = np.zeros(self.max_num_reqs, dtype=bool)
def add_request(
self,
req_idx: int,
sampling_params: SamplingParams,
) -> None:
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
bad_words_token_ids = sampling_params.bad_words_token_ids
if not bad_words_token_ids:
self.num_bad_words.np[req_idx] = 0
self.use_bad_words[req_idx] = False
return
num_bad_words = len(bad_words_token_ids)
......@@ -77,7 +70,6 @@ class BadWordsState:
self.bad_word_token_ids.stage_write(req_idx, 0, flattened_tokens)
self.bad_word_offsets.stage_write(req_idx, 0, offsets)
self.num_bad_words.np[req_idx] = num_bad_words
self.use_bad_words[req_idx] = True
def apply_staged_writes(self) -> None:
self.num_bad_words.copy_to_uva()
......@@ -92,11 +84,11 @@ class BadWordsState:
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> None:
if not np.any(self.use_bad_words[idx_mapping_np]):
max_num_bad_words = int(self.num_bad_words.np[idx_mapping_np].max())
if max_num_bad_words == 0:
# No request uses bad words. Skip the kernel launch.
return
actual_max_num_bad_words = int(np.max(self.num_bad_words.np[idx_mapping_np]))
apply_bad_words(
logits,
idx_mapping,
......@@ -108,7 +100,7 @@ class BadWordsState:
self.total_len,
input_ids,
expanded_local_pos,
actual_max_num_bad_words,
max_num_bad_words,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment