Unverified Commit 2bc4be4e authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V1][Minor] Simplify rejection sampler's parse_output (#15741)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent c67abd61
......@@ -107,7 +107,6 @@ class RejectionSampler(nn.Module):
@staticmethod
def parse_output(
output_token_ids: torch.Tensor,
ignored_req_idxs: list[int],
vocab_size: int,
) -> list[list[int]]:
"""Parse the output of the rejection sampler.
......@@ -117,9 +116,6 @@ class RejectionSampler(nn.Module):
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
ignored_req_idxs: The indices of the requests that should not be
sampled. This is usually because the request is still in the
prefill phase.
vocab_size: The size of the vocabulary.
Returns:
......@@ -129,11 +125,8 @@ class RejectionSampler(nn.Module):
# Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size))
ignored_req_idx_set = set(ignored_req_idxs)
outputs = [
row[valid_mask[i]].tolist()
if i not in ignored_req_idx_set else []
for i, row in enumerate(output_token_ids_np)
]
return outputs
......
......@@ -1121,16 +1121,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
discard_sampled_tokens_req_indices,
self.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
if not self.use_spec_decode:
spec_token_ids = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment