Unverified Commit 19a53b27 authored by afeldman-nm's avatar afeldman-nm Committed by GitHub
Browse files

[V1] Decouple GPU and TPU `InputBatch` (#19778)


Signed-off-by: default avatarAndrew Feldman <afeldman@redhat.com>
parent eccdc831
...@@ -5,7 +5,7 @@ from typing import Optional ...@@ -5,7 +5,7 @@ from typing import Optional
import torch import torch
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.tpu_input_batch import InputBatch
DEFAULT_SAMPLING_PARAMS = dict( DEFAULT_SAMPLING_PARAMS = dict(
temperature=-1.0, temperature=-1.0,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining an input batch # Datastructures defining a GPU input batch
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, cast from typing import Optional, cast
...@@ -453,6 +453,11 @@ class InputBatch: ...@@ -453,6 +453,11 @@ class InputBatch:
self.block_table.swap_row(i1, i2) self.block_table.swap_row(i1, i2)
def condense(self, empty_req_indices: list[int]) -> None: def condense(self, empty_req_indices: list[int]) -> None:
"""Move non-empty requests down into lower, empty indices.
Args:
empty_req_indices: empty batch indices, sorted descending.
"""
num_reqs = self.num_reqs num_reqs = self.num_reqs
if num_reqs == 0: if num_reqs == 0:
# The batched states are empty. # The batched states are empty.
......
...@@ -5,6 +5,7 @@ Define LoRA functionality mixin for model runners. ...@@ -5,6 +5,7 @@ Define LoRA functionality mixin for model runners.
""" """
from contextlib import contextmanager from contextlib import contextmanager
from typing import Union
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
...@@ -15,7 +16,10 @@ from vllm.lora.layers import LoRAMapping ...@@ -15,7 +16,10 @@ from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
InputBatch = Union[TPUInputBatch, GPUInputBatch]
logger = init_logger(__name__) logger = init_logger(__name__)
......
This diff is collapsed.
...@@ -42,8 +42,8 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ...@@ -42,8 +42,8 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
from .utils import (initialize_kv_cache_for_kv_sharing, from .utils import (initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs) sanity_check_mm_encoder_outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment