Commit f48954a4 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.5.0

parents 1dba29d3 8f89d720
...@@ -55,7 +55,7 @@ class SpeculativeScores: ...@@ -55,7 +55,7 @@ class SpeculativeScores:
class SpeculativeProposer(ABC): class SpeculativeProposer(ABC):
@abstractmethod @abstractmethod
def get_proposals( def get_spec_proposals(
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals: ) -> SpeculativeProposals:
......
...@@ -7,11 +7,12 @@ import torch ...@@ -7,11 +7,12 @@ import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
class MultiStepWorker(Worker): class MultiStepWorker(Worker, ProposerWorkerBase):
"""The MultiStepWorker is equivalent to a Worker except that it allows """The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead allocated enough space to store the additional KV. This reduces overhead
...@@ -33,7 +34,7 @@ class MultiStepWorker(Worker): ...@@ -33,7 +34,7 @@ class MultiStepWorker(Worker):
super().init_device() super().init_device()
self._proposer = Top1Proposer( self._proposer = Top1Proposer(
weakref.proxy(self), weakref.proxy(self), # type: ignore[arg-type]
self.device, self.device,
self.vocab_size, self.vocab_size,
max_proposal_len=self.max_model_len, max_proposal_len=self.max_model_len,
...@@ -92,11 +93,12 @@ class MultiStepWorker(Worker): ...@@ -92,11 +93,12 @@ class MultiStepWorker(Worker):
speculative tokens per sequence is determined by max_proposal_len. speculative tokens per sequence is determined by max_proposal_len.
""" """
return self._proposer.get_proposals(execute_model_req) return self._proposer.get_spec_proposals(execute_model_req)
@staticmethod
def _append_new_tokens( def _append_new_tokens(
self, model_output: SamplerOutput, model_output: List[SamplerOutput],
seq_group_metadata_list: SequenceGroupMetadata) -> None: seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
"""Given model output from a single run, append the tokens to the """Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes. required if the worker is to perform multiple forward passes.
...@@ -116,8 +118,9 @@ class MultiStepWorker(Worker): ...@@ -116,8 +118,9 @@ class MultiStepWorker(Worker):
seq.append_token_id(token_id, token_logprob.logprob) seq.append_token_id(token_id, token_logprob.logprob)
seq.update_num_computed_tokens(1) seq.update_num_computed_tokens(1)
@staticmethod
def _shallow_copy_inputs( def _shallow_copy_inputs(
self, seq_group_metadata_list: List[SequenceGroupMetadata] seq_group_metadata_list: List[SequenceGroupMetadata]
) -> List[SequenceGroupMetadata]: ) -> List[SequenceGroupMetadata]:
"""Copy input data structures to remove side-effects when input data """Copy input data structures to remove side-effects when input data
structures are shared with other modules. structures are shared with other modules.
......
...@@ -5,15 +5,16 @@ import torch ...@@ -5,15 +5,16 @@ import torch
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class NGramWorker(LoraNotSupportedWorkerBase): class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
"""NGramWorker provides a light drafter without need for model. """NGramWorker provides a light drafter without need for model.
Current NGramWorker only implement prompt lookup decoding, Current NGramWorker only implement prompt lookup decoding,
and in future we may also do RAG type drafter and other scenerios and in future we may also do RAG type drafter and other scenarios
which don't rely on LLM model to give proposals. which don't rely on LLM model to give proposals.
""" """
...@@ -38,34 +39,11 @@ class NGramWorker(LoraNotSupportedWorkerBase): ...@@ -38,34 +39,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
# Current only support Top1Proposer # Current only support Top1Proposer
self._proposer = Top1Proposer( self._proposer = Top1Proposer(
weakref.proxy(self), weakref.proxy(self), # type: ignore[arg-type]
device=self.device, device=self.device,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
) )
def set_include_gpu_probs_tensor(self):
# NGram don't need gpu sampler
pass
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None) -> None:
"""NGram doesn't depend on model execution, just pass this function"""
pass
def determine_num_available_blocks(self) -> None:
"""NGram doesn't depend on model execution, no need to check blocks"""
pass
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""As there is no cache need to handle, just pass this function"""
pass
def get_cache_block_size_bytes(self):
"""Return the size of a cache block in bytes."""
return 0
def sampler_output( def sampler_output(
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
...@@ -97,7 +75,6 @@ class NGramWorker(LoraNotSupportedWorkerBase): ...@@ -97,7 +75,6 @@ class NGramWorker(LoraNotSupportedWorkerBase):
-1, -1,
): ):
ngram_tensor = input_ids[-ngram_size:] ngram_tensor = input_ids[-ngram_size:]
proposal_start_idx = None
if ngram_size == 1: if ngram_size == 1:
# Do not match itself and do not use unfold and all # Do not match itself and do not use unfold and all
matches = (input_ids[:-1] == ngram_tensor) matches = (input_ids[:-1] == ngram_tensor)
...@@ -161,7 +138,7 @@ class NGramWorker(LoraNotSupportedWorkerBase): ...@@ -161,7 +138,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
speculative tokens per sequence is determined by max_proposal_len. speculative tokens per sequence is determined by max_proposal_len.
""" """
return self._proposer.get_proposals(execute_model_req) return self._proposer.get_spec_proposals(execute_model_req)
def _raise_if_unsupported( def _raise_if_unsupported(
self, self,
......
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposer
from vllm.worker.worker_base import WorkerBase
class ProposerWorkerBase(WorkerBase, SpeculativeProposer):
"""Interface for proposer workers"""
@abstractmethod
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[Optional[List[SamplerOutput]], bool]:
raise NotImplementedError
def set_include_gpu_probs_tensor(self):
"""Implementation optional"""
pass
class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
"""Proposer worker which does not use a model with kvcache"""
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""get_spec_proposals is used to get the proposals"""
return []
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""This is never called on the proposer, only the target model"""
raise NotImplementedError
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
pass
def get_cache_block_size_bytes(self) -> int:
return 0
...@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple ...@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from vllm.config import SpeculativeConfig
from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
...@@ -14,6 +15,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals, ...@@ -14,6 +15,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.util import (create_sequence_group_output, from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs, get_all_seq_ids, get_all_num_logprobs, get_all_seq_ids,
get_sampled_token_logprobs, nvtx_range, get_sampled_token_logprobs, nvtx_range,
...@@ -29,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": ...@@ -29,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
""" """
assert "speculative_config" in kwargs assert "speculative_config" in kwargs
speculative_config = kwargs.get("speculative_config") speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
assert speculative_config is not None assert speculative_config is not None
target_worker = Worker(*args, **kwargs) target_worker = Worker(*args, **kwargs)
...@@ -108,16 +110,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -108,16 +110,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger.info("Configuring SpecDecodeWorker with proposer=%s", logger.info("Configuring SpecDecodeWorker with proposer=%s",
type(proposer_worker)) type(proposer_worker))
return SpecDecodeWorker( return SpecDecodeWorker(proposer_worker,
proposer_worker, scorer_worker,
scorer_worker, disable_by_batch_size=disable_by_batch_size,
disable_by_batch_size=disable_by_batch_size, rejection_sampler=RejectionSampler(
rejection_sampler=RejectionSampler( disable_bonus_tokens=disable_bonus_tokens))
disable_bonus_tokens=disable_bonus_tokens, ))
def __init__( def __init__(
self, self,
proposer_worker: WorkerBase, proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase, scorer_worker: WorkerBase,
rejection_sampler: RejectionSampler, rejection_sampler: RejectionSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None, metrics_collector: Optional[AsyncMetricsCollector] = None,
...@@ -260,7 +261,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -260,7 +261,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# This is required as if the number of draft model runs changes # This is required as if the number of draft model runs changes
# dynamically, the non-driver workers won't know unless we perform a # dynamically, the non-driver workers won't know unless we perform a
# communication to inform then. # communication to inform them.
broadcast_dict = dict( broadcast_dict = dict(
num_lookahead_slots=num_lookahead_slots, num_lookahead_slots=num_lookahead_slots,
disable_all_speculation=disable_all_speculation, disable_all_speculation=disable_all_speculation,
......
...@@ -6,8 +6,8 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput, ...@@ -6,8 +6,8 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer) SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.util import sampler_output_to_torch from vllm.spec_decode.util import sampler_output_to_torch
from vllm.worker.worker_base import WorkerBase
class Top1Proposer(SpeculativeProposer): class Top1Proposer(SpeculativeProposer):
...@@ -29,7 +29,7 @@ class Top1Proposer(SpeculativeProposer): ...@@ -29,7 +29,7 @@ class Top1Proposer(SpeculativeProposer):
def __init__( def __init__(
self, self,
worker: WorkerBase, worker: ProposerWorkerBase,
device: str, device: str,
vocab_size: int, vocab_size: int,
max_proposal_len: Optional[int] = None, max_proposal_len: Optional[int] = None,
...@@ -39,7 +39,7 @@ class Top1Proposer(SpeculativeProposer): ...@@ -39,7 +39,7 @@ class Top1Proposer(SpeculativeProposer):
self.max_proposal_len = max_proposal_len self.max_proposal_len = max_proposal_len
self._vocab_size = vocab_size self._vocab_size = vocab_size
def get_proposals( def get_spec_proposals(
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals: ) -> SpeculativeProposals:
...@@ -148,7 +148,8 @@ class Top1Proposer(SpeculativeProposer): ...@@ -148,7 +148,8 @@ class Top1Proposer(SpeculativeProposer):
nonzero_proposal_len_indices, nonzero_proposal_len_indices,
) )
def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output, @staticmethod
def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
nonzero_proposal_len_indices, transposed): nonzero_proposal_len_indices, transposed):
"""Remove sequences from nonzero_proposal_len_indices and reset """Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal their proposal_len to 0 the draft worker does not provide a proposal
...@@ -207,7 +208,7 @@ class Top1Proposer(SpeculativeProposer): ...@@ -207,7 +208,7 @@ class Top1Proposer(SpeculativeProposer):
self, self,
batch_size: int, batch_size: int,
proposal_len: int, proposal_len: int,
maybe_sampler_output: Optional[SamplerOutput], maybe_sampler_output: Optional[List[SamplerOutput]],
proposal_lens: List[int], proposal_lens: List[int],
nonzero_proposal_len_indices: List[int], nonzero_proposal_len_indices: List[int],
sampler_transposed: bool, sampler_transposed: bool,
...@@ -218,25 +219,19 @@ class Top1Proposer(SpeculativeProposer): ...@@ -218,25 +219,19 @@ class Top1Proposer(SpeculativeProposer):
if maybe_sampler_output is None: if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None. # If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals. # In this case we return empty proposals.
proposal_tokens = torch.full( proposal_tokens = torch.tensor(-1,
size=( dtype=torch.long,
batch_size, device=self._device).expand(
proposal_len, batch_size, proposal_len)
), proposal_probs = torch.tensor(0,
fill_value=-1, dtype=torch.float32,
dtype=torch.long, device=self._device).expand(
device=self._device, batch_size, proposal_len,
) self._vocab_size)
proposal_probs = torch.zeros( proposal_lens_tensor = torch.tensor(0,
batch_size, dtype=torch.long,
proposal_len, device=self._device).expand(
self._vocab_size, len(proposal_lens))
dtype=torch.float32,
device=self._device,
)
proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens_tensor return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output sampler_output = maybe_sampler_output
...@@ -246,18 +241,14 @@ class Top1Proposer(SpeculativeProposer): ...@@ -246,18 +241,14 @@ class Top1Proposer(SpeculativeProposer):
# Now, reformat the output GPU tensors such that each sequence has # Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1] # a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = torch.full( entire_proposal_tokens = proposal_tokens.new_full(
size=(batch_size, *proposal_tokens.shape[1:]), size=(batch_size, *proposal_tokens.shape[1:]),
fill_value=-1, fill_value=-1,
dtype=torch.long,
device=self._device,
) )
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros( entire_proposal_probs = proposal_probs.new_zeros(
batch_size, batch_size,
*proposal_probs.shape[1:], *proposal_probs.shape[1:],
dtype=torch.float32,
device=self._device,
) )
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
......
from contextlib import contextmanager from contextlib import contextmanager
from itertools import chain
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata, SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput) SequenceOutput)
SeqId = int SeqId = int
...@@ -16,11 +15,7 @@ def get_all_seq_ids( ...@@ -16,11 +15,7 @@ def get_all_seq_ids(
"""Given a list of SequenceGroupMetadata, create a list of all """Given a list of SequenceGroupMetadata, create a list of all
sequence ids. sequence ids.
""" """
return list( return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data]
chain.from_iterable([
seq_group_metadata.seq_data.keys()
for seq_group_metadata in seq_group_metadata_list
]))
def get_all_num_logprobs( def get_all_num_logprobs(
...@@ -68,7 +63,7 @@ def create_sequence_group_output( ...@@ -68,7 +63,7 @@ def create_sequence_group_output(
seq_id: SeqId, seq_id: SeqId,
topk_token_ids: List[int], topk_token_ids: List[int],
topk_logprobs: List[float], topk_logprobs: List[float],
) -> SequenceGroupOutput: ) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results. """Create a SequenceGroupOutput given the sampling results.
Args: Args:
......
from typing import Dict, Optional from typing import Dict, Optional, Type
from transformers import AutoConfig, PretrainedConfig from transformers import PretrainedConfig
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MPTConfig, RWConfig) JAISConfig, MPTConfig, RWConfig)
logger = init_logger(__name__) logger = init_logger(__name__)
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chatglm": ChatGLMConfig, "chatglm": ChatGLMConfig,
"dbrx": DbrxConfig, "dbrx": DbrxConfig,
"mpt": MPTConfig, "mpt": MPTConfig,
...@@ -22,8 +23,13 @@ def get_config(model: str, ...@@ -22,8 +23,13 @@ def get_config(model: str,
trust_remote_code: bool, trust_remote_code: bool,
revision: Optional[str] = None, revision: Optional[str] = None,
code_revision: Optional[str] = None, code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None) -> PretrainedConfig: rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None) -> PretrainedConfig:
try: try:
if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
else:
from transformers import AutoConfig
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, model,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
...@@ -45,10 +51,12 @@ def get_config(model: str, ...@@ -45,10 +51,12 @@ def get_config(model: str,
config = config_class.from_pretrained(model, config = config_class.from_pretrained(model,
revision=revision, revision=revision,
code_revision=code_revision) code_revision=code_revision)
if rope_scaling is not None: for key, value in [("rope_scaling", rope_scaling),
logger.info("Updating rope_scaling from %r to %r", ("rope_theta", rope_theta)]:
getattr(config, "rope_scaling", None), rope_scaling) if value is not None:
config.update({"rope_scaling": rope_scaling}) logger.info("Updating %s from %r to %r", key,
getattr(config, key, None), value)
config.update({key: value})
return config return config
...@@ -63,4 +71,4 @@ def get_hf_text_config(config: PretrainedConfig): ...@@ -63,4 +71,4 @@ def get_hf_text_config(config: PretrainedConfig):
assert hasattr(config.text_config, "num_attention_heads") assert hasattr(config.text_config, "num_attention_heads")
return config.text_config return config.text_config
else: else:
return config return config
\ No newline at end of file
from functools import lru_cache
from typing import Optional
from transformers import AutoImageProcessor
from transformers.image_processing_utils import BaseImageProcessor
from vllm.logger import init_logger
logger = init_logger(__name__)
def get_image_processor(
processor_name: str,
*args,
trust_remote_code: bool = False,
revision: Optional[str] = None,
**kwargs,
) -> BaseImageProcessor:
"""Gets an image processor for the given model name via HuggingFace."""
try:
processor: BaseImageProcessor = AutoImageProcessor.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
if not trust_remote_code:
err_msg = (
"Failed to load the image processor. If the image processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
return processor
cached_get_image_processor = lru_cache(get_image_processor)
...@@ -17,10 +17,12 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, ...@@ -17,10 +17,12 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Tuple, TypeVar, Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
Union) Union)
import numpy as np
import psutil import psutil
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import enable_trace_function_call, init_logger from vllm.logger import enable_trace_function_call, init_logger
T = TypeVar("T") T = TypeVar("T")
...@@ -147,12 +149,8 @@ def is_neuron() -> bool: ...@@ -147,12 +149,8 @@ def is_neuron() -> bool:
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int: def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes.""" """Returns the maximum shared memory per thread block in bytes."""
# NOTE: This import statement should be executed lazily since
# the Neuron-X backend does not have the `cuda_utils` module.
from vllm._C import cuda_utils
max_shared_mem = ( max_shared_mem = (
cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu)) ops.get_max_shared_memory_per_block_device_attribute(gpu))
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
# will fail # will fail
assert max_shared_mem > 0, "max_shared_mem can not be zero" assert max_shared_mem > 0, "max_shared_mem can not be zero"
...@@ -288,7 +286,15 @@ def get_distributed_init_method(ip: str, port: int) -> str: ...@@ -288,7 +286,15 @@ def get_distributed_init_method(ip: str, port: int) -> str:
def get_open_port() -> int: def get_open_port() -> int:
port = envs.VLLM_PORT port = envs.VLLM_PORT
if port is not None: if port is not None:
return port while True:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", port))
return port
except OSError:
port += 1 # Increment port number if already in use
logger.info("Port %d is already in use, trying port %d",
port - 1, port)
# try ipv4 # try ipv4
try: try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
...@@ -501,11 +507,6 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]: ...@@ -501,11 +507,6 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]:
f"(e.g., 1, 2, 3). Given input: {s}") from e f"(e.g., 1, 2, 3). Given input: {s}") from e
def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len
return x + [pad] * (max_len - len(x))
def make_tensor_with_pad( def make_tensor_with_pad(
x: List[List[int]], x: List[List[int]],
max_len: int, max_len: int,
...@@ -518,7 +519,10 @@ def make_tensor_with_pad( ...@@ -518,7 +519,10 @@ def make_tensor_with_pad(
The padding is applied to the end of each inner list until it reaches The padding is applied to the end of each inner list until it reaches
`max_len`. `max_len`.
""" """
padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x] padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad
for ind, blocktb in enumerate(x):
assert len(blocktb) <= max_len
padded_x[ind, :len(blocktb)] = blocktb
return torch.tensor(padded_x, dtype=dtype, device=device) return torch.tensor(padded_x, dtype=dtype, device=device)
......
from typing import List, Optional, Tuple from collections import defaultdict
from typing import Dict, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -11,6 +12,7 @@ from vllm.distributed import broadcast_tensor_dict ...@@ -11,6 +12,7 @@ from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad from vllm.utils import make_tensor_with_pad
...@@ -63,6 +65,16 @@ class CPUModelRunner: ...@@ -63,6 +65,16 @@ class CPUModelRunner:
self.block_size, self.block_size,
) )
# Create processor for multi-modal data
if self.vision_language_config is not None:
self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
.create_input_processor(
self.model_config,
self.vision_language_config,
)
else:
self.multi_modal_input_processor = None
# Lazy initialization. # Lazy initialization.
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
...@@ -80,14 +92,15 @@ class CPUModelRunner: ...@@ -80,14 +92,15 @@ class CPUModelRunner:
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], Dict[
Optional[torch.Tensor]]: str, torch.Tensor]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
seq_lens: List[int] = [] seq_lens: List[int] = []
multi_modal_input_list: List[torch.Tensor] = [] multi_modal_kwargs_list: Dict[str,
List[torch.Tensor]] = defaultdict(list)
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
...@@ -108,9 +121,17 @@ class CPUModelRunner: ...@@ -108,9 +121,17 @@ class CPUModelRunner:
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len))) input_positions.extend(list(range(computed_len, seq_len)))
if seq_group_metadata.multi_modal_data: mm_data = seq_group_metadata.multi_modal_data
multi_modal_input_list.append( if mm_data is not None:
seq_group_metadata.multi_modal_data.data) # Process multi-modal data
if self.multi_modal_input_processor is None:
raise ValueError(
"Multi-modal inputs are only supported by "
"vision language models.")
mm_kwargs = self.multi_modal_input_processor(mm_data)
for k, v in mm_kwargs.items():
multi_modal_kwargs_list[k].append(v)
# Compute the slot mapping. # Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
...@@ -134,14 +155,10 @@ class CPUModelRunner: ...@@ -134,14 +155,10 @@ class CPUModelRunner:
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping.append(slot)
if multi_modal_input_list: multi_modal_kwargs = {
assert self.vision_language_config, ( k: torch.cat(v, dim=0).to(self.device)
"Multi-modal inputs are only supported by " for k, v in multi_modal_kwargs_list.items()
"vision language models.") }
multi_modal_input = torch.cat(multi_modal_input_list,
dim=0).to(self.device)
else:
multi_modal_input = None
num_prompt_tokens = len(input_tokens) num_prompt_tokens = len(input_tokens)
...@@ -167,7 +184,7 @@ class CPUModelRunner: ...@@ -167,7 +184,7 @@ class CPUModelRunner:
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
) )
return (input_tokens, input_positions, attn_metadata, seq_lens, return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input) multi_modal_kwargs)
def _prepare_decode( def _prepare_decode(
self, self,
...@@ -257,8 +274,8 @@ class CPUModelRunner: ...@@ -257,8 +274,8 @@ class CPUModelRunner:
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Optional[torch.Tensor]]: Optional[Dict[str, torch.Tensor]]]:
multi_modal_input = None multi_modal_kwargs = None
if self.is_driver_worker: if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
...@@ -266,7 +283,7 @@ class CPUModelRunner: ...@@ -266,7 +283,7 @@ class CPUModelRunner:
# Prepare input tensors. # Prepare input tensors.
if is_prompt: if is_prompt:
(input_tokens, input_positions, attn_metadata, seq_lens, (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list) ) = self._prepare_prompt(seq_group_metadata_list)
else: else:
(input_tokens, input_positions, (input_tokens, input_positions,
...@@ -307,7 +324,7 @@ class CPUModelRunner: ...@@ -307,7 +324,7 @@ class CPUModelRunner:
) )
return (input_tokens, input_positions, attn_metadata, return (input_tokens, input_positions, attn_metadata,
sampling_metadata, multi_modal_input) sampling_metadata, multi_modal_kwargs)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
......
...@@ -90,7 +90,7 @@ class EmbeddingModelRunner(ModelRunner): ...@@ -90,7 +90,7 @@ class EmbeddingModelRunner(ModelRunner):
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
Set[LoRARequest], LoRAMapping, torch.Tensor]: Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
if self.is_driver_worker: if self.is_driver_worker:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
# Prepare input tensors. # Prepare input tensors.
...@@ -102,7 +102,7 @@ class EmbeddingModelRunner(ModelRunner): ...@@ -102,7 +102,7 @@ class EmbeddingModelRunner(ModelRunner):
_, _,
lora_mapping, lora_mapping,
lora_requests, lora_requests,
multi_modal_input, multi_modal_kwargs,
slot_mapping, slot_mapping,
num_prefill_tokens, num_prefill_tokens,
num_decode_tokens, num_decode_tokens,
...@@ -117,7 +117,7 @@ class EmbeddingModelRunner(ModelRunner): ...@@ -117,7 +117,7 @@ class EmbeddingModelRunner(ModelRunner):
"input_positions": input_positions, "input_positions": input_positions,
"lora_requests": lora_requests, "lora_requests": lora_requests,
"lora_mapping": lora_mapping, "lora_mapping": lora_mapping,
"multi_modal_input": multi_modal_input, "multi_modal_kwargs": multi_modal_kwargs,
"num_prefill_tokens": num_prefill_tokens, "num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens, "num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping, "slot_mapping": slot_mapping,
...@@ -132,7 +132,7 @@ class EmbeddingModelRunner(ModelRunner): ...@@ -132,7 +132,7 @@ class EmbeddingModelRunner(ModelRunner):
input_positions = metadata_dict.pop("input_positions") input_positions = metadata_dict.pop("input_positions")
lora_mapping = metadata_dict.pop("lora_mapping") lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests") lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input") multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
if metadata_dict: if metadata_dict:
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
**metadata_dict) **metadata_dict)
...@@ -143,7 +143,7 @@ class EmbeddingModelRunner(ModelRunner): ...@@ -143,7 +143,7 @@ class EmbeddingModelRunner(ModelRunner):
prompt_lens=None) prompt_lens=None)
return (input_tokens, input_positions, attn_metadata, pooling_metadata, return (input_tokens, input_positions, attn_metadata, pooling_metadata,
lora_requests, lora_mapping, multi_modal_input) lora_requests, lora_mapping, multi_modal_kwargs)
def _prepare_pooling( def _prepare_pooling(
self, self,
......
import gc
import time import time
import warnings import warnings
from collections import defaultdict
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
import numpy as np import numpy as np
...@@ -18,9 +20,9 @@ from vllm.lora.request import LoRARequest ...@@ -18,9 +20,9 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available, make_tensor_with_pad) is_pin_memory_available, make_tensor_with_pad)
...@@ -34,6 +36,7 @@ _BATCH_SIZE_ALIGNMENT = 8 ...@@ -34,6 +36,7 @@ _BATCH_SIZE_ALIGNMENT = 8
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
] ]
_NUM_WARMUP_ITERS = 2
class ModelInput(NamedTuple): class ModelInput(NamedTuple):
...@@ -44,7 +47,7 @@ class ModelInput(NamedTuple): ...@@ -44,7 +47,7 @@ class ModelInput(NamedTuple):
query_lens: List[int] query_lens: List[int]
lora_mapping: Optional[LoRAMapping] lora_mapping: Optional[LoRAMapping]
lora_requests: Set[LoRARequest] lora_requests: Set[LoRARequest]
multi_modal_input: Optional[torch.Tensor] multi_modal_kwargs: Dict[str, torch.Tensor]
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
num_prefill_tokens: int num_prefill_tokens: int
num_decode_tokens: int num_decode_tokens: int
...@@ -60,7 +63,7 @@ class ModelInput(NamedTuple): ...@@ -60,7 +63,7 @@ class ModelInput(NamedTuple):
query_lens=[], query_lens=[],
lora_mapping=None, lora_mapping=None,
lora_requests=set(), lora_requests=set(),
multi_modal_input=None, multi_modal_kwargs={},
slot_mapping=torch.empty(0, device=device), slot_mapping=torch.empty(0, device=device),
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=0, num_decode_tokens=0,
...@@ -122,6 +125,16 @@ class ModelRunner: ...@@ -122,6 +125,16 @@ class ModelRunner:
self.block_size, self.block_size,
) )
# Create processor for multi-modal data
if self.vision_language_config is not None:
self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
.create_input_processor(
self.model_config,
self.vision_language_config,
)
else:
self.multi_modal_input_processor = None
# Lazy initialization # Lazy initialization
self.model: nn.Module # Set after load_model self.model: nn.Module # Set after load_model
# Set if the backend is flashinfer. # Set if the backend is flashinfer.
...@@ -242,7 +255,8 @@ class ModelRunner: ...@@ -242,7 +255,8 @@ class ModelRunner:
context_lens: List[int] = [] context_lens: List[int] = []
query_lens: List[int] = [] query_lens: List[int] = []
block_tables: List[List[int]] = [] block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = [] multi_modal_kwargs_list: Dict[str,
List[torch.Tensor]] = defaultdict(list)
decode_only = True decode_only = True
num_prefills = 0 num_prefills = 0
num_prefill_tokens = 0 num_prefill_tokens = 0
...@@ -415,11 +429,19 @@ class ModelRunner: ...@@ -415,11 +429,19 @@ class ModelRunner:
[lora_id] * [lora_id] *
(query_len if seq_group_metadata.sampling_params (query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs and seq_group_metadata.sampling_params.prompt_logprobs
else 1)) is not None else 1))
if seq_group_metadata.multi_modal_data: mm_data = seq_group_metadata.multi_modal_data
multi_modal_input_list.append( if mm_data is not None:
seq_group_metadata.multi_modal_data.data) # Process multi-modal data
if self.multi_modal_input_processor is None:
raise ValueError(
"Multi-modal inputs are only supported by "
"vision language models.")
mm_kwargs = self.multi_modal_input_processor(mm_data)
for k, v in mm_kwargs.items():
multi_modal_kwargs_list[k].append(v)
if _is_block_tables_empty(seq_group_metadata.block_tables): if _is_block_tables_empty(seq_group_metadata.block_tables):
# During memory profiling, the block tables are not # During memory profiling, the block tables are not
...@@ -505,26 +527,6 @@ class ModelRunner: ...@@ -505,26 +527,6 @@ class ModelRunner:
) )
assert max_query_len > 0, ("query_lens: {}".format(query_lens)) assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
if multi_modal_input_list:
assert self.vision_language_config, (
"Multi-modal inputs are only supported by "
"vision language models.")
multi_modal_input = torch.cat(multi_modal_input_list,
dim=0).to(self.device)
else:
multi_modal_input = None
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens, seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
...@@ -532,11 +534,6 @@ class ModelRunner: ...@@ -532,11 +534,6 @@ class ModelRunner:
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
torch.cumsum(seq_lens_tensor, torch.cumsum(seq_lens_tensor,
dim=0, dim=0,
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
...@@ -589,6 +586,21 @@ class ModelRunner: ...@@ -589,6 +586,21 @@ class ModelRunner:
seq_start_loc=seq_start_loc, seq_start_loc=seq_start_loc,
data_type=kv_cache_dtype) data_type=kv_cache_dtype)
else: else:
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
...@@ -614,6 +626,11 @@ class ModelRunner: ...@@ -614,6 +626,11 @@ class ModelRunner:
else: else:
lora_mapping = None lora_mapping = None
multi_modal_kwargs = {
k: torch.cat(v, dim=0).to(self.device)
for k, v in multi_modal_kwargs_list.items()
}
return ModelInput( return ModelInput(
input_tokens=input_tokens_tensor, input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor, input_positions=input_positions_tensor,
...@@ -622,7 +639,7 @@ class ModelRunner: ...@@ -622,7 +639,7 @@ class ModelRunner:
query_lens=query_lens, query_lens=query_lens,
lora_mapping=lora_mapping, lora_mapping=lora_mapping,
lora_requests=lora_requests, lora_requests=lora_requests,
multi_modal_input=multi_modal_input, multi_modal_kwargs=multi_modal_kwargs,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
...@@ -633,7 +650,7 @@ class ModelRunner: ...@@ -633,7 +650,7 @@ class ModelRunner:
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[LoRARequest], LoRAMapping, torch.Tensor]: Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
if self.is_driver_worker: if self.is_driver_worker:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
# Prepare input tensors. # Prepare input tensors.
...@@ -645,7 +662,7 @@ class ModelRunner: ...@@ -645,7 +662,7 @@ class ModelRunner:
query_lens, query_lens,
lora_mapping, lora_mapping,
lora_requests, lora_requests,
multi_modal_input, multi_modal_kwargs,
slot_mapping, slot_mapping,
num_prefill_tokens, num_prefill_tokens,
num_decode_tokens, num_decode_tokens,
...@@ -662,7 +679,7 @@ class ModelRunner: ...@@ -662,7 +679,7 @@ class ModelRunner:
sampling_metadata.selected_token_indices, sampling_metadata.selected_token_indices,
"lora_requests": lora_requests, "lora_requests": lora_requests,
"lora_mapping": lora_mapping, "lora_mapping": lora_mapping,
"multi_modal_input": multi_modal_input, "multi_modal_kwargs": multi_modal_kwargs,
"num_prefill_tokens": num_prefill_tokens, "num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens, "num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping, "slot_mapping": slot_mapping,
...@@ -679,7 +696,7 @@ class ModelRunner: ...@@ -679,7 +696,7 @@ class ModelRunner:
"selected_token_indices") "selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping") lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests") lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input") multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
if metadata_dict: if metadata_dict:
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
**metadata_dict) **metadata_dict)
...@@ -694,7 +711,7 @@ class ModelRunner: ...@@ -694,7 +711,7 @@ class ModelRunner:
return (input_tokens, input_positions, attn_metadata, return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping, sampling_metadata, lora_requests, lora_mapping,
multi_modal_input) multi_modal_kwargs)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
...@@ -703,7 +720,7 @@ class ModelRunner: ...@@ -703,7 +720,7 @@ class ModelRunner:
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata, (input_tokens, input_positions, attn_metadata, sampling_metadata,
lora_requests, lora_mapping, multi_modal_input lora_requests, lora_mapping, multi_modal_kwargs
) = self.prepare_input_tensors(seq_group_metadata_list) ) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config: if self.lora_config:
...@@ -717,15 +734,14 @@ class ModelRunner: ...@@ -717,15 +734,14 @@ class ModelRunner:
model_executable = self.graph_runners[graph_batch_size] model_executable = self.graph_runners[graph_batch_size]
else: else:
model_executable = self.model model_executable = self.model
execute_model_kwargs = {
"input_ids": input_tokens, hidden_states = model_executable(
"positions": input_positions, input_ids=input_tokens,
"kv_caches": kv_caches, positions=input_positions,
"attn_metadata": attn_metadata, kv_caches=kv_caches,
} attn_metadata=attn_metadata,
if self.vision_language_config: **multi_modal_kwargs,
execute_model_kwargs.update({"image_input": multi_modal_input}) )
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits. # Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata) logits = self.model.compute_logits(hidden_states, sampling_metadata)
...@@ -781,16 +797,24 @@ class ModelRunner: ...@@ -781,16 +797,24 @@ class ModelRunner:
# To exercise the worst scenario for GPU memory consumption, # To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number # the number of seqs (batch_size) is chosen to maximize the number
# of images processed. # of images processed.
if self.vision_language_config: model_config = self.model_config
vlm_config = self.vision_language_config
if vlm_config:
max_num_seqs = min( max_num_seqs = min(
max_num_seqs, max_num_seqs,
int(max_num_batched_tokens / int(max_num_batched_tokens / vlm_config.image_feature_size))
self.vision_language_config.image_feature_size))
for group_id in range(max_num_seqs): for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
seq_data, fake_multi_modal_input = _prepare_fake_inputs(
seq_len, self.vision_language_config) if vlm_config is None:
seq_data = SequenceData([0] * seq_len)
dummy_multi_modal_data = None
else:
seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \
.dummy_data_for_profiling(seq_len, model_config, vlm_config)
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),
is_prompt=True, is_prompt=True,
...@@ -799,7 +823,7 @@ class ModelRunner: ...@@ -799,7 +823,7 @@ class ModelRunner:
block_tables=None, block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id] lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None, if dummy_lora_requests_per_seq else None,
multi_modal_data=fake_multi_modal_input, multi_modal_data=dummy_multi_modal_data,
) )
seqs.append(seq) seqs.append(seq)
...@@ -871,6 +895,10 @@ class ModelRunner: ...@@ -871,6 +895,10 @@ class ModelRunner:
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda()
# Prepare buffer for outputs. These will be reused for all batch sizes.
# It will be filled after the first graph capture.
hidden_states: Optional[torch.Tensor] = None
graph_batch_size = _get_graph_batch_size( graph_batch_size = _get_graph_batch_size(
self.scheduler_config.max_num_seqs) self.scheduler_config.max_num_seqs)
batch_size_capture_list = [ batch_size_capture_list = [
...@@ -907,9 +935,11 @@ class ModelRunner: ...@@ -907,9 +935,11 @@ class ModelRunner:
self.set_active_loras(set(), lora_mapping) self.set_active_loras(set(), lora_mapping)
graph_runner = CUDAGraphRunner(self.model) graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture( hidden_states = graph_runner.capture(
input_tokens[:batch_size], input_tokens[:batch_size],
input_positions[:batch_size], input_positions[:batch_size],
hidden_states[:batch_size]
if hidden_states is not None else None,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
memory_pool=self.graph_memory_pool, memory_pool=self.graph_memory_pool,
...@@ -946,35 +976,46 @@ class CUDAGraphRunner: ...@@ -946,35 +976,46 @@ class CUDAGraphRunner:
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: Optional[torch.Tensor],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
memory_pool: Optional[Tuple[int, int]], memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream, stream: torch.cuda.Stream,
**kwargs, **kwargs,
) -> None: ) -> torch.Tensor:
assert self._graph is None assert self._graph is None
# Run the model once without capturing the graph. # Run the model a few times without capturing the graph.
# This is to make sure that the captured graph does not include the # This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune). # kernel launches for initial benchmarking (e.g., Triton autotune).
self.model( # Note one iteration is not enough for torch.jit.script
input_ids, for _ in range(_NUM_WARMUP_ITERS):
positions, self.model(
kv_caches, input_ids,
attn_metadata, positions,
**kwargs, kv_caches,
) attn_metadata,
**kwargs,
)
torch.cuda.synchronize() torch.cuda.synchronize()
# Capture the graph. # Capture the graph.
self._graph = torch.cuda.CUDAGraph() self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
hidden_states = self.model( output_hidden_states = self.model(
input_ids, input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
**kwargs, **kwargs,
) )
if hidden_states is not None:
hidden_states.copy_(output_hidden_states)
else:
hidden_states = output_hidden_states
del output_hidden_states
# make sure `output_hidden_states` is deleted
# in the graph's memory pool
gc.collect()
torch.cuda.synchronize() torch.cuda.synchronize()
# Save the input and output buffers. # Save the input and output buffers.
...@@ -987,7 +1028,7 @@ class CUDAGraphRunner: ...@@ -987,7 +1028,7 @@ class CUDAGraphRunner:
"block_tables": attn_metadata.decode_metadata.block_tables, "block_tables": attn_metadata.decode_metadata.block_tables,
} }
self.output_buffers = {"hidden_states": hidden_states} self.output_buffers = {"hidden_states": hidden_states}
return return hidden_states
def forward( def forward(
self, self,
...@@ -1034,24 +1075,6 @@ def _get_graph_batch_size(batch_size: int) -> int: ...@@ -1034,24 +1075,6 @@ def _get_graph_batch_size(batch_size: int) -> int:
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
def _prepare_fake_inputs(
seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
"""Prepare fake inputs for profile run."""
if vision_language_config:
prompt_tokens = [
vision_language_config.image_token_id
] * vision_language_config.image_feature_size + [0] * (
seq_len - vision_language_config.image_feature_size)
fake_image_input = MultiModalData(
type=MultiModalData.Type.IMAGE,
data=torch.zeros(vision_language_config.image_input_shape,
dtype=torch.float16))
else:
prompt_tokens = [0] * seq_len
fake_image_input = None
return SequenceData(prompt_tokens), fake_image_input
def _is_block_tables_empty(block_tables: Union[None, Dict]): def _is_block_tables_empty(block_tables: Union[None, Dict]):
""" """
Check if block_tables is None or a dictionary with all None values. Check if block_tables is None or a dictionary with all None values.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment