Unverified Commit b3f2fddd authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[TPU][V1] Fix exponential padding when `max-num-batched-tokens` is not a power of 2 (#16596)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent aa29841e
...@@ -299,6 +299,18 @@ def test_get_paddings(): ...@@ -299,6 +299,18 @@ def test_get_paddings():
actual_paddings = _get_token_paddings(min_token_size, max_token_size, actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap) padding_gap)
assert actual_paddings == expected_paddings assert actual_paddings == expected_paddings
# Exponential padding.
max_token_size, padding_gap = 1024, 0
expected_paddings = [16, 32, 64, 128, 256, 512, 1024]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)
assert actual_paddings == expected_paddings
# Exponential padding with max_token_size not a power of two.
max_token_size = 317
expected_paddings = [16, 32, 64, 128, 256, 512]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)
assert actual_paddings == expected_paddings
def test_get_padded_token_len(): def test_get_padded_token_len():
......
...@@ -1040,9 +1040,11 @@ def _get_token_paddings(min_token_size: int, max_token_size: int, ...@@ -1040,9 +1040,11 @@ def _get_token_paddings(min_token_size: int, max_token_size: int,
if padding_gap == 0: if padding_gap == 0:
logger.info("Using exponential token paddings:") logger.info("Using exponential token paddings:")
while num <= max_token_size: while True:
logger.info(" %d", num) logger.info(" %d", num)
paddings.append(num) paddings.append(num)
if num >= max_token_size:
break
num *= 2 num *= 2
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment