Unverified Commit 3408e471 authored by Yihua Cheng's avatar Yihua Cheng Committed by GitHub
Browse files
parent 0377b831
...@@ -13,6 +13,8 @@ import torch.nn as nn ...@@ -13,6 +13,8 @@ import torch.nn as nn
from vllm.attention import AttentionType, get_attn_backend from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.distributed.parallel_state import get_pp_group, graph_capture
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -987,6 +989,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -987,6 +989,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]: ) -> Union[ModelRunnerOutput, torch.Tensor]:
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
get_kv_transfer_group().bind_connector_metadata(
scheduler_output.kv_connector_metadata)
self._update_states(scheduler_output) self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens: if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOutput if there's no work to do. # Return empty ModelRunnerOutput if there's no work to do.
...@@ -1228,6 +1235,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1228,6 +1235,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# in the next step. # in the next step.
del draft_probs del draft_probs
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
return ModelRunnerOutput( return ModelRunnerOutput(
req_ids=self.input_batch.req_ids, req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=self.input_batch.req_id_to_index,
......
...@@ -9,11 +9,12 @@ import torch.distributed ...@@ -9,11 +9,12 @@ import torch.distributed
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig from vllm.config import VllmConfig
from vllm.device_allocator.cumem import CuMemAllocator from vllm.device_allocator.cumem import CuMemAllocator
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -110,7 +111,7 @@ class Worker(WorkerBase): ...@@ -110,7 +111,7 @@ class Worker(WorkerBase):
raise RuntimeError( raise RuntimeError(
f"Not support device type: {self.device_config.device}") f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment. # Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank, init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method, self.distributed_init_method,
self.local_rank) self.local_rank)
# Set random seed. # Set random seed.
...@@ -285,12 +286,13 @@ class Worker(WorkerBase): ...@@ -285,12 +286,13 @@ class Worker(WorkerBase):
def init_worker_distributed_environment( def init_worker_distributed_environment(
parallel_config: ParallelConfig, vllm_config: VllmConfig,
rank: int, rank: int,
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
local_rank: int = -1, local_rank: int = -1,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
parallel_config = vllm_config.parallel_config
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank, init_distributed_environment(parallel_config.world_size, rank,
...@@ -299,6 +301,8 @@ def init_worker_distributed_environment( ...@@ -299,6 +301,8 @@ def init_worker_distributed_environment(
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(vllm_config)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
......
...@@ -23,7 +23,8 @@ from vllm.attention.backends.abstract import AttentionState ...@@ -23,7 +23,8 @@ from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_kv_transfer_group, get_pp_group from vllm.distributed import get_pp_group
from vllm.distributed.kv_transfer import get_kv_transfer_group
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
graph_capture) graph_capture)
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
......
...@@ -10,10 +10,10 @@ import torch.distributed ...@@ -10,10 +10,10 @@ import torch.distributed
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.device_allocator.cumem import CuMemAllocator from vllm.device_allocator.cumem import CuMemAllocator
from vllm.distributed import (ensure_kv_transfer_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment