Unverified Commit 10f75527 authored by iefgnoix's avatar iefgnoix Committed by GitHub
Browse files

[V1][TPU] Remove unnecessary padding for running on TPU. (#14467)

parent b0d54194
...@@ -12,8 +12,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, ...@@ -12,8 +12,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
# These are the 2 tunable parameters of the paged attention Pallas kernel. # These are the 2 tunable parameters of the paged attention Pallas kernel.
NUM_QUERIES_PER_BLOCK = 16 NUM_QUERIES_PER_BLOCK = 32
NUM_KV_PAGES_PER_BLOCK = 256 NUM_KV_PAGES_PER_BLOCK = 128
class PallasAttentionBackend(AttentionBackend): class PallasAttentionBackend(AttentionBackend):
......
...@@ -23,9 +23,7 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality ...@@ -23,9 +23,7 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
NUM_QUERIES_PER_BLOCK,
PallasAttentionBackend,
PallasMetadata) PallasMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
...@@ -78,10 +76,8 @@ class TPUModelRunner: ...@@ -78,10 +76,8 @@ class TPUModelRunner:
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = _get_padded_number( self.max_num_tokens = scheduler_config.max_num_batched_tokens
scheduler_config.max_num_batched_tokens, NUM_QUERIES_PER_BLOCK) self.max_num_reqs = scheduler_config.max_num_seqs
self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs,
NUM_QUERIES_PER_BLOCK)
# Model-related. # Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type( self.num_attn_layers = model_config.get_num_layers_by_block_type(
...@@ -142,16 +138,8 @@ class TPUModelRunner: ...@@ -142,16 +138,8 @@ class TPUModelRunner:
device="cpu") device="cpu")
self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.slot_mapping_np = self.slot_mapping_cpu.numpy()
# self.input_batch.block_table has a shape of [max_num_reqs,
# max_num_blocks_per_req]. To reduce the number of recompilation,
# we want the block_table.shape[0] to be num_tokens.
# To make the block_table to be compatible with the paged attention
# kernel, we want the block_table[1] to be multiple of
# NUM_KV_PAGES_PER_BLOCK.
padded_max_num_blocks_per_req = _get_padded_number(
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
self.block_table_cpu = torch.zeros( self.block_table_cpu = torch.zeros(
(self.max_num_tokens, padded_max_num_blocks_per_req), (self.max_num_tokens, self.max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype, dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
device="cpu") device="cpu")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment