Unverified Commit ff7ec82c authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)

parent 200a2ffa
......@@ -21,6 +21,7 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
msgspec
librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
......
......@@ -8,6 +8,7 @@ pytest tests/basic_correctness/test_preemption.py`.
import pytest
from prometheus_client import REGISTRY
import vllm.envs as envs
from vllm import SamplingParams
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
ENABLE_ARTIFICIAL_PREEMPT)
......@@ -24,6 +25,13 @@ assert ENABLE_ARTIFICIAL_PREEMPT is True, (
"tests/basic_correctness/test_preemption.py`")
@pytest.fixture
def worker_use_ray() -> bool:
# When SPMD worker is used, use ray_use_worker=True
# to test delta input optimization works with preemption.
return envs.VLLM_USE_RAY_SPMD_WORKER
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96])
......@@ -36,6 +44,7 @@ def test_chunked_prefill_recompute(
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
worker_use_ray: bool,
) -> None:
"""Ensure that chunked prefill works with preemption."""
max_num_seqs = min(chunked_prefill_token_size, 256)
......@@ -54,6 +63,7 @@ def test_chunked_prefill_recompute(
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
max_num_seqs=max_num_seqs,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
......@@ -79,6 +89,7 @@ def test_preemption(
model: str,
dtype: str,
max_tokens: int,
worker_use_ray: bool,
) -> None:
"""By default, recompute preemption is enabled"""
......@@ -89,6 +100,7 @@ def test_preemption(
model,
dtype=dtype,
disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
......@@ -132,6 +144,7 @@ def test_swap(
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
) -> None:
"""Use beam search enables swapping."""
example_prompts = example_prompts[:1]
......@@ -144,6 +157,7 @@ def test_swap(
dtype=dtype,
swap_space=10,
disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
......@@ -188,6 +202,7 @@ def test_swap_infeasible(
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
) -> None:
"""Verify infeasible swap request will be ignored."""
BLOCK_SIZE = 16
......@@ -204,6 +219,7 @@ def test_swap_infeasible(
# decode blocks are not enough to finish.
num_gpu_blocks_override=prefill_blocks + decode_blocks,
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
worker_use_ray=worker_use_ray,
) as vllm_model:
sampling_params = SamplingParams(n=beam_width,
use_beam_search=True,
......@@ -230,6 +246,7 @@ def test_preemption_infeasible(
model: str,
dtype: str,
max_tokens: int,
worker_use_ray: bool,
) -> None:
"""Verify infeasible preemption request will be ignored."""
BLOCK_SIZE = 16
......@@ -244,6 +261,7 @@ def test_preemption_infeasible(
# ignored instead of hanging forever.
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
worker_use_ray=worker_use_ray,
) as vllm_model:
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True)
......
import msgspec
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.sequence import ExecuteModelRequest
from ..spec_decode.utils import create_batch
def test_msgspec_serialization():
num_lookahead_slots = 4
seq_group_metadata_list, _, _ = create_batch(16, num_lookahead_slots)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=num_lookahead_slots,
running_queue_size=4)
encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
dec_hook=decode_hook)
req = decoder.decode(encoder.encode(execute_model_req))
expected = execute_model_req.seq_group_metadata_list
actual = req.seq_group_metadata_list
assert (len(expected) == len(actual))
expected = expected[0]
actual = actual[0]
assert expected.block_tables == actual.block_tables
assert expected.is_prompt == actual.is_prompt
assert expected.request_id == actual.request_id
assert (expected.seq_data[0].prompt_token_ids ==
actual.seq_data[0].prompt_token_ids)
assert (expected.seq_data[0].output_token_ids ==
actual.seq_data[0].output_token_ids)
......@@ -22,7 +22,8 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
@pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"model, distributed_executor_backend, attention_backend, test_suite", [
"model, distributed_executor_backend, attention_backend, "
"test_suite", [
("facebook/opt-125m", "ray", "", "L4"),
("facebook/opt-125m", "mp", "", "L4"),
("meta-llama/Llama-2-7b-hf", "ray", "", "L4"),
......
......@@ -6,6 +6,8 @@ pytest test_chunked_prefill_distributed.py
```
"""
import os
import pytest
from vllm.utils import cuda_device_count_stateless
......@@ -30,6 +32,11 @@ def test_models(
model: str,
distributed_executor_backend: str,
) -> None:
if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray": # noqa
assert distributed_executor_backend == "ray"
# test ray adag
os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1"
os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1"
dtype = "half"
max_tokens = 5
......
import itertools
import random
from array import array
from typing import Dict, List, Optional, Tuple
from unittest.mock import Mock, patch
......@@ -10,7 +11,8 @@ from transformers import GenerationConfig, GenerationMixin
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import Counter, is_pin_memory_available
......@@ -56,7 +58,9 @@ def _do_sample(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=sampling_params,
block_tables={0: [1]},
))
......@@ -201,7 +205,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def create_sequence_data(num_input=3, num_generated=0):
seq_data = SequenceData(
random.choices(range(0, VOCAB_SIZE), k=num_input))
array(VLLM_TOKEN_ID_ARRAY_TYPE,
random.choices(range(0, VOCAB_SIZE), k=num_input)))
if num_generated > 0:
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
k=num_generated)
......@@ -504,7 +509,9 @@ def test_sampler_mixed(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=sampling_params,
block_tables={0: [1]},
))
......@@ -600,7 +607,9 @@ def test_sampler_top_k_top_p(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=SamplingParams(
temperature=1,
top_k=top_k,
......@@ -650,7 +659,11 @@ def test_sampler_repetition_penalty_mixed(device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={
0:
SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
[1, 2, 3]))
},
sampling_params=sampling_params[i],
block_tables={0: [1]},
))
......
from array import array
from itertools import count
from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
......@@ -9,7 +10,8 @@ import torch
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceData, SequenceGroupMetadata,
SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
......@@ -138,8 +140,9 @@ def create_seq_group_metadata_from_prompts(
seq_data={
i:
SequenceData(
prompt_token_ids=prompt_token_ids[:],
output_token_ids=cont_token_ids[:],
array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]),
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
cont_token_ids[:]),
),
},
sampling_params=SamplingParams(temperature=0.0, ),
......
import random
from array import array
from typing import Tuple
from unittest.mock import patch
......@@ -8,7 +9,8 @@ import torch
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available
......@@ -69,7 +71,9 @@ def test_logits_processors(seed: int, device: str):
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
......
from array import array
import pytest
from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput,
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, SamplerOutput,
SequenceData, SequenceOutput)
from .core.utils import create_dummy_prompt
......@@ -54,7 +57,7 @@ def test_sampler_output_eq(sample_outputs):
def test_sequence_data_prefill():
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4])
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4]))
assert seq_data.get_num_uncomputed_tokens() == 4
assert seq_data.get_num_computed_tokens() == 0
# advance by 2
......
from array import array
from typing import List
import pytest
import torch
from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import is_cpu
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
......@@ -125,10 +127,12 @@ def test_prepare_prompt(
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
......@@ -319,10 +323,12 @@ def test_prepare_decode(
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
......
from array import array
from typing import List
import pytest
......@@ -7,7 +8,8 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
......@@ -46,7 +48,8 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
......@@ -163,7 +166,8 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len)
seq_data = SequenceData(list(range(context_len)))
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)))
seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0)
......@@ -324,7 +328,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
......@@ -340,7 +345,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(context_len))
prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))
seq_data = SequenceData(prompt_toks)
seq_data.append_token_id(1, 0)
seq_data.update_num_computed_tokens(context_len)
......
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass
class AdapterRequest(ABC):
"""
Base class for adapter requests.
......
......@@ -770,8 +770,8 @@ class ParallelConfig:
self.tokenizer_pool_config = tokenizer_pool_config
self.ray_workers_use_nsight = ray_workers_use_nsight
self.placement_group = placement_group
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
if worker_use_ray:
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "ray"
......@@ -867,6 +867,11 @@ class SchedulerConfig:
swapping. However, when the sequence group has multiple sequences
(e.g., beam search), recomputation is not currently supported. In
such a case, we use swapping instead.
send_delta_data: Private API. If used, scheduler sends delta data to
workers instead of an entire data. It should be enabled only
when SPMD worker architecture is enabled. I.e.,
VLLM_USE_RAY_SPMD_WORKER=1
"""
def __init__(self,
......@@ -879,7 +884,8 @@ class SchedulerConfig:
enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False,
preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1) -> None:
num_scheduler_steps: int = 1,
send_delta_data: bool = False) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
......@@ -909,6 +915,7 @@ class SchedulerConfig:
self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode
self.num_scheduler_steps = num_scheduler_steps
self.send_delta_data = send_delta_data
self._verify_args()
def _verify_args(self) -> None:
......
......@@ -12,7 +12,8 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)
from vllm.utils import PyObjectCache
logger = init_logger(__name__)
......@@ -363,8 +364,6 @@ class Scheduler:
self.num_cumulative_preemption: int = 0
# Used to cache python objects
self._seq_group_metadata_cache: PyObjectCache = PyObjectCache(
seq_group_metadata_builder)
self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache(
scheduler_running_outputs_builder)
self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache(
......@@ -1048,15 +1047,10 @@ class Scheduler:
token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.maybe_set_first_scheduled_time(now)
seq_group_metadata = self._seq_group_metadata_cache.get_object()
seq_group_metadata.seq_data.clear()
seq_group_metadata.block_tables.clear()
# seq_id -> SequenceData
seq_data: Dict[int, SequenceData] = seq_group_metadata.seq_data
seq_data: Dict[int, SequenceData] = {}
# seq_id -> physical block numbers
block_tables: Dict[int,
List[int]] = seq_group_metadata.block_tables
block_tables: Dict[int, List[int]] = {}
if seq_group.is_encoder_decoder():
# Encoder associated with SequenceGroup
......@@ -1081,24 +1075,29 @@ class Scheduler:
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
do_sample = True
if seq_group.is_prefill():
is_prompt = seq_group.is_prefill()
# We should send the metadata to workers when the first prefill
# is sent. Subsequent requests could be chunked prefill or decode.
is_first_prefill = False
if is_prompt:
seqs = seq_group.get_seqs()
# Prefill has only 1 sequence.
assert len(seqs) == 1
num_computed_tokens = seqs[0].data.get_num_computed_tokens()
is_first_prefill = num_computed_tokens == 0
# In the next iteration, all prompt tokens are not computed.
# It means the prefill is chunked, and we don't need sampling.
# NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated
# output tokens.
if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
if (token_chunk_size + num_computed_tokens <
seqs[0].data.get_len()):
do_sample = False
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
is_prompt = seq_group.is_prefill()
seq_group_metadata.__init__(
if is_first_prefill or not self.scheduler_config.send_delta_data:
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
seq_data=seq_data,
......@@ -1120,6 +1119,21 @@ class Scheduler:
if scheduler_outputs.num_prefill_groups > 0 else None,
prompt_adapter_request=seq_group.prompt_adapter_request,
)
else:
# When SPMD mode is enabled, we only send delta data except for
# the first request to reduce serialization cost.
seq_data_delta = {}
for id, data in seq_data.items():
seq_data_delta[id] = data.get_delta_and_reset()
seq_group_metadata = SequenceGroupMetadataDelta(
seq_data_delta,
seq_group.request_id,
block_tables,
is_prompt,
do_sample=do_sample,
token_chunk_size=token_chunk_size,
computed_block_nums=common_computed_block_nums,
)
seq_group_metadata_list.append(seq_group_metadata)
# Now that the batch has been created, we can assume all blocks in the
......@@ -1130,8 +1144,6 @@ class Scheduler:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group)
self._seq_group_metadata_cache.reset()
scheduler_time = time.perf_counter() - scheduler_start_time
# Add this to scheduler time to all the sequences that are currently
# running. This will help estimate if the scheduler is a significant
......
......@@ -5,6 +5,7 @@ from dataclasses import dataclass
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
Union)
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig,
......@@ -905,6 +906,8 @@ class EngineArgs:
embedding_mode=model_config.embedding_mode,
preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray),
)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
......
......@@ -224,7 +224,6 @@ class LLMEngine:
cache_config.enable_prefix_caching,
)
# TODO(woosuk): Print more configs in debug mode.
from vllm.plugins import load_general_plugins
load_general_plugins()
......
from array import array
from typing import Any, Type
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
def encode_hook(obj: Any) -> Any:
"""Custom msgspec enc hook that supports array types.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if isinstance(obj, array):
assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, (
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
f"Given array has a type code of {obj.typecode}.")
return obj.tobytes()
def decode_hook(type: Type, obj: Any) -> Any:
"""Custom msgspec dec hook that supports array types.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if type is array:
deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE)
deserialized.frombytes(obj)
return deserialized
......@@ -4,9 +4,12 @@ from collections import defaultdict
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import msgspec
import vllm.envs as envs
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.msgspec_utils import encode_hook
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
......@@ -60,6 +63,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
self.output_decoder = msgspec.msgpack.Decoder(
Optional[List[SamplerOutput]])
def shutdown(self) -> None:
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
......@@ -123,6 +130,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
ray_remote_kwargs)
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers.
driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
......@@ -304,8 +312,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
outputs = ray.get(self.forward_dag.execute(execute_model_req))
return outputs[0]
serialized_data = self.input_encoder.encode(execute_model_req)
outputs = ray.get(self.forward_dag.execute(serialized_data))
output = self.output_decoder.decode(outputs[0])
return output
def _run_workers(
self,
......@@ -475,9 +485,10 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
dag_future = await self.forward_dag.execute_async(execute_model_req)
serialized_data = self.input_encoder.encode(execute_model_req)
dag_future = await self.forward_dag.execute_async(serialized_data)
outputs = await dag_future
return outputs[0]
return self.output_decoder.decode(outputs[0])
async def _driver_execute_model_async(
self,
......
from typing import List, Optional, Tuple, Union
import msgspec
from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
......@@ -24,6 +27,10 @@ try:
# that thread.
self.compiled_dag_cuda_device_set = False
self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
dec_hook=decode_hook)
self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
def get_node_ip(self) -> str:
return get_ip()
......@@ -33,16 +40,26 @@ try:
return node_id, gpu_ids
def execute_model_spmd(
self, req_or_tuple: Union[ExecuteModelRequest,
Tuple[ExecuteModelRequest,
IntermediateTensors]]):
self, req_or_tuple: Union[bytes,
Tuple[bytes,
Optional[IntermediateTensors]]]
) -> bytes:
"""Execute model in SPMD fashion: used only when SPMD worker and
compiled DAG are both enabled.
Args:
req_or_tuple: The request to execute the model, or a tuple
containing the request and intermediate tensors.
req_or_tuple: A request or a tuple containing the
request and intermediate tensors. Intermediate tensors are
None unless if it is provided because it is > 0 pipeline
stage. The request is serialized by msgspec.
"""
if isinstance(req_or_tuple, bytes):
serialized_req, intermediate_tensors = req_or_tuple, None
else:
serialized_req, intermediate_tensors = req_or_tuple
execute_model_req = self.input_decoder.decode(serialized_req)
# TODO(swang): This is needed right now because Ray aDAG executes
# on a background thread, so we need to reset torch's current
# device.
......@@ -51,16 +68,14 @@ try:
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
if isinstance(req_or_tuple, tuple):
execute_model_req, intermediate_tensors = req_or_tuple
else:
execute_model_req = req_or_tuple
intermediate_tensors = None
output = self.worker._execute_model_spmd(execute_model_req,
intermediate_tensors)
# Pipeline model request and output to the next pipeline stage.
if isinstance(output, IntermediateTensors):
return execute_model_req, output
output = serialized_req, output
else:
output = self.output_encoder.encode(output)
return output
ray_import_err = None
......
import functools
from array import array
from collections import UserDict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol,
......@@ -21,6 +22,10 @@ logger = init_logger(__name__)
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
# We cannot import it here because of circular dependencies.
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
@dataclass(frozen=True)
class InputContext:
......@@ -118,7 +123,8 @@ class InputRegistry:
# Avoid circular import
from vllm.sequence import SequenceData
dummy_seq_data = SequenceData([0] * seq_len)
dummy_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len)
dummy_multi_modal_data = None
return dummy_seq_data, dummy_multi_modal_data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment