Unverified Commit ff7ec82c authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)

parent 200a2ffa
...@@ -21,6 +21,7 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 ...@@ -21,6 +21,7 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions >= 4.10 typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq pyzmq
msgspec
librosa # Required for audio processing librosa # Required for audio processing
soundfile # Required for audio processing soundfile # Required for audio processing
gguf == 0.9.1 gguf == 0.9.1
......
...@@ -8,6 +8,7 @@ pytest tests/basic_correctness/test_preemption.py`. ...@@ -8,6 +8,7 @@ pytest tests/basic_correctness/test_preemption.py`.
import pytest import pytest
from prometheus_client import REGISTRY from prometheus_client import REGISTRY
import vllm.envs as envs
from vllm import SamplingParams from vllm import SamplingParams
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
ENABLE_ARTIFICIAL_PREEMPT) ENABLE_ARTIFICIAL_PREEMPT)
...@@ -24,6 +25,13 @@ assert ENABLE_ARTIFICIAL_PREEMPT is True, ( ...@@ -24,6 +25,13 @@ assert ENABLE_ARTIFICIAL_PREEMPT is True, (
"tests/basic_correctness/test_preemption.py`") "tests/basic_correctness/test_preemption.py`")
@pytest.fixture
def worker_use_ray() -> bool:
# When SPMD worker is used, use ray_use_worker=True
# to test delta input optimization works with preemption.
return envs.VLLM_USE_RAY_SPMD_WORKER
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("max_tokens", [96])
...@@ -36,6 +44,7 @@ def test_chunked_prefill_recompute( ...@@ -36,6 +44,7 @@ def test_chunked_prefill_recompute(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
chunked_prefill_token_size: int, chunked_prefill_token_size: int,
worker_use_ray: bool,
) -> None: ) -> None:
"""Ensure that chunked prefill works with preemption.""" """Ensure that chunked prefill works with preemption."""
max_num_seqs = min(chunked_prefill_token_size, 256) max_num_seqs = min(chunked_prefill_token_size, 256)
...@@ -54,6 +63,7 @@ def test_chunked_prefill_recompute( ...@@ -54,6 +63,7 @@ def test_chunked_prefill_recompute(
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill, enable_chunked_prefill=enable_chunked_prefill,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
worker_use_ray=worker_use_ray,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
...@@ -79,6 +89,7 @@ def test_preemption( ...@@ -79,6 +89,7 @@ def test_preemption(
model: str, model: str,
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
worker_use_ray: bool,
) -> None: ) -> None:
"""By default, recompute preemption is enabled""" """By default, recompute preemption is enabled"""
...@@ -89,6 +100,7 @@ def test_preemption( ...@@ -89,6 +100,7 @@ def test_preemption(
model, model,
dtype=dtype, dtype=dtype,
disable_log_stats=False, disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
...@@ -132,6 +144,7 @@ def test_swap( ...@@ -132,6 +144,7 @@ def test_swap(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
beam_width: int, beam_width: int,
worker_use_ray: bool,
) -> None: ) -> None:
"""Use beam search enables swapping.""" """Use beam search enables swapping."""
example_prompts = example_prompts[:1] example_prompts = example_prompts[:1]
...@@ -144,6 +157,7 @@ def test_swap( ...@@ -144,6 +157,7 @@ def test_swap(
dtype=dtype, dtype=dtype,
swap_space=10, swap_space=10,
disable_log_stats=False, disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts, vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens) beam_width, max_tokens)
...@@ -188,6 +202,7 @@ def test_swap_infeasible( ...@@ -188,6 +202,7 @@ def test_swap_infeasible(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
beam_width: int, beam_width: int,
worker_use_ray: bool,
) -> None: ) -> None:
"""Verify infeasible swap request will be ignored.""" """Verify infeasible swap request will be ignored."""
BLOCK_SIZE = 16 BLOCK_SIZE = 16
...@@ -204,6 +219,7 @@ def test_swap_infeasible( ...@@ -204,6 +219,7 @@ def test_swap_infeasible(
# decode blocks are not enough to finish. # decode blocks are not enough to finish.
num_gpu_blocks_override=prefill_blocks + decode_blocks, num_gpu_blocks_override=prefill_blocks + decode_blocks,
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE, max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
worker_use_ray=worker_use_ray,
) as vllm_model: ) as vllm_model:
sampling_params = SamplingParams(n=beam_width, sampling_params = SamplingParams(n=beam_width,
use_beam_search=True, use_beam_search=True,
...@@ -230,6 +246,7 @@ def test_preemption_infeasible( ...@@ -230,6 +246,7 @@ def test_preemption_infeasible(
model: str, model: str,
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
worker_use_ray: bool,
) -> None: ) -> None:
"""Verify infeasible preemption request will be ignored.""" """Verify infeasible preemption request will be ignored."""
BLOCK_SIZE = 16 BLOCK_SIZE = 16
...@@ -244,6 +261,7 @@ def test_preemption_infeasible( ...@@ -244,6 +261,7 @@ def test_preemption_infeasible(
# ignored instead of hanging forever. # ignored instead of hanging forever.
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2, num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE), max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
worker_use_ray=worker_use_ray,
) as vllm_model: ) as vllm_model:
sampling_params = SamplingParams(max_tokens=max_tokens, sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True) ignore_eos=True)
......
import msgspec
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.sequence import ExecuteModelRequest
from ..spec_decode.utils import create_batch
def test_msgspec_serialization():
num_lookahead_slots = 4
seq_group_metadata_list, _, _ = create_batch(16, num_lookahead_slots)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=num_lookahead_slots,
running_queue_size=4)
encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
dec_hook=decode_hook)
req = decoder.decode(encoder.encode(execute_model_req))
expected = execute_model_req.seq_group_metadata_list
actual = req.seq_group_metadata_list
assert (len(expected) == len(actual))
expected = expected[0]
actual = actual[0]
assert expected.block_tables == actual.block_tables
assert expected.is_prompt == actual.is_prompt
assert expected.request_id == actual.request_id
assert (expected.seq_data[0].prompt_token_ids ==
actual.seq_data[0].prompt_token_ids)
assert (expected.seq_data[0].output_token_ids ==
actual.seq_data[0].output_token_ids)
...@@ -22,7 +22,8 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") ...@@ -22,7 +22,8 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
@pytest.mark.skipif(cuda_device_count_stateless() < 2, @pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.") reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model, distributed_executor_backend, attention_backend, test_suite", [ "model, distributed_executor_backend, attention_backend, "
"test_suite", [
("facebook/opt-125m", "ray", "", "L4"), ("facebook/opt-125m", "ray", "", "L4"),
("facebook/opt-125m", "mp", "", "L4"), ("facebook/opt-125m", "mp", "", "L4"),
("meta-llama/Llama-2-7b-hf", "ray", "", "L4"), ("meta-llama/Llama-2-7b-hf", "ray", "", "L4"),
......
...@@ -6,6 +6,8 @@ pytest test_chunked_prefill_distributed.py ...@@ -6,6 +6,8 @@ pytest test_chunked_prefill_distributed.py
``` ```
""" """
import os
import pytest import pytest
from vllm.utils import cuda_device_count_stateless from vllm.utils import cuda_device_count_stateless
...@@ -30,6 +32,11 @@ def test_models( ...@@ -30,6 +32,11 @@ def test_models(
model: str, model: str,
distributed_executor_backend: str, distributed_executor_backend: str,
) -> None: ) -> None:
if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray": # noqa
assert distributed_executor_backend == "ray"
# test ray adag
os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1"
os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1"
dtype = "half" dtype = "half"
max_tokens = 5 max_tokens = 5
......
import itertools import itertools
import random import random
from array import array
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
...@@ -10,7 +11,8 @@ from transformers import GenerationConfig, GenerationMixin ...@@ -10,7 +11,8 @@ from transformers import GenerationConfig, GenerationMixin
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import Counter, is_pin_memory_available from vllm.utils import Counter, is_pin_memory_available
...@@ -56,7 +58,9 @@ def _do_sample( ...@@ -56,7 +58,9 @@ def _do_sample(
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])}, seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
...@@ -201,7 +205,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): ...@@ -201,7 +205,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def create_sequence_data(num_input=3, num_generated=0): def create_sequence_data(num_input=3, num_generated=0):
seq_data = SequenceData( seq_data = SequenceData(
random.choices(range(0, VOCAB_SIZE), k=num_input)) array(VLLM_TOKEN_ID_ARRAY_TYPE,
random.choices(range(0, VOCAB_SIZE), k=num_input)))
if num_generated > 0: if num_generated > 0:
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE), seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
k=num_generated) k=num_generated)
...@@ -504,7 +509,9 @@ def test_sampler_mixed(seed: int, device: str): ...@@ -504,7 +509,9 @@ def test_sampler_mixed(seed: int, device: str):
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])}, seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
...@@ -600,7 +607,9 @@ def test_sampler_top_k_top_p(seed: int, device: str): ...@@ -600,7 +607,9 @@ def test_sampler_top_k_top_p(seed: int, device: str):
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])}, seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=SamplingParams( sampling_params=SamplingParams(
temperature=1, temperature=1,
top_k=top_k, top_k=top_k,
...@@ -650,7 +659,11 @@ def test_sampler_repetition_penalty_mixed(device: str): ...@@ -650,7 +659,11 @@ def test_sampler_repetition_penalty_mixed(device: str):
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])}, seq_data={
0:
SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
[1, 2, 3]))
},
sampling_params=sampling_params[i], sampling_params=sampling_params[i],
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
......
from array import array
from itertools import count from itertools import count
from typing import Callable, Dict, List, Optional from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
...@@ -9,7 +10,8 @@ import torch ...@@ -9,7 +10,8 @@ import torch
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceData, SequenceGroupMetadata, SamplerOutput, SequenceData, SequenceGroupMetadata,
SequenceOutput) SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import get_distributed_init_method, get_ip, get_open_port
...@@ -138,8 +140,9 @@ def create_seq_group_metadata_from_prompts( ...@@ -138,8 +140,9 @@ def create_seq_group_metadata_from_prompts(
seq_data={ seq_data={
i: i:
SequenceData( SequenceData(
prompt_token_ids=prompt_token_ids[:], array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]),
output_token_ids=cont_token_ids[:], _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
cont_token_ids[:]),
), ),
}, },
sampling_params=SamplingParams(temperature=0.0, ), sampling_params=SamplingParams(temperature=0.0, ),
......
import random import random
from array import array
from typing import Tuple from typing import Tuple
from unittest.mock import patch from unittest.mock import patch
...@@ -8,7 +9,8 @@ import torch ...@@ -8,7 +9,8 @@ import torch
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
...@@ -69,7 +71,9 @@ def test_logits_processors(seed: int, device: str): ...@@ -69,7 +71,9 @@ def test_logits_processors(seed: int, device: str):
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])}, seq_data={
0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
},
sampling_params=SamplingParams(temperature=0, sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]), logits_processors=[pick_ith]),
block_tables={0: [1]}, block_tables={0: [1]},
......
from array import array
import pytest import pytest
from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE,
CompletionSequenceGroupOutput, SamplerOutput,
SequenceData, SequenceOutput) SequenceData, SequenceOutput)
from .core.utils import create_dummy_prompt from .core.utils import create_dummy_prompt
...@@ -54,7 +57,7 @@ def test_sampler_output_eq(sample_outputs): ...@@ -54,7 +57,7 @@ def test_sampler_output_eq(sample_outputs):
def test_sequence_data_prefill(): def test_sequence_data_prefill():
seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4]) seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4]))
assert seq_data.get_num_uncomputed_tokens() == 4 assert seq_data.get_num_uncomputed_tokens() == 4
assert seq_data.get_num_computed_tokens() == 0 assert seq_data.get_num_computed_tokens() == 0
# advance by 2 # advance by 2
......
from array import array
from typing import List from typing import List
import pytest import pytest
import torch import torch
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import is_cpu from vllm.utils import is_cpu
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
...@@ -125,10 +127,12 @@ def test_prepare_prompt( ...@@ -125,10 +127,12 @@ def test_prepare_prompt(
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len) seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len))) seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len) encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len))) encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len)))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
...@@ -319,10 +323,12 @@ def test_prepare_decode( ...@@ -319,10 +323,12 @@ def test_prepare_decode(
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len) seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len))) seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len) encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData(list(range(encoder_seq_len))) encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=False, is_prompt=False,
......
from array import array
from typing import List from typing import List
import pytest import pytest
...@@ -7,7 +8,8 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, ...@@ -7,7 +8,8 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import get_open_port from vllm.utils import get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
...@@ -46,7 +48,8 @@ def test_prepare_prompt(batch_size): ...@@ -46,7 +48,8 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len) seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len))) seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
...@@ -163,7 +166,8 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -163,7 +166,8 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1 context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len) context_lens.append(context_len)
seq_data = SequenceData(list(range(context_len))) seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)))
seq_data.update_num_computed_tokens(context_len) seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished. # Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0) seq_data.append_token_id(1, 0)
...@@ -324,7 +328,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): ...@@ -324,7 +328,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len) seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len))) seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE,
range(seq_len)))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
...@@ -340,7 +345,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): ...@@ -340,7 +345,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for i in range(prefill_batch_size, batch_size): for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1 context_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(context_len)) prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))
seq_data = SequenceData(prompt_toks) seq_data = SequenceData(prompt_toks)
seq_data.append_token_id(1, 0) seq_data.append_token_id(1, 0)
seq_data.update_num_computed_tokens(context_len) seq_data.update_num_computed_tokens(context_len)
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass
class AdapterRequest(ABC): class AdapterRequest(ABC):
""" """
Base class for adapter requests. Base class for adapter requests.
......
...@@ -770,8 +770,8 @@ class ParallelConfig: ...@@ -770,8 +770,8 @@ class ParallelConfig:
self.tokenizer_pool_config = tokenizer_pool_config self.tokenizer_pool_config = tokenizer_pool_config
self.ray_workers_use_nsight = ray_workers_use_nsight self.ray_workers_use_nsight = ray_workers_use_nsight
self.placement_group = placement_group self.placement_group = placement_group
self.world_size = pipeline_parallel_size * self.tensor_parallel_size self.world_size = pipeline_parallel_size * self.tensor_parallel_size
if worker_use_ray: if worker_use_ray:
if self.distributed_executor_backend is None: if self.distributed_executor_backend is None:
self.distributed_executor_backend = "ray" self.distributed_executor_backend = "ray"
...@@ -867,6 +867,11 @@ class SchedulerConfig: ...@@ -867,6 +867,11 @@ class SchedulerConfig:
swapping. However, when the sequence group has multiple sequences swapping. However, when the sequence group has multiple sequences
(e.g., beam search), recomputation is not currently supported. In (e.g., beam search), recomputation is not currently supported. In
such a case, we use swapping instead. such a case, we use swapping instead.
send_delta_data: Private API. If used, scheduler sends delta data to
workers instead of an entire data. It should be enabled only
when SPMD worker architecture is enabled. I.e.,
VLLM_USE_RAY_SPMD_WORKER=1
""" """
def __init__(self, def __init__(self,
...@@ -879,7 +884,8 @@ class SchedulerConfig: ...@@ -879,7 +884,8 @@ class SchedulerConfig:
enable_chunked_prefill: bool = False, enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False, embedding_mode: Optional[bool] = False,
preemption_mode: Optional[str] = None, preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1) -> None: num_scheduler_steps: int = 1,
send_delta_data: bool = False) -> None:
if max_num_batched_tokens is not None: if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
else: else:
...@@ -909,6 +915,7 @@ class SchedulerConfig: ...@@ -909,6 +915,7 @@ class SchedulerConfig:
self.embedding_mode = embedding_mode self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode self.preemption_mode = preemption_mode
self.num_scheduler_steps = num_scheduler_steps self.num_scheduler_steps = num_scheduler_steps
self.send_delta_data = send_delta_data
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:
......
...@@ -12,7 +12,8 @@ from vllm.logger import init_logger ...@@ -12,7 +12,8 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus) SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)
from vllm.utils import PyObjectCache from vllm.utils import PyObjectCache
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -363,8 +364,6 @@ class Scheduler: ...@@ -363,8 +364,6 @@ class Scheduler:
self.num_cumulative_preemption: int = 0 self.num_cumulative_preemption: int = 0
# Used to cache python objects # Used to cache python objects
self._seq_group_metadata_cache: PyObjectCache = PyObjectCache(
seq_group_metadata_builder)
self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache( self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache(
scheduler_running_outputs_builder) scheduler_running_outputs_builder)
self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache( self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache(
...@@ -1048,15 +1047,10 @@ class Scheduler: ...@@ -1048,15 +1047,10 @@ class Scheduler:
token_chunk_size = scheduled_seq_group.token_chunk_size token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.maybe_set_first_scheduled_time(now) seq_group.maybe_set_first_scheduled_time(now)
seq_group_metadata = self._seq_group_metadata_cache.get_object()
seq_group_metadata.seq_data.clear()
seq_group_metadata.block_tables.clear()
# seq_id -> SequenceData # seq_id -> SequenceData
seq_data: Dict[int, SequenceData] = seq_group_metadata.seq_data seq_data: Dict[int, SequenceData] = {}
# seq_id -> physical block numbers # seq_id -> physical block numbers
block_tables: Dict[int, block_tables: Dict[int, List[int]] = {}
List[int]] = seq_group_metadata.block_tables
if seq_group.is_encoder_decoder(): if seq_group.is_encoder_decoder():
# Encoder associated with SequenceGroup # Encoder associated with SequenceGroup
...@@ -1081,45 +1075,65 @@ class Scheduler: ...@@ -1081,45 +1075,65 @@ class Scheduler:
seq_group.get_seqs(status=SequenceStatus.RUNNING))) seq_group.get_seqs(status=SequenceStatus.RUNNING)))
do_sample = True do_sample = True
if seq_group.is_prefill(): is_prompt = seq_group.is_prefill()
# We should send the metadata to workers when the first prefill
# is sent. Subsequent requests could be chunked prefill or decode.
is_first_prefill = False
if is_prompt:
seqs = seq_group.get_seqs() seqs = seq_group.get_seqs()
# Prefill has only 1 sequence. # Prefill has only 1 sequence.
assert len(seqs) == 1 assert len(seqs) == 1
num_computed_tokens = seqs[0].data.get_num_computed_tokens()
is_first_prefill = num_computed_tokens == 0
# In the next iteration, all prompt tokens are not computed. # In the next iteration, all prompt tokens are not computed.
# It means the prefill is chunked, and we don't need sampling. # It means the prefill is chunked, and we don't need sampling.
# NOTE: We use get_len instead of get_prompt_len because when # NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated # a sequence is preempted, prefill includes previous generated
# output tokens. # output tokens.
if (token_chunk_size + seqs[0].data.get_num_computed_tokens() < if (token_chunk_size + num_computed_tokens <
seqs[0].data.get_len()): seqs[0].data.get_len()):
do_sample = False do_sample = False
# It assumes the scheduled_seq_groups is ordered by # It assumes the scheduled_seq_groups is ordered by
# prefill < decoding. # prefill < decoding.
is_prompt = seq_group.is_prefill() if is_first_prefill or not self.scheduler_config.send_delta_data:
seq_group_metadata = SequenceGroupMetadata(
seq_group_metadata.__init__( request_id=seq_group.request_id,
request_id=seq_group.request_id, is_prompt=is_prompt,
is_prompt=is_prompt, seq_data=seq_data,
seq_data=seq_data, sampling_params=seq_group.sampling_params,
sampling_params=seq_group.sampling_params, block_tables=block_tables,
block_tables=block_tables, do_sample=do_sample,
do_sample=do_sample, pooling_params=seq_group.pooling_params,
pooling_params=seq_group.pooling_params, token_chunk_size=token_chunk_size,
token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request,
lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums,
computed_block_nums=common_computed_block_nums, encoder_seq_data=encoder_seq_data,
encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table,
cross_block_table=cross_block_table, state=seq_group.state,
state=seq_group.state, # `multi_modal_data` will only be present for the 1st comm
# `multi_modal_data` will only be present for the 1st comm # between engine and worker.
# between engine and worker. # the subsequent comms can still use delta, but
# the subsequent comms can still use delta, but # `multi_modal_data` will be None.
# `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data
multi_modal_data=seq_group.multi_modal_data if scheduler_outputs.num_prefill_groups > 0 else None,
if scheduler_outputs.num_prefill_groups > 0 else None, prompt_adapter_request=seq_group.prompt_adapter_request,
prompt_adapter_request=seq_group.prompt_adapter_request, )
) else:
# When SPMD mode is enabled, we only send delta data except for
# the first request to reduce serialization cost.
seq_data_delta = {}
for id, data in seq_data.items():
seq_data_delta[id] = data.get_delta_and_reset()
seq_group_metadata = SequenceGroupMetadataDelta(
seq_data_delta,
seq_group.request_id,
block_tables,
is_prompt,
do_sample=do_sample,
token_chunk_size=token_chunk_size,
computed_block_nums=common_computed_block_nums,
)
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
# Now that the batch has been created, we can assume all blocks in the # Now that the batch has been created, we can assume all blocks in the
...@@ -1130,8 +1144,6 @@ class Scheduler: ...@@ -1130,8 +1144,6 @@ class Scheduler:
self.block_manager.mark_blocks_as_computed( self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group) scheduled_seq_group.seq_group)
self._seq_group_metadata_cache.reset()
scheduler_time = time.perf_counter() - scheduler_start_time scheduler_time = time.perf_counter() - scheduler_start_time
# Add this to scheduler time to all the sequences that are currently # Add this to scheduler time to all the sequences that are currently
# running. This will help estimate if the scheduler is a significant # running. This will help estimate if the scheduler is a significant
......
...@@ -5,6 +5,7 @@ from dataclasses import dataclass ...@@ -5,6 +5,7 @@ from dataclasses import dataclass
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type, from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
Union) Union)
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, ObservabilityConfig, ParallelConfig,
...@@ -905,6 +906,8 @@ class EngineArgs: ...@@ -905,6 +906,8 @@ class EngineArgs:
embedding_mode=model_config.embedding_mode, embedding_mode=model_config.embedding_mode,
preemption_mode=self.preemption_mode, preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps, num_scheduler_steps=self.num_scheduler_steps,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray),
) )
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
......
...@@ -224,7 +224,6 @@ class LLMEngine: ...@@ -224,7 +224,6 @@ class LLMEngine:
cache_config.enable_prefix_caching, cache_config.enable_prefix_caching,
) )
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
load_general_plugins() load_general_plugins()
......
from array import array
from typing import Any, Type
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
def encode_hook(obj: Any) -> Any:
"""Custom msgspec enc hook that supports array types.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if isinstance(obj, array):
assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, (
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
f"Given array has a type code of {obj.typecode}.")
return obj.tobytes()
def decode_hook(type: Type, obj: Any) -> Any:
"""Custom msgspec dec hook that supports array types.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
if type is array:
deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE)
deserialized.frombytes(obj)
return deserialized
...@@ -4,9 +4,12 @@ from collections import defaultdict ...@@ -4,9 +4,12 @@ from collections import defaultdict
from itertools import islice, repeat from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import msgspec
import vllm.envs as envs import vllm.envs as envs
from vllm.executor.distributed_gpu_executor import ( # yapf: disable from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync) DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.msgspec_utils import encode_hook
from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
...@@ -60,6 +63,10 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -60,6 +63,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers. # Create the parallel GPU workers.
self._init_workers_ray(placement_group) self._init_workers_ray(placement_group)
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
self.output_decoder = msgspec.msgpack.Decoder(
Optional[List[SamplerOutput]])
def shutdown(self) -> None: def shutdown(self) -> None:
if hasattr(self, "forward_dag") and self.forward_dag is not None: if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown() self.forward_dag.teardown()
...@@ -123,6 +130,7 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -123,6 +130,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
ray_remote_kwargs) ray_remote_kwargs)
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers. # Create the workers.
driver_ip = get_ip() driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args() worker_wrapper_kwargs = self._get_worker_wrapper_args()
...@@ -304,8 +312,10 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -304,8 +312,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
if self.forward_dag is None: if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
outputs = ray.get(self.forward_dag.execute(execute_model_req)) serialized_data = self.input_encoder.encode(execute_model_req)
return outputs[0] outputs = ray.get(self.forward_dag.execute(serialized_data))
output = self.output_decoder.decode(outputs[0])
return output
def _run_workers( def _run_workers(
self, self,
...@@ -475,9 +485,10 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): ...@@ -475,9 +485,10 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
if self.forward_dag is None: if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
dag_future = await self.forward_dag.execute_async(execute_model_req) serialized_data = self.input_encoder.encode(execute_model_req)
dag_future = await self.forward_dag.execute_async(serialized_data)
outputs = await dag_future outputs = await dag_future
return outputs[0] return self.output_decoder.decode(outputs[0])
async def _driver_execute_model_async( async def _driver_execute_model_async(
self, self,
......
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import msgspec
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
...@@ -24,6 +27,10 @@ try: ...@@ -24,6 +27,10 @@ try:
# that thread. # that thread.
self.compiled_dag_cuda_device_set = False self.compiled_dag_cuda_device_set = False
self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
dec_hook=decode_hook)
self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
def get_node_ip(self) -> str: def get_node_ip(self) -> str:
return get_ip() return get_ip()
...@@ -33,16 +40,26 @@ try: ...@@ -33,16 +40,26 @@ try:
return node_id, gpu_ids return node_id, gpu_ids
def execute_model_spmd( def execute_model_spmd(
self, req_or_tuple: Union[ExecuteModelRequest, self, req_or_tuple: Union[bytes,
Tuple[ExecuteModelRequest, Tuple[bytes,
IntermediateTensors]]): Optional[IntermediateTensors]]]
) -> bytes:
"""Execute model in SPMD fashion: used only when SPMD worker and """Execute model in SPMD fashion: used only when SPMD worker and
compiled DAG are both enabled. compiled DAG are both enabled.
Args: Args:
req_or_tuple: The request to execute the model, or a tuple req_or_tuple: A request or a tuple containing the
containing the request and intermediate tensors. request and intermediate tensors. Intermediate tensors are
None unless if it is provided because it is > 0 pipeline
stage. The request is serialized by msgspec.
""" """
if isinstance(req_or_tuple, bytes):
serialized_req, intermediate_tensors = req_or_tuple, None
else:
serialized_req, intermediate_tensors = req_or_tuple
execute_model_req = self.input_decoder.decode(serialized_req)
# TODO(swang): This is needed right now because Ray aDAG executes # TODO(swang): This is needed right now because Ray aDAG executes
# on a background thread, so we need to reset torch's current # on a background thread, so we need to reset torch's current
# device. # device.
...@@ -51,16 +68,14 @@ try: ...@@ -51,16 +68,14 @@ try:
torch.cuda.set_device(self.worker.device) torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True self.compiled_dag_cuda_device_set = True
if isinstance(req_or_tuple, tuple):
execute_model_req, intermediate_tensors = req_or_tuple
else:
execute_model_req = req_or_tuple
intermediate_tensors = None
output = self.worker._execute_model_spmd(execute_model_req, output = self.worker._execute_model_spmd(execute_model_req,
intermediate_tensors) intermediate_tensors)
# Pipeline model request and output to the next pipeline stage.
if isinstance(output, IntermediateTensors): if isinstance(output, IntermediateTensors):
return execute_model_req, output output = serialized_req, output
else:
output = self.output_encoder.encode(output)
return output return output
ray_import_err = None ray_import_err = None
......
import functools import functools
from array import array
from collections import UserDict from collections import UserDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol, from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol,
...@@ -21,6 +22,10 @@ logger = init_logger(__name__) ...@@ -21,6 +22,10 @@ logger = init_logger(__name__)
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
# We cannot import it here because of circular dependencies.
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
@dataclass(frozen=True) @dataclass(frozen=True)
class InputContext: class InputContext:
...@@ -118,7 +123,8 @@ class InputRegistry: ...@@ -118,7 +123,8 @@ class InputRegistry:
# Avoid circular import # Avoid circular import
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
dummy_seq_data = SequenceData([0] * seq_len) dummy_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len)
dummy_multi_modal_data = None dummy_multi_modal_data = None
return dummy_seq_data, dummy_multi_modal_data return dummy_seq_data, dummy_multi_modal_data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment