Unverified Commit 19a53b27 authored by afeldman-nm's avatar afeldman-nm Committed by GitHub
Browse files

[V1] Decouple GPU and TPU `InputBatch` (#19778)


Signed-off-by: default avatarAndrew Feldman <afeldman@redhat.com>
parent eccdc831
......@@ -5,7 +5,7 @@ from typing import Optional
import torch
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.tpu_input_batch import InputBatch
DEFAULT_SAMPLING_PARAMS = dict(
temperature=-1.0,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining an input batch
# Datastructures defining a GPU input batch
from dataclasses import dataclass
from typing import Optional, cast
......@@ -453,6 +453,11 @@ class InputBatch:
self.block_table.swap_row(i1, i2)
def condense(self, empty_req_indices: list[int]) -> None:
"""Move non-empty requests down into lower, empty indices.
Args:
empty_req_indices: empty batch indices, sorted descending.
"""
num_reqs = self.num_reqs
if num_reqs == 0:
# The batched states are empty.
......
......@@ -5,6 +5,7 @@ Define LoRA functionality mixin for model runners.
"""
from contextlib import contextmanager
from typing import Union
import numpy as np
import torch.nn as nn
......@@ -15,7 +16,10 @@ from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
InputBatch = Union[TPUInputBatch, GPUInputBatch]
logger = init_logger(__name__)
......
This diff is collapsed.
......@@ -42,8 +42,8 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
from .utils import (initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment