# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. diff -Naur v0.6.3.post1_vllm/_version.py patched_vllm/_version.py --- v0.6.3.post1_vllm/_version.py 2025-01-09 03:03:32.439278263 -0800 +++ patched_vllm/_version.py 2025-01-09 01:49:43.785300620 -0800 @@ -12,5 +12,5 @@ __version_tuple__: VERSION_TUPLE version_tuple: VERSION_TUPLE -__version__ = version = '0.6.3.post1' -__version_tuple__ = version_tuple = (0, 6, 3) +__version__ = version = '0.6.3.post2.dev16+gf61960ce' +__version_tuple__ = version_tuple = (0, 6, 3, 'dev16', 'gf61960ce') diff -Naur v0.6.3.post1_vllm/config.py patched_vllm/config.py --- v0.6.3.post1_vllm/config.py 2025-01-09 03:03:32.439278263 -0800 +++ patched_vllm/config.py 2025-01-09 01:49:43.785300620 -0800 @@ -1046,7 +1046,7 @@ "sequences. Please increase max_num_batched_tokens or " "decrease max_model_len.") - if self.max_num_batched_tokens < self.max_num_seqs: + if self.max_num_seqs is not None and self.max_num_batched_tokens < self.max_num_seqs: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " diff -Naur v0.6.3.post1_vllm/core/scheduler.py patched_vllm/core/scheduler.py --- v0.6.3.post1_vllm/core/scheduler.py 2025-01-09 03:03:32.291290245 -0800 +++ patched_vllm/core/scheduler.py 2025-01-09 01:49:43.785300620 -0800 @@ -17,6 +17,7 @@ SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceStatus) from vllm.utils import Device, PyObjectCache +import vllm.envs as envs logger = init_logger(__name__) @@ -883,12 +884,17 @@ assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.WAITING, - enable_chunking, budget) - if not enable_chunking: - num_prompt_tokens = waiting_seqs[0].get_len() - assert num_new_tokens == num_prompt_tokens + real_num_new_tokens = self._get_num_new_tokens(seq_group, + SequenceStatus.WAITING, + enable_chunking, budget) + if envs.VLLM_DISAGG_STAGE == "GENERATE": + num_new_tokens = 1 + assert not enable_chunking + else: + num_new_tokens = real_num_new_tokens + if not enable_chunking: + num_prompt_tokens = waiting_seqs[0].get_len() + assert num_new_tokens == num_prompt_tokens prompt_limit = self._get_prompt_limit(seq_group) if num_new_tokens > prompt_limit: @@ -967,7 +973,7 @@ seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, - token_chunk_size=num_new_tokens)) + token_chunk_size=real_num_new_tokens)) budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) budget.add_num_seqs(seq_group.request_id, num_new_seqs) diff -Naur v0.6.3.post1_vllm/distributed/parallel_state.py patched_vllm/distributed/parallel_state.py --- v0.6.3.post1_vllm/distributed/parallel_state.py 2025-01-09 03:03:32.291290245 -0800 +++ patched_vllm/distributed/parallel_state.py 2025-01-09 01:49:43.785300620 -0800 @@ -889,6 +889,14 @@ get_pipeline_model_parallel_group = get_pp_group +_STORE: Optional[Any] = None + +def get_store() -> Any: + assert _STORE is not None, ("store is not initialized") + return _STORE + + + @contextmanager def graph_capture(): """ @@ -926,20 +934,51 @@ local_rank: int = -1, backend: str = "nccl", ): + + # TODO ptarasiewicz this is a hack to get the stage from the environment + logger.info("="*50) + logger.info("Patching init_distributed_environment") + stage = envs.VLLM_DISAGG_STAGE + logger.info(f"Stage: {stage}") + store = None + if stage is not None and envs.VLLM_DATA_PLANE_BACKEND == "nccl": + context_workers = envs.VLLM_CONTEXT_WORKERS + context_tp_size = envs.VLLM_CONTEXT_TP_SIZE + generate_workers = envs.VLLM_GENERATE_WORKERS + generate_tp_size = envs.VLLM_GENERATE_TP_SIZE + world_size = context_workers * context_tp_size + generate_workers * generate_tp_size + if stage == "PREFILL": + worker_id = envs.VLLM_WORKER_ID + rank += worker_id * context_tp_size + if stage == "GENERATE": + rank += context_workers * context_tp_size # TODO ptarasiewicz this only works for 1 generate worker + # distributed_init_method = f"tcp://{envs.VLLM_TORCH_HOST}:{envs.VLLM_TORCH_PORT}" + distributed_init_method = None + store = torch.distributed.TCPStore(envs.VLLM_TORCH_HOST, envs.VLLM_TORCH_PORT, world_size=world_size, is_master = rank == 0) + logger.info(f"world_size: {world_size}, rank: {rank}, distributed_init_method: {distributed_init_method}, local_rank: {local_rank}, backend: {backend}") + logger.debug( "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", world_size, rank, local_rank, distributed_init_method, backend) if not torch.distributed.is_initialized(): - assert distributed_init_method is not None, ( - "distributed_init_method must be provided when initializing " + assert distributed_init_method is not None or store is not None, ( + "distributed_init_method or store must be provided when initializing " "distributed environment") # this backend is used for WORLD - torch.distributed.init_process_group( - backend=backend, - init_method=distributed_init_method, - world_size=world_size, - rank=rank) + if store is None: + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank) + else: + torch.distributed.init_process_group( + backend=backend, + # init_method=distributed_init_method, + world_size=world_size, + rank=rank, + store=store) # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -958,6 +997,10 @@ assert _WORLD.world_size == torch.distributed.get_world_size(), ( "world group already initialized with a different world size") + global _STORE + if store is not None and _STORE is None: + _STORE = store + def initialize_model_parallel( tensor_model_parallel_size: int = 1, @@ -986,32 +1029,60 @@ with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. """ + logger.debug("="*50) + logger.debug("Patching initialize_model_parallel") # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend( get_world_group().device_group) - if (world_size != - tensor_model_parallel_size * pipeline_model_parallel_size): - raise RuntimeError( - f"world_size ({world_size}) is not equal to " - f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " - f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") + + # TODO ptarasiewicz this assertion does not work with disagg + # if (world_size != + # tensor_model_parallel_size * pipeline_model_parallel_size): + # raise RuntimeError( + # f"world_size ({world_size}) is not equal to " + # f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + # f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") # Build the tensor model-parallel groups. num_tensor_model_parallel_groups: int = (world_size // - tensor_model_parallel_size) + tensor_model_parallel_size) + + stage = envs.VLLM_DISAGG_STAGE + global _TP assert _TP is None, ("tensor model parallel group is already initialized") group_ranks = [] for i in range(num_tensor_model_parallel_groups): - ranks = list( - range(i * tensor_model_parallel_size, - (i + 1) * tensor_model_parallel_size)) - group_ranks.append(ranks) + + # TODO ptarasiewicz this is a hack to adjust the ranks for the stage + if stage is not None and envs.VLLM_DATA_PLANE_BACKEND == "nccl": + ranks = [] + context_workers = envs.VLLM_CONTEXT_WORKERS + context_tp_size = envs.VLLM_CONTEXT_TP_SIZE + generate_workers = envs.VLLM_GENERATE_WORKERS + generate_tp_size = envs.VLLM_GENERATE_TP_SIZE + for context_id in range(context_workers): + ranks.append( + [context_id * context_tp_size + i for i in range(context_tp_size)] + ) + for generate_id in range(generate_workers): + ranks.append( + [context_workers * context_tp_size + generate_id * generate_tp_size + i for i in range(generate_tp_size)] + ) + group_ranks.extend(ranks) + break + else: + ranks = list( + range(i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size)) + group_ranks.append(ranks) # message queue broadcaster is only used in tensor model parallel group + logger.debug("initializing tensor model parallel group") + logger.debug(f"group_ranks {group_ranks}") _TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, @@ -1020,15 +1091,32 @@ # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = (world_size // - pipeline_model_parallel_size) + pipeline_model_parallel_size) + global _PP assert _PP is None, ( "pipeline model parallel group is already initialized") group_ranks = [] for i in range(num_pipeline_model_parallel_groups): - ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) - group_ranks.append(ranks) + + + + # TODO ptarasiewicz this is a hack to adjust the ranks for the stage + if stage is not None and envs.VLLM_DATA_PLANE_BACKEND == "nccl": + context_workers = envs.VLLM_CONTEXT_WORKERS + context_tp_size = envs.VLLM_CONTEXT_TP_SIZE + generate_workers = envs.VLLM_GENERATE_WORKERS + generate_tp_size = envs.VLLM_GENERATE_TP_SIZE + world_size = context_workers * context_tp_size + generate_workers * generate_tp_size + ranks = [[i] for i in range(world_size)] + group_ranks.extend(ranks) + break + else: + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) # pipeline parallel does not need custom allreduce + logger.debug("initializing pipeline model parallel group") + logger.debug(f"group_ranks {group_ranks}") _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, diff -Naur v0.6.3.post1_vllm/engine/async_llm_engine.py patched_vllm/engine/async_llm_engine.py --- v0.6.3.post1_vllm/engine/async_llm_engine.py 2025-01-09 03:03:32.443277939 -0800 +++ patched_vllm/engine/async_llm_engine.py 2025-01-09 01:49:43.785300620 -0800 @@ -371,7 +371,7 @@ # multi_step_model_runner does the first-step output append. is_first_step_output: bool = False if not seq_group_metadata_list \ else seq_group_metadata_list[0].state.num_steps == 1 - + ctx.append_output(outputs=outputs, seq_group_metadata_list=seq_group_metadata_list, scheduler_outputs=scheduler_outputs, diff -Naur v0.6.3.post1_vllm/engine/llm_engine.py patched_vllm/engine/llm_engine.py --- v0.6.3.post1_vllm/engine/llm_engine.py 2025-01-09 03:03:32.443277939 -0800 +++ patched_vllm/engine/llm_engine.py 2025-01-09 01:49:43.785300620 -0800 @@ -479,8 +479,28 @@ The workers will determine the number of blocks in both the GPU cache and the swap CPU cache. """ - num_gpu_blocks, num_cpu_blocks = ( - self.model_executor.determine_num_available_blocks()) + + + max_num_seqs = None + if self.scheduler_config.max_num_seqs is not None: + num_gpu_blocks, num_cpu_blocks = ( + self.model_executor.determine_num_available_blocks()) + else: + max_num_seqs = 1 + max_concurrency = None + max_iter_count_left = 5 + while True: + logger.info("Profiling with %d sequences", max_num_seqs) + num_gpu_blocks, num_cpu_blocks = ( + self.model_executor.determine_num_available_blocks(max_num_seqs)) + max_concurrency = (num_gpu_blocks * self.cache_config.block_size / self.model_config.max_model_len) + logger.info("Maximum concurrency for %d sequences and %s tokens per request: %.2fx", + max_num_seqs, self.model_config.max_model_len, max_concurrency) + max_iter_count_left -= 1 + if max_iter_count_left < 1 or (max_concurrency - max_num_seqs) ** 2 < 2: + break + max_num_seqs = int(max_concurrency) + if self.cache_config.num_gpu_blocks_override is not None: num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override @@ -494,6 +514,10 @@ self.cache_config.num_cpu_blocks = num_cpu_blocks self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + if max_num_seqs is not None and self.scheduler_config.max_num_seqs is None: + self.scheduler_config.max_num_seqs = max_num_seqs + logger.info("Setting max_num_seqs to %d", max_num_seqs) @classmethod def _get_executor_cls(cls, diff -Naur v0.6.3.post1_vllm/envs.py patched_vllm/envs.py --- v0.6.3.post1_vllm/envs.py 2025-01-09 03:03:32.439278263 -0800 +++ patched_vllm/envs.py 2025-01-09 01:49:43.789300297 -0800 @@ -66,6 +66,15 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_TORCH_COMPILE_LEVEL: int = 0 VLLM_DISABLED_KERNELS: List[str] = [] + VLLM_DISAGG_STAGE: Optional[str] = None + VLLM_CONTEXT_TP_SIZE: int = 0 + VLLM_CONTEXT_WORKERS: int = 0 + VLLM_GENERATE_TP_SIZE: int = 0 + VLLM_GENERATE_WORKERS: int = 0 + VLLM_TORCH_HOST: str = "localhost" + VLLM_TORCH_PORT: int = 36183 + VLLM_WORKER_ID: int = 0 + VLLM_DATA_PLANE_BACKEND: str = "nccl" def get_default_cache_root(): @@ -433,6 +442,35 @@ "VLLM_DISABLED_KERNELS": lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[ "VLLM_DISABLED_KERNELS"].split(","), + + # Stage of the disaggregated model + "VLLM_DISAGG_STAGE": + lambda: os.getenv("VLLM_DISAGG_STAGE", None), + + # Disaggregation global config + "VLLM_CONTEXT_TP_SIZE": + lambda: int(os.getenv("VLLM_CONTEXT_TP_SIZE", "0")), + + "VLLM_CONTEXT_WORKERS": + lambda: int(os.getenv("VLLM_CONTEXT_WORKERS", "0")), + + "VLLM_GENERATE_TP_SIZE": + lambda: int(os.getenv("VLLM_GENERATE_TP_SIZE", "0")), + + "VLLM_GENERATE_WORKERS": + lambda: int(os.getenv("VLLM_GENERATE_WORKERS", "0")), + + "VLLM_TORCH_HOST": + lambda: os.getenv("VLLM_TORCH_HOST", "localhost"), + + "VLLM_TORCH_PORT": + lambda: int(os.getenv("VLLM_TORCH_PORT", "36183")), + + "VLLM_WORKER_ID": + lambda: int(os.getenv("VLLM_WORKER_ID", "0")), + + "VLLM_DATA_PLANE_BACKEND": + lambda: os.getenv("VLLM_DATA_PLANE_BACKEND", "nccl"), } # end-env-vars-definition diff -Naur v0.6.3.post1_vllm/executor/distributed_gpu_executor.py patched_vllm/executor/distributed_gpu_executor.py --- v0.6.3.post1_vllm/executor/distributed_gpu_executor.py 2025-01-09 03:03:32.443277939 -0800 +++ patched_vllm/executor/distributed_gpu_executor.py 2025-01-09 01:49:43.789300297 -0800 @@ -25,7 +25,7 @@ super().__init__(*args, **kwargs) - def determine_num_available_blocks(self) -> Tuple[int, int]: + def determine_num_available_blocks(self, max_num_seqs: Optional[int] = None) -> Tuple[int, int]: """Determine the number of available KV blocks. This invokes `determine_num_available_blocks` on each worker and takes @@ -36,7 +36,7 @@ - tuple[num_gpu_blocks, num_cpu_blocks] """ # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers("determine_num_available_blocks", ) + num_blocks = self._run_workers("determine_num_available_blocks", max_num_seqs=max_num_seqs) # Since we use a shared centralized controller, we take the minimum # number of blocks across all workers to make sure all the memory diff -Naur v0.6.3.post1_vllm/executor/gpu_executor.py patched_vllm/executor/gpu_executor.py --- v0.6.3.post1_vllm/executor/gpu_executor.py 2025-01-09 03:03:32.443277939 -0800 +++ patched_vllm/executor/gpu_executor.py 2025-01-09 01:49:43.789300297 -0800 @@ -107,11 +107,11 @@ rank=rank, distributed_init_method=distributed_init_method)) - def determine_num_available_blocks(self) -> Tuple[int, int]: + def determine_num_available_blocks(self, max_num_seqs: Optional[int] = None) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ - return self.driver_worker.determine_num_available_blocks() + return self.driver_worker.determine_num_available_blocks(max_num_seqs=max_num_seqs) def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: """Initialize the KV cache by invoking the underlying worker. diff -Naur v0.6.3.post1_vllm/worker/model_runner.py patched_vllm/worker/model_runner.py --- v0.6.3.post1_vllm/worker/model_runner.py 2025-01-09 03:03:32.559268548 -0800 +++ patched_vllm/worker/model_runner.py 2025-01-09 01:49:43.993283816 -0800 @@ -1,7 +1,9 @@ +import time import dataclasses import gc import inspect import itertools +import os import time import warnings import weakref @@ -25,6 +27,7 @@ PromptAdapterConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_pp_group +from vllm.distributed.kv_cache import get_kv_cache_handler from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry @@ -46,7 +49,7 @@ from vllm.prompt_adapter.worker_manager import ( LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata, Logprob, SequenceOutput, CompletionSequenceGroupOutput from vllm.transformers_utils.config import uses_mrope from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available, @@ -58,6 +61,7 @@ _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict, dump_input_when_exception) + if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -120,6 +124,7 @@ "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, "finished_requests_ids": self.finished_requests_ids, + "seq_lens": self.seq_lens, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @@ -158,6 +163,7 @@ "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, "finished_requests_ids": self.finished_requests_ids, + "seq_lens": self.seq_lens, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -959,6 +965,7 @@ input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): + self._kv_cache_handler = None self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config @@ -978,8 +985,7 @@ self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture - self.max_batchsize_to_capture = _get_max_graph_batch_size( - self.scheduler_config.max_num_seqs) + self.max_batchsize_to_capture = None self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ {} for _ in range(self.parallel_config.pipeline_parallel_size) @@ -995,9 +1001,7 @@ # in numpy and only copy the actual input content at every iteration. # The shape of the cached block table will be # (max batch size to capture, max context len to capture / block size). - self.graph_block_tables = np.zeros( - (self.max_batchsize_to_capture, self.get_max_block_per_batch()), - dtype=np.int32) + self.graph_block_tables = None # Attention-free but stateful models like Mamba need a placeholder attn # backend, as the attention metadata is needed to manage internal state. @@ -1196,11 +1200,18 @@ return builder.build() # type: ignore @torch.inference_mode() - def profile_run(self) -> None: + def profile_run(self, max_num_seqs: Optional[int] = None) -> None: # Enable top-k sampling to reflect the accurate memory usage. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens - max_num_seqs = self.scheduler_config.max_num_seqs + if max_num_seqs is None: + max_num_seqs = self.scheduler_config.max_num_seqs + + self.max_batchsize_to_capture = _get_max_graph_batch_size( + max_num_seqs) + self.graph_block_tables = np.zeros( + (self.max_batchsize_to_capture, self.get_max_block_per_batch()), + dtype=np.int32) # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request @@ -1522,6 +1533,11 @@ # This usually takes < 10 seconds. logger.info("Graph capturing finished in %.0f secs.", elapsed_time) + def init_kv_cache_handler(self) -> None: + if envs.VLLM_DISAGG_STAGE is not None: + self._kv_cache_handler = get_kv_cache_handler() + torch.distributed.barrier() + def _update_inputs_to_capture_for_enc_dec_model(self, capture_inputs: Dict[str, Any]): @@ -1552,6 +1568,13 @@ _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( ModelInputForGPUWithSamplingMetadata) _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder + _first_tokens = {} + + def set_first_token(self, request_id: str, first_token: int) -> None: + self._first_tokens[request_id] = first_token + + def pop_first_token(self, request_id: str) -> int: + return self._first_tokens.pop(request_id) def make_model_input_from_broadcasted_tensor_dict( self, @@ -1610,6 +1633,8 @@ intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + logger.debug(f"execute model input_ids: {model_input.input_tokens.shape}") + # logger.info(f"model input {model_input}") if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") @@ -1654,21 +1679,98 @@ model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() + # Check if we need to run inference or not + # We do not run inference if we are in generating stage of disaggregated serving + disagg_stage = envs.VLLM_DISAGG_STAGE + assert disagg_stage in ["PREFILL", "GENERATE", None], f"Invalid disagg_stage: {disagg_stage}, must be one of ['PREFILL', 'GENERATE', None]" + is_profile_run = ( + (kv_caches is None) or (kv_caches[0] is None) or (kv_caches[0].numel() == 0) + ) + is_generate = disagg_stage == "GENERATE" and prefill_meta is None + run_inference = any( + [ + disagg_stage == "PREFILL", + is_profile_run, + is_generate, + disagg_stage is None, + ] + ) + + logger.debug( + f"Running inference: {run_inference}, disagg_stage: {disagg_stage}, " + f"is_profile_run: {is_profile_run}, is_generate: {is_generate}" + ) + logger.debug(f"Batch size: {model_input.input_tokens.shape} Prefill: {prefill_meta is not None}, Generate: {prefill_meta is None}") + + if not run_inference: + if self.is_driver_worker: + logger.debug(f"Pulling KV cache for seq lens: {model_input.seq_lens}") + # We are in the GENERATE stage of disaggregated serving + # instead of running inference, we just need to pull the KV cache + # and hidden states from the context stage + # TODO ptarasiewicz check why without torch.cuda.synchronize() thorughput is lower + logger.debug("PULLING KV CACHE") + # torch.cuda.synchronize() + # start_time = time.perf_counter_ns() + self._kv_cache_handler.recv( + self.model.model, + model_input, + kv_caches, + ) + # torch.cuda.synchronize() + # end_time = time.perf_counter_ns() + # logger.info(f"KV CACHE PULL TIME: {(end_time - start_time) / 1e6} ms") + + if not self.is_driver_worker: + return [] + + fist_token = self.pop_first_token(list(model_input.request_ids_to_seq_ids.keys())[0]) + mocked_output = SamplerOutput( + outputs=[ + CompletionSequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq_group.seq_ids[0], + output_token=fist_token, + logprobs={fist_token: Logprob(float('inf'))}, + ) + ], + prompt_logprobs=None, + ) + for seq_group in model_input.sampling_metadata.seq_groups + ], + ) + logger.debug(f"MOCKED OUTPUT {mocked_output}") + return [mocked_output] + + logger.debug("RUNNING INFERENCE") with set_forward_context(model_input.attn_metadata): - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + with torch.cuda.nvtx.range("model_executable"): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() + if disagg_stage == "PREFILL" and not is_profile_run: + # TODO ptarasiewicz check why without torch.cuda.synchronize() thorughput is lower + logger.debug("Pushing KV cache") + torch.cuda.synchronize() + self._kv_cache_handler.send( + self.model.model, + model_input, + kv_caches, + ) + logger.debug("finished sending") + # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: if (self.is_driver_worker @@ -1687,6 +1789,31 @@ hidden_or_intermediate_states.tensors["model_forward_time"] = ( torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states + + if not run_inference: + + if not self.is_driver_worker: + return [] + + fist_token = self.pop_first_token(list(model_input.request_ids_to_seq_ids.keys())[0]) + mocked_output = SamplerOutput( + outputs=[ + CompletionSequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq_group.seq_ids[0], + output_token=fist_token, + logprobs={fist_token: Logprob(float('inf'))}, + ) + ], + prompt_logprobs=None, + ) + for seq_group in model_input.sampling_metadata.seq_groups + ], + ) + # print("OUTPUT", output) + print("MOCKED OUTPUT", mocked_output) + return [mocked_output] logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) @@ -1702,6 +1829,7 @@ logits=logits, sampling_metadata=model_input.sampling_metadata, ) + if (self.observability_config is not None and self.observability_config.collect_model_forward_time and output is not None): diff -Naur v0.6.3.post1_vllm/worker/worker.py patched_vllm/worker/worker.py --- v0.6.3.post1_vllm/worker/worker.py 2025-01-09 03:03:32.559268548 -0800 +++ patched_vllm/worker/worker.py 2025-01-09 01:49:43.993283816 -0800 @@ -202,7 +202,7 @@ tensorizer_config=tensorizer_config, ) @torch.inference_mode() - def determine_num_available_blocks(self) -> Tuple[int, int]: + def determine_num_available_blocks(self, max_num_seqs: Optional[int] = None) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many KV blocks may be allocated without OOMs. @@ -214,13 +214,14 @@ You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. """ + # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - self.model_runner.profile_run() + self.model_runner.profile_run(max_num_seqs) # Calculate the number of blocks that can be allocated with the # profiled peak memory. @@ -242,16 +243,18 @@ else: num_gpu_blocks = int( (total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) // cache_block_size) + peak_memory) // cache_block_size) num_cpu_blocks = int(self.cache_config.swap_space_bytes // - cache_block_size) + cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) if self.model_runner.lora_manager: self.model_runner.remove_all_loras() gc.collect() torch.cuda.empty_cache() - return num_gpu_blocks, num_cpu_blocks + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: @@ -285,6 +288,7 @@ def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: self.model_runner.capture_model(self.gpu_cache) + self.model_runner.init_kv_cache_handler() # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed)