utils.py 6.14 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union

import torch

from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
                         SchedulerConfig, SpeculativeConfig, VllmConfig)
9
10
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
                                    MultiModalKwargsItem, PlaceholderRange)
11
from vllm.sampling_params import SamplingParams
12
from vllm.utils import sha256
13
14
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
                                         init_none_hash)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
                                        KVCacheGroupSpec)
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager

EOS_TOKEN_ID = 50256


def create_scheduler(
    model: str = "facebook/opt-125m",
    max_num_seqs: int = 16,
    max_num_batched_tokens: int = 8192,
    enable_prefix_caching: Optional[bool] = None,
    long_prefill_token_threshold: int = 0,
    disable_chunked_mm_input: bool = False,
    use_kv_connector: bool = False,
    num_blocks: int = 10000,
    block_size: int = 16,
    max_model_len: Optional[int] = None,
    num_speculative_tokens: Optional[int] = None,
    skip_tokenizer_init: bool = False,
    async_scheduling: bool = False,
) -> Union[Scheduler, AsyncScheduler]:
    '''Create scheduler under test.

    Args:
      model: model under test
      max_num_seqs: max sequences to schedule
      max_num_batch_tokens: max num tokens to batch
      enable_prefix_caching: optionally force APC config
                             (True/False) or use default
                             (None)

    Returns:
      {class}`Scheduler` instance
    '''
    if max_model_len is None:
        max_model_len = max_num_batched_tokens
    scheduler_config = SchedulerConfig(
        max_num_seqs=max_num_seqs,
        max_num_batched_tokens=max_num_batched_tokens,
        max_model_len=max_model_len,
        long_prefill_token_threshold=long_prefill_token_threshold,
        disable_chunked_mm_input=disable_chunked_mm_input,
        enable_chunked_prefill=True,
        async_scheduling=async_scheduling,
    )
    model_config = ModelConfig(
        model=model,
        trust_remote_code=True,
        dtype="float16",
        seed=42,
        skip_tokenizer_init=skip_tokenizer_init,
    )
    # Cache config, optionally force APC
    kwargs_cache = ({} if enable_prefix_caching is None else {
        'enable_prefix_caching': enable_prefix_caching
    })
    cache_config = CacheConfig(
        block_size=block_size,
        gpu_memory_utilization=0.9,
        swap_space=0,
        cache_dtype="auto",
        **kwargs_cache,
    )
    kv_transfer_config = KVTransferConfig(
        kv_connector="SharedStorageConnector",
        kv_role="kv_both",
        kv_connector_extra_config={"shared_storage_path": "local_storage"},
    ) if use_kv_connector else None

    speculative_config: Optional[SpeculativeConfig] = None
    if num_speculative_tokens is not None:
        speculative_config = SpeculativeConfig(
            model="ngram", num_speculative_tokens=num_speculative_tokens)

    vllm_config = VllmConfig(
        scheduler_config=scheduler_config,
        model_config=model_config,
        cache_config=cache_config,
        kv_transfer_config=kv_transfer_config,
        speculative_config=speculative_config,
    )
    kv_cache_config = KVCacheConfig(
        num_blocks=num_blocks,  # A large number of blocks to hold all requests
        kv_cache_tensors=[],
        kv_cache_groups=[
            KVCacheGroupSpec(['layer'],
                             FullAttentionSpec(block_size, 1, 1, torch.float32,
                                               False))
        ],
    )
    cache_config.num_gpu_blocks = num_blocks
    scheduler_cls = AsyncScheduler if async_scheduling else Scheduler
    return scheduler_cls(
        vllm_config=vllm_config,
        kv_cache_config=kv_cache_config,
        log_stats=True,
        structured_output_manager=StructuredOutputManager(vllm_config),
    )


119
120
121
_none_hash_initialized = False


122
123
124
def create_requests(
    num_requests: int,
    num_tokens: int = 10,
125
    mm_positions: Optional[list[list[PlaceholderRange]]] = None,
126
127
128
129
    max_tokens: int = 16,
    stop_token_ids: Optional[list[int]] = None,
    prompt_logprobs: Optional[int] = None,
    same_prompt: bool = False,
130
    block_size: int = 16,
131
) -> list[Request]:
132
133
    global _none_hash_initialized
    if not _none_hash_initialized:
134
        init_none_hash(sha256)
135
136
        _none_hash_initialized = True

137
    block_hasher = get_request_block_hasher(block_size, sha256)
138
139
140
141
142
143
    sampling_params = SamplingParams(ignore_eos=False,
                                     max_tokens=max_tokens,
                                     stop_token_ids=stop_token_ids,
                                     prompt_logprobs=prompt_logprobs)
    requests = []
    for i in range(num_requests):
144
        mm_features = []
145
146
        if mm_positions is not None:
            mm_position = mm_positions[i]
147
148
149
150
151
152
153
154
155
156
157
            for j, position in enumerate(mm_position):
                # Dummy hash for each mm item should be unique
                # since encoder cache tracks entries by hash
                identifier = f"hash{i}_{j}"
                mm_feature = MultiModalFeatureSpec(
                    data=MultiModalKwargsItem.dummy("dummy_m"),
                    mm_position=position,
                    identifier=identifier,
                    modality="image")
                mm_features.append(mm_feature)

158
159
160
161
162
163
164
        prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
                            num_tokens)
        request = Request(
            request_id=f"{i}",
            prompt_token_ids=prompt_token_ids,
            sampling_params=sampling_params,
            pooling_params=None,
165
            mm_features=mm_features if mm_features else None,
166
            eos_token_id=EOS_TOKEN_ID,
167
            block_hasher=block_hasher,
168
169
170
        )
        requests.append(request)
    return requests