Unverified Commit 471fe656 authored by Chengji Yao's avatar Chengji Yao Committed by GitHub
Browse files

[TPU][V1] Implicitly adjust page size when there's SMEM OOM (#16871)


Signed-off-by: default avatarChengji Yao <chengjiyao@google.com>
parent 3a0fba5c
...@@ -22,6 +22,7 @@ MODELS = [ ...@@ -22,6 +22,7 @@ MODELS = [
] ]
TENSOR_PARALLEL_SIZES = [1] TENSOR_PARALLEL_SIZES = [1]
MAX_NUM_REQS = [16, 1024]
# TODO: Enable when CI/CD will have a multi-tpu instance # TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4] # TENSOR_PARALLEL_SIZES = [1, 4]
...@@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1] ...@@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1]
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES) @pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
def test_basic( def test_basic(
vllm_runner: type[VllmRunner], vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
model: str, model: str,
max_tokens: int, max_tokens: int,
tensor_parallel_size: int, tensor_parallel_size: int,
max_num_seqs: int,
) -> None: ) -> None:
prompt = "The next numbers of the sequence " + ", ".join( prompt = "The next numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:" str(i) for i in range(1024)) + " are:"
...@@ -51,9 +54,9 @@ def test_basic( ...@@ -51,9 +54,9 @@ def test_basic(
# Note: max_num_batched_tokens == 1024 is needed here to # Note: max_num_batched_tokens == 1024 is needed here to
# actually test chunked prompt # actually test chunked prompt
max_num_batched_tokens=1024, max_num_batched_tokens=1024,
max_model_len=8196, max_model_len=8192,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
max_num_seqs=16, max_num_seqs=max_num_seqs,
tensor_parallel_size=tensor_parallel_size) as vllm_model: tensor_parallel_size=tensor_parallel_size) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens) max_tokens)
......
...@@ -97,6 +97,20 @@ class TpuPlatform(Platform): ...@@ -97,6 +97,20 @@ class TpuPlatform(Platform):
"Using bfloat16 instead.", vllm_config.model_config.dtype) "Using bfloat16 instead.", vllm_config.model_config.dtype)
vllm_config.model_config.dtype = torch.bfloat16 vllm_config.model_config.dtype = torch.bfloat16
if envs.VLLM_USE_V1:
from vllm.v1.attention.backends.pallas import (
PallasAttentionBackend)
min_page_size = PallasAttentionBackend.get_min_page_size(
vllm_config)
if min_page_size > vllm_config.cache_config.block_size:
logger.warning(
"Increase the page size from %s to %s to make sure there's"
"no SMEM OOM",
vllm_config.cache_config.block_size,
min_page_size,
)
vllm_config.cache_config.block_size = min_page_size
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":
......
...@@ -10,7 +10,9 @@ import torch_xla.experimental.custom_kernel # noqa: F401 ...@@ -10,7 +10,9 @@ import torch_xla.experimental.custom_kernel # noqa: F401
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType) AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -50,6 +52,19 @@ class PallasAttentionBackend(AttentionBackend): ...@@ -50,6 +52,19 @@ class PallasAttentionBackend(AttentionBackend):
) -> None: ) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.") raise RuntimeError("swap_blocks is not used for the TPU backend.")
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
# block_tables within the PallasMetadata constitute almost the entire SMEM
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
# we simply make sure that the size is smaller than half of SMEM capacity.
@staticmethod
def get_min_page_size(vllm_config: VllmConfig) -> int:
max_num_page_per_req = (1024 * 1024 // 2 //
vllm_config.scheduler_config.max_num_seqs // 4)
min_page_size = cdiv(vllm_config.model_config.max_model_len,
max_num_page_per_req)
min_page_size = 1 << (min_page_size - 1).bit_length()
return min_page_size
@dataclass @dataclass
class PallasMetadata: class PallasMetadata:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment