Unverified Commit 856589ed authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Refactor] Remove dead code in kv connector and model runner (#38383)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 517b769b
...@@ -6,6 +6,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # ...@@ -6,6 +6,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( #
) )
from vllm.distributed.kv_transfer.kv_transfer_state import ( from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_initialized, ensure_kv_transfer_initialized,
ensure_kv_transfer_shutdown,
get_kv_transfer_group, get_kv_transfer_group,
) )
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
...@@ -57,4 +58,4 @@ def test_kv_connector_mixin_clears_metadata(): ...@@ -57,4 +58,4 @@ def test_kv_connector_mixin_clears_metadata():
assert connector.call_record.get("clear_connector_metadata", 0) == 1 assert connector.call_record.get("clear_connector_metadata", 0) == 1
finally: finally:
# Ensure we clean up the global connector between tests # Ensure we clean up the global connector between tests
KVConnectorModelRunnerMixin.ensure_kv_transfer_shutdown() ensure_kv_transfer_shutdown()
...@@ -102,10 +102,6 @@ class CPUModelRunner(GPUModelRunner): ...@@ -102,10 +102,6 @@ class CPUModelRunner(GPUModelRunner):
# so stale KV cache data never affects computation. # so stale KV cache data never affects computation.
pass pass
def get_dp_padding(self, num_tokens: int) -> tuple[int, torch.Tensor | None]:
# Note: For CPU backend, dp padding is not required for now.
return 0, None
@contextmanager @contextmanager
def _torch_cuda_wrapper(): def _torch_cuda_wrapper():
......
...@@ -34,14 +34,6 @@ class ECConnectorModelRunnerMixin: ...@@ -34,14 +34,6 @@ class ECConnectorModelRunnerMixin:
connector = get_ec_transfer() connector = get_ec_transfer()
connector.save_caches(encoder_cache=encoder_cache, mm_hash=mm_hash) connector.save_caches(encoder_cache=encoder_cache, mm_hash=mm_hash)
@staticmethod
def get_finished_ec_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[set[str] | None, set[str] | None]:
if has_ec_transfer():
return get_ec_transfer().get_finished(scheduler_output.finished_req_ids)
return None, None
@staticmethod @staticmethod
def maybe_get_ec_connector_output( def maybe_get_ec_connector_output(
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
......
...@@ -97,11 +97,6 @@ class ActiveKVConnector(KVConnector): ...@@ -97,11 +97,6 @@ class ActiveKVConnector(KVConnector):
self.kv_connector.clear_connector_metadata() self.kv_connector.clear_connector_metadata()
return output return output
def clear_metadata(self) -> None:
"""Clear the connector metadata. Call this after draft model runs."""
if not self._disabled:
self.kv_connector.clear_connector_metadata()
def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
if self._disabled: if self._disabled:
return EMPTY_MODEL_RUNNER_OUTPUT return EMPTY_MODEL_RUNNER_OUTPUT
......
...@@ -451,7 +451,6 @@ class GPUModelRunner( ...@@ -451,7 +451,6 @@ class GPUModelRunner(
# Model-related. # Model-related.
self.num_query_heads = model_config.get_num_attention_heads(parallel_config) self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
self.inputs_embeds_size = model_config.get_inputs_embeds_size() self.inputs_embeds_size = model_config.get_inputs_embeds_size()
self.attention_chunk_size = model_config.attention_chunk_size
# Only relevant for models using ALiBi (e.g, MPT) # Only relevant for models using ALiBi (e.g, MPT)
self.use_alibi = model_config.uses_alibi self.use_alibi = model_config.uses_alibi
...@@ -594,7 +593,6 @@ class GPUModelRunner( ...@@ -594,7 +593,6 @@ class GPUModelRunner(
# NOTE(rob): num_prompt_logprobs only includes reqs # NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase. # that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {} self.num_prompt_logprobs: dict[str, int] = {}
self.comm_stream = torch.cuda.Stream()
# Input Batch # Input Batch
# NOTE(Chen): Ideally, we should initialize the input batch inside # NOTE(Chen): Ideally, we should initialize the input batch inside
......
...@@ -13,11 +13,7 @@ import torch ...@@ -13,11 +13,7 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.distributed.kv_transfer import ( from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
ensure_kv_transfer_shutdown,
get_kv_transfer_group,
has_kv_transfer_group,
)
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -38,12 +34,6 @@ logger = init_logger(__name__) ...@@ -38,12 +34,6 @@ logger = init_logger(__name__)
# Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU) # Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU)
class KVConnectorModelRunnerMixin: class KVConnectorModelRunnerMixin:
@staticmethod
def ensure_kv_transfer_shutdown() -> None:
# has_kv_transfer_group can be None during interpreter shutdown.
if has_kv_transfer_group and has_kv_transfer_group(): # type: ignore[truthy-function]
ensure_kv_transfer_shutdown()
@staticmethod @staticmethod
def kv_connector_no_forward( def kv_connector_no_forward(
scheduler_output: "SchedulerOutput", vllm_config: VllmConfig scheduler_output: "SchedulerOutput", vllm_config: VllmConfig
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.torch_utils import supports_xpu_graph from vllm.utils.torch_utils import supports_xpu_graph
from vllm.v1.worker.gpu.model_runner import ( from vllm.v1.worker.gpu.model_runner import (
GPUModelRunner as GPUModelRunnerV2, GPUModelRunner as GPUModelRunnerV2,
) )
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if TYPE_CHECKING:
pass
logger = init_logger(__name__)
class XPUModelRunner(GPUModelRunner): class XPUModelRunner(GPUModelRunner):
"""A model runner for XPU devices.""" """A model runner for XPU devices."""
...@@ -47,7 +40,6 @@ class XPUModelRunnerV2(GPUModelRunnerV2): ...@@ -47,7 +40,6 @@ class XPUModelRunnerV2(GPUModelRunnerV2):
@contextmanager @contextmanager
def _torch_cuda_wrapper(): def _torch_cuda_wrapper():
try:
# replace cuda APIs with xpu APIs, this should work by default # replace cuda APIs with xpu APIs, this should work by default
torch.cuda.Stream = torch.xpu.Stream torch.cuda.Stream = torch.xpu.Stream
torch.cuda.default_stream = torch.xpu.current_stream torch.cuda.default_stream = torch.xpu.current_stream
...@@ -61,5 +53,3 @@ def _torch_cuda_wrapper(): ...@@ -61,5 +53,3 @@ def _torch_cuda_wrapper():
torch.cuda.CUDAGraph = torch.xpu.XPUGraph torch.cuda.CUDAGraph = torch.xpu.XPUGraph
torch.cuda.graph_pool_handle = torch.xpu.graph_pool_handle torch.cuda.graph_pool_handle = torch.xpu.graph_pool_handle
yield yield
finally:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment