Unverified Commit bd7599d3 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[V1][TPU] Do not compile sampling more than needed (#15883)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 01b61136
......@@ -862,7 +862,9 @@ class TPUModelRunner:
out = self.model.sample_from_hidden(dummy_hidden,
sampling_meta)
out = out.cpu()
if num_reqs_to_sample >= self.max_num_reqs:
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
if num_reqs_to_sample >= min(num_tokens, self.max_num_reqs):
break
# Make sure to compile the `max_num_reqs` upper-limit case
num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment