Unverified Commit 9752da9d authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Minor simplification for BadWordsState (#34669)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 04925b22
...@@ -39,18 +39,11 @@ class BadWordsState: ...@@ -39,18 +39,11 @@ class BadWordsState:
) )
# number of bad words per request # number of bad words per request
self.num_bad_words = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32) self.num_bad_words = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
# whether request uses bad words
self.use_bad_words = np.zeros(self.max_num_reqs, dtype=bool)
def add_request( def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
self,
req_idx: int,
sampling_params: SamplingParams,
) -> None:
bad_words_token_ids = sampling_params.bad_words_token_ids bad_words_token_ids = sampling_params.bad_words_token_ids
if not bad_words_token_ids: if not bad_words_token_ids:
self.num_bad_words.np[req_idx] = 0 self.num_bad_words.np[req_idx] = 0
self.use_bad_words[req_idx] = False
return return
num_bad_words = len(bad_words_token_ids) num_bad_words = len(bad_words_token_ids)
...@@ -77,7 +70,6 @@ class BadWordsState: ...@@ -77,7 +70,6 @@ class BadWordsState:
self.bad_word_token_ids.stage_write(req_idx, 0, flattened_tokens) self.bad_word_token_ids.stage_write(req_idx, 0, flattened_tokens)
self.bad_word_offsets.stage_write(req_idx, 0, offsets) self.bad_word_offsets.stage_write(req_idx, 0, offsets)
self.num_bad_words.np[req_idx] = num_bad_words self.num_bad_words.np[req_idx] = num_bad_words
self.use_bad_words[req_idx] = True
def apply_staged_writes(self) -> None: def apply_staged_writes(self) -> None:
self.num_bad_words.copy_to_uva() self.num_bad_words.copy_to_uva()
...@@ -92,11 +84,11 @@ class BadWordsState: ...@@ -92,11 +84,11 @@ class BadWordsState:
input_ids: torch.Tensor, input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor, expanded_local_pos: torch.Tensor,
) -> None: ) -> None:
if not np.any(self.use_bad_words[idx_mapping_np]): max_num_bad_words = int(self.num_bad_words.np[idx_mapping_np].max())
if max_num_bad_words == 0:
# No request uses bad words. Skip the kernel launch. # No request uses bad words. Skip the kernel launch.
return return
actual_max_num_bad_words = int(np.max(self.num_bad_words.np[idx_mapping_np]))
apply_bad_words( apply_bad_words(
logits, logits,
idx_mapping, idx_mapping,
...@@ -108,7 +100,7 @@ class BadWordsState: ...@@ -108,7 +100,7 @@ class BadWordsState:
self.total_len, self.total_len,
input_ids, input_ids,
expanded_local_pos, expanded_local_pos,
actual_max_num_bad_words, max_num_bad_words,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment