Unverified Commit b2e85e26 authored by Alexander Matveev's avatar Alexander Matveev Committed by GitHub
Browse files

[V1] TPU - Revert to exponential padding by default (#15565)


Signed-off-by: default avatarAlexander Matveev <amatveev@redhat.com>
parent dd8a29da
...@@ -99,7 +99,7 @@ if TYPE_CHECKING: ...@@ -99,7 +99,7 @@ if TYPE_CHECKING:
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 64 VLLM_TPU_BUCKET_PADDING_GAP: int = 0
def get_default_cache_root(): def get_default_cache_root():
...@@ -648,7 +648,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -648,7 +648,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# 8, we will run forward pass with [16, 24, 32, ...]. # 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP": "VLLM_TPU_BUCKET_PADDING_GAP":
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"]) lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 64, if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -944,18 +944,35 @@ def _get_paddings(min_token_size: int, max_token_size: int, ...@@ -944,18 +944,35 @@ def _get_paddings(min_token_size: int, max_token_size: int,
padding_gap: int) -> list[int]: padding_gap: int) -> list[int]:
"""Generate a list of padding size, starting from min_token_size, """Generate a list of padding size, starting from min_token_size,
ending with a number that can cover max_token_size ending with a number that can cover max_token_size
If padding_gap == 0 then:
increase 2X each time (exponential)
else:
first increase the size to twice, first increase the size to twice,
then increase the padding size by padding_gap. then increase the padding size by padding_gap.
""" """
paddings = [] paddings = []
num = min_token_size num = min_token_size
if padding_gap == 0:
logger.info("Using exponential paddings:")
while num <= max_token_size:
logger.info(" %d", num)
paddings.append(num)
num *= 2
else:
logger.info("Using incremental paddings:")
while num <= padding_gap: while num <= padding_gap:
logger.info(" %d", num)
paddings.append(num) paddings.append(num)
num *= 2 num *= 2
num //= 2 num //= 2
while num < max_token_size: while num < max_token_size:
num += padding_gap num += padding_gap
logger.info(" %d", num)
paddings.append(num) paddings.append(num)
return paddings return paddings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment