Commit e00b0a19 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.3.3

parents ead94d93 3f1166ab
from typing import List, Optional from typing import List, Optional
import time
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup,
SequenceStatus) SequenceStatus, RequestMetrics)
from vllm.lora.request import LoRARequest
class CompletionOutput: class CompletionOutput:
...@@ -16,6 +18,7 @@ class CompletionOutput: ...@@ -16,6 +18,7 @@ class CompletionOutput:
logprobs: The log probabilities of the top probability words at each logprobs: The log probabilities of the top probability words at each
position if the logprobs are requested. position if the logprobs are requested.
finish_reason: The reason why the sequence is finished. finish_reason: The reason why the sequence is finished.
lora_request: The LoRA request that was used to generate the output.
""" """
def __init__( def __init__(
...@@ -26,6 +29,7 @@ class CompletionOutput: ...@@ -26,6 +29,7 @@ class CompletionOutput:
cumulative_logprob: float, cumulative_logprob: float,
logprobs: Optional[SampleLogprobs], logprobs: Optional[SampleLogprobs],
finish_reason: Optional[str] = None, finish_reason: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
) -> None: ) -> None:
self.index = index self.index = index
self.text = text self.text = text
...@@ -33,6 +37,7 @@ class CompletionOutput: ...@@ -33,6 +37,7 @@ class CompletionOutput:
self.cumulative_logprob = cumulative_logprob self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs self.logprobs = logprobs
self.finish_reason = finish_reason self.finish_reason = finish_reason
self.lora_request = lora_request
def finished(self) -> bool: def finished(self) -> bool:
return self.finish_reason is not None return self.finish_reason is not None
...@@ -56,6 +61,8 @@ class RequestOutput: ...@@ -56,6 +61,8 @@ class RequestOutput:
prompt_logprobs: The log probabilities to return per prompt token. prompt_logprobs: The log probabilities to return per prompt token.
outputs: The output sequences of the request. outputs: The output sequences of the request.
finished: Whether the whole request is finished. finished: Whether the whole request is finished.
metrics: Metrics associated with the request.
lora_request: The LoRA request that was used to generate the output.
""" """
def __init__( def __init__(
...@@ -66,6 +73,8 @@ class RequestOutput: ...@@ -66,6 +73,8 @@ class RequestOutput:
prompt_logprobs: Optional[PromptLogprobs], prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],
finished: bool, finished: bool,
metrics: Optional[RequestMetrics] = None,
lora_request: Optional[LoRARequest] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
...@@ -73,6 +82,8 @@ class RequestOutput: ...@@ -73,6 +82,8 @@ class RequestOutput:
self.prompt_logprobs = prompt_logprobs self.prompt_logprobs = prompt_logprobs
self.outputs = outputs self.outputs = outputs
self.finished = finished self.finished = finished
self.metrics = metrics
self.lora_request = lora_request
@classmethod @classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
...@@ -108,8 +119,16 @@ class RequestOutput: ...@@ -108,8 +119,16 @@ class RequestOutput:
prompt_token_ids = seq_group.prompt_token_ids prompt_token_ids = seq_group.prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs prompt_logprobs = seq_group.prompt_logprobs
finished = seq_group.is_finished() finished = seq_group.is_finished()
return cls(seq_group.request_id, prompt, prompt_token_ids, finished_time = time.time() if finished else None
prompt_logprobs, outputs, finished) seq_group.set_finished_time(finished_time)
return cls(seq_group.request_id,
prompt,
prompt_token_ids,
prompt_logprobs,
outputs,
finished,
seq_group.metrics,
lora_request=seq_group.lora_request)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, " return (f"RequestOutput(request_id={self.request_id}, "
...@@ -117,4 +136,6 @@ class RequestOutput: ...@@ -117,4 +136,6 @@ class RequestOutput:
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
f"prompt_logprobs={self.prompt_logprobs}, " f"prompt_logprobs={self.prompt_logprobs}, "
f"outputs={self.outputs}, " f"outputs={self.outputs}, "
f"finished={self.finished})") f"finished={self.finished}, "
f"metrics={self.metrics}, "
f"lora_request={self.lora_request})")
from typing import Dict, List, Sequence, Tuple, Optional
from vllm.block import BlockTable
class Prefix:
"""Data and states associated with a prefix of prompt tokens for multiple
sequence groups.
NOTE: This feature is experimental and may be replaced with automatic
prefix caching in the future.
Args:
token_ids: The token ids of the prefix.
block_size: The block size of the executed model.
"""
def __init__(
self,
token_ids: Sequence[int],
block_size: int,
) -> None:
self.token_ids = tuple(token_ids)
self.block_size = block_size
self.length = len(token_ids)
self.hash = hash(token_ids)
assert self.length % block_size == 0
self.block_table: Optional[BlockTable] = None
self.computed = False
@property
def allocated(self) -> bool:
return self.block_table is not None
def get_num_blocks(self) -> int:
return self.length // self.block_size
def get_block_numbers(self) -> List[int]:
return [block.block_number for block in self.block_table]
def get_length(self) -> int:
return self.length
def __hash__(self) -> int:
return self.hash
def set_block_table(self, block_table: BlockTable) -> None:
self.block_table = block_table.copy()
class PrefixPool:
"""Manages all the prompt prefixes.
NOTE: This feature is experimental and may be replaced with automatic
prefix caching in the future.
Args:
block_size: The block size of the executed model.
Attributes:
prefixes: A list of all the prefixes.
block_size: The block size of the executed model.
"""
def __init__(
self,
block_size: int,
) -> None:
# TODO(zhuohan): Add a capacity limit to the prefix pool.
self.prefixes: Dict[int, Prefix] = {}
self.block_size = block_size
def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]:
new_length = len(token_ids) // self.block_size * self.block_size
return tuple(token_ids[:new_length])
def add_or_get_prefix(self, token_ids: Sequence[int],
lora_int_id: int) -> Optional[Prefix]:
token_ids = self._truncate_token_ids(token_ids)
if len(token_ids) == 0:
# Prefix is empty.
return None
prefix = Prefix(token_ids, self.block_size)
prefix_hash = hash((prefix, lora_int_id))
if prefix_hash not in self.prefixes:
self.prefixes[prefix_hash] = prefix
return self.prefixes[prefix_hash]
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
import copy
from enum import IntEnum from enum import IntEnum
from functools import cached_property from functools import cached_property
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
...@@ -11,7 +12,8 @@ _SAMPLING_EPS = 1e-5 ...@@ -11,7 +12,8 @@ _SAMPLING_EPS = 1e-5
class SamplingType(IntEnum): class SamplingType(IntEnum):
GREEDY = 0 GREEDY = 0
RANDOM = 1 RANDOM = 1
BEAM = 2 RANDOM_SEED = 2
BEAM = 3
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
...@@ -56,6 +58,7 @@ class SamplingParams: ...@@ -56,6 +58,7 @@ class SamplingParams:
min_p: Float that represents the minimum probability for a token to be min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token. considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this. Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
use_beam_search: Whether to use beam search instead of sampling. use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length. length_penalty: Float that penalizes sequences based on their length.
Used in beam search. Used in beam search.
...@@ -101,6 +104,7 @@ class SamplingParams: ...@@ -101,6 +104,7 @@ class SamplingParams:
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
min_p: float = 0.0, min_p: float = 0.0,
seed: Optional[int] = None,
use_beam_search: bool = False, use_beam_search: bool = False,
length_penalty: float = 1.0, length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False, early_stopping: Union[bool, str] = False,
...@@ -108,7 +112,7 @@ class SamplingParams: ...@@ -108,7 +112,7 @@ class SamplingParams:
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False, include_stop_str_in_output: bool = False,
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: int = 16, max_tokens: Optional[int] = 16,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
...@@ -124,6 +128,7 @@ class SamplingParams: ...@@ -124,6 +128,7 @@ class SamplingParams:
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
self.min_p = min_p self.min_p = min_p
self.seed = seed
self.use_beam_search = use_beam_search self.use_beam_search = use_beam_search
self.length_penalty = length_penalty self.length_penalty = length_penalty
self.early_stopping = early_stopping self.early_stopping = early_stopping
...@@ -183,7 +188,7 @@ class SamplingParams: ...@@ -183,7 +188,7 @@ class SamplingParams:
if not 0.0 <= self.min_p <= 1.0: if not 0.0 <= self.min_p <= 1.0:
raise ValueError("min_p must be in [0, 1], got " raise ValueError("min_p must be in [0, 1], got "
f"{self.min_p}.") f"{self.min_p}.")
if self.max_tokens < 1: if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError( raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.") f"max_tokens must be at least 1, got {self.max_tokens}.")
if self.logprobs is not None and self.logprobs < 0: if self.logprobs is not None and self.logprobs < 0:
...@@ -229,8 +234,24 @@ class SamplingParams: ...@@ -229,8 +234,24 @@ class SamplingParams:
return SamplingType.BEAM return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS: if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY return SamplingType.GREEDY
if self.seed is not None:
return SamplingType.RANDOM_SEED
return SamplingType.RANDOM return SamplingType.RANDOM
def clone(self) -> "SamplingParams":
"""Deep copy excluding LogitsProcessor objects.
LogitsProcessor objects are excluded because they may contain an
arbitrary, nontrivial amount of data.
See https://github.com/vllm-project/vllm/issues/3087
"""
logit_processor_refs = None if self.logits_processors is None else {
id(lp): lp
for lp in self.logits_processors
}
return copy.deepcopy(self, memo=logit_processor_refs)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"SamplingParams(n={self.n}, " f"SamplingParams(n={self.n}, "
...@@ -242,6 +263,7 @@ class SamplingParams: ...@@ -242,6 +263,7 @@ class SamplingParams:
f"top_p={self.top_p}, " f"top_p={self.top_p}, "
f"top_k={self.top_k}, " f"top_k={self.top_k}, "
f"min_p={self.min_p}, " f"min_p={self.min_p}, "
f"seed={self.seed}, "
f"use_beam_search={self.use_beam_search}, " f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, " f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, " f"early_stopping={self.early_stopping}, "
......
"""Sequence and its related classes.""" """Sequence and its related classes."""
import copy import copy
import enum import enum
from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from vllm.block import LogicalTokenBlock from vllm.block import LogicalTokenBlock
from vllm.prefix import Prefix
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.lora.request import LoRARequest
PromptLogprobs = List[Optional[Dict[int, float]]] PromptLogprobs = List[Optional[Dict[int, float]]]
SampleLogprobs = List[Dict[int, float]] SampleLogprobs = List[Dict[int, float]]
...@@ -47,10 +50,28 @@ class SequenceStatus(enum.Enum): ...@@ -47,10 +50,28 @@ class SequenceStatus(enum.Enum):
return finish_reason return finish_reason
@dataclass
class RequestMetrics:
"""Metrics associated with a request.
Args:
arrival_time: The time when the request arrived.
first_scheduled_time: The time when the request was first scheduled.
first_token_time: The time when the first token was generated.
time_in_queue: The time the request spent in the queue.
finished_time: The time when the request was finished.
"""
arrival_time: float
last_token_time: float
first_scheduled_time: Optional[float]
first_token_time: Optional[float]
time_in_queue: Optional[float]
finished_time: Optional[float] = None
class SequenceData: class SequenceData:
"""Data associated with a sequence. """Data associated with a sequence.
Args: Args:
prompt_token_ids: The token IDs of the prompt. prompt_token_ids: The token IDs of the prompt.
...@@ -105,6 +126,7 @@ class Sequence: ...@@ -105,6 +126,7 @@ class Sequence:
prompt_token_ids: The token IDs of the prompt. prompt_token_ids: The token IDs of the prompt.
block_size: The block size of the sequence. Should be the same as the block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine. block size used by the block manager and cache engine.
lora_request: LoRA request.
""" """
def __init__( def __init__(
...@@ -113,10 +135,12 @@ class Sequence: ...@@ -113,10 +135,12 @@ class Sequence:
prompt: str, prompt: str,
prompt_token_ids: List[int], prompt_token_ids: List[int],
block_size: int, block_size: int,
lora_request: Optional[LoRARequest] = None,
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.prompt = prompt self.prompt = prompt
self.block_size = block_size self.block_size = block_size
self.lora_request = lora_request
self.data = SequenceData(prompt_token_ids) self.data = SequenceData(prompt_token_ids)
self.output_logprobs: SampleLogprobs = [] self.output_logprobs: SampleLogprobs = []
...@@ -133,6 +157,10 @@ class Sequence: ...@@ -133,6 +157,10 @@ class Sequence:
# Input + output tokens # Input + output tokens
self.tokens: Optional[List[str]] = None self.tokens: Optional[List[str]] = None
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
def _append_logical_block(self) -> None: def _append_logical_block(self) -> None:
block = LogicalTokenBlock( block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks), block_number=len(self.logical_token_blocks),
...@@ -188,7 +216,7 @@ class Sequence: ...@@ -188,7 +216,7 @@ class Sequence:
return self.data.cumulative_logprob return self.data.cumulative_logprob
def get_beam_search_score(self, def get_beam_search_score(self,
length_penalty: float = 0.0, length_penalty: float = 1.0,
seq_len: Optional[int] = None, seq_len: Optional[int] = None,
eos_token_id: Optional[int] = None) -> float: eos_token_id: Optional[int] = None) -> float:
"""Calculate the beam search score with length penalty. """Calculate the beam search score with length penalty.
...@@ -220,6 +248,14 @@ class Sequence: ...@@ -220,6 +248,14 @@ class Sequence:
f"num_blocks={len(self.logical_token_blocks)})") f"num_blocks={len(self.logical_token_blocks)})")
@dataclass
class SequenceGroupState:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator: Optional = None
class SequenceGroup: class SequenceGroup:
"""A group of sequences that are generated from the same prompt. """A group of sequences that are generated from the same prompt.
...@@ -228,6 +264,8 @@ class SequenceGroup: ...@@ -228,6 +264,8 @@ class SequenceGroup:
seqs: The list of sequences. seqs: The list of sequences.
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request. arrival_time: The arrival time of the request.
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group.
""" """
def __init__( def __init__(
...@@ -236,12 +274,21 @@ class SequenceGroup: ...@@ -236,12 +274,21 @@ class SequenceGroup:
seqs: List[Sequence], seqs: List[Sequence],
sampling_params: SamplingParams, sampling_params: SamplingParams,
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.arrival_time = arrival_time self.metrics = RequestMetrics(arrival_time=arrival_time,
last_token_time=arrival_time,
first_scheduled_time=None,
first_token_time=None,
time_in_queue=None)
self.lora_request = lora_request
self.prefix: Optional[Prefix] = prefix
self.prompt_logprobs: Optional[PromptLogprobs] = None self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
@property @property
def prompt(self) -> str: def prompt(self) -> str:
...@@ -255,6 +302,31 @@ class SequenceGroup: ...@@ -255,6 +302,31 @@ class SequenceGroup:
# We use the prompt of an arbitrary sequence. # We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).data.prompt_token_ids return next(iter(self.seqs_dict.values())).data.prompt_token_ids
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
def get_last_latency(self, now: float) -> float:
"""Gets last token latency for Request level timings."""
latency = now - self.metrics.last_token_time
self.metrics.last_token_time = now
return latency
def maybe_set_first_token_time(self, time: float) -> None:
"""Sets the first token time for Request level timings."""
if self.metrics.first_token_time is None:
self.metrics.first_token_time = time
def maybe_set_first_scheduled_time(self, time: float) -> None:
"""Sets the first scheduled time and time in queue for Request level timings."""
if self.metrics.first_scheduled_time is None:
self.metrics.first_scheduled_time = time
self.metrics.time_in_queue = time - self.metrics.arrival_time
def set_finished_time(self, time: Optional[float]) -> None:
"""Sets the finished time for Request level timings."""
self.metrics.finished_time = time
def get_max_num_running_seqs(self) -> int: def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining """The maximum number of sequences running in parallel in the remaining
lifetime of the request.""" lifetime of the request."""
...@@ -327,7 +399,6 @@ class SequenceGroup: ...@@ -327,7 +399,6 @@ class SequenceGroup:
class SequenceGroupMetadata: class SequenceGroupMetadata:
"""Metadata for a sequence group. Used to create `InputMetadata`. """Metadata for a sequence group. Used to create `InputMetadata`.
Args: Args:
request_id: The ID of the request. request_id: The ID of the request.
is_prompt: Whether the request is at prompt stage. is_prompt: Whether the request is at prompt stage.
...@@ -335,6 +406,9 @@ class SequenceGroupMetadata: ...@@ -335,6 +406,9 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block block_tables: The block tables. (Seq id -> list of physical block
numbers) numbers)
state: Internal state tied to this sequence group.
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group.
""" """
def __init__( def __init__(
...@@ -344,12 +418,22 @@ class SequenceGroupMetadata: ...@@ -344,12 +418,22 @@ class SequenceGroupMetadata:
seq_data: Dict[int, SequenceData], seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams, sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], block_tables: Dict[int, List[int]],
lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None,
state: Optional[SequenceGroupState] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.is_prompt = is_prompt self.is_prompt = is_prompt
self.seq_data = seq_data self.seq_data = seq_data
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.block_tables = block_tables self.block_tables = block_tables
self.lora_request = lora_request
self.prefix = prefix
self.state = SequenceGroupState() if state is None else state
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
class SequenceOutput: class SequenceOutput:
......
import ray
from vllm.config import ParallelConfig
from vllm.utils import get_open_port
from vllm.worker.worker import init_distributed_environment
def init_test_distributed_environment(
pipeline_parallel_size: int,
tensor_parallel_size: int,
rank: int,
distributed_init_port: str,
) -> None:
parallel_config = ParallelConfig(pipeline_parallel_size,
tensor_parallel_size,
worker_use_ray=True)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
parallel_config,
rank,
cupy_port=None,
distributed_init_method=distributed_init_method)
def multi_process_tensor_parallel(
tensor_parallel_size: int,
test_target,
) -> None:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
ray.init()
distributed_init_port = get_open_port()
refs = []
for rank in range(tensor_parallel_size):
refs.append(
test_target.remote(tensor_parallel_size, rank,
distributed_init_port))
ray.get(refs)
ray.shutdown()
...@@ -5,23 +5,33 @@ from transformers import AutoConfig, PretrainedConfig ...@@ -5,23 +5,33 @@ from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import * from vllm.transformers_utils.configs import *
_CONFIG_REGISTRY = { _CONFIG_REGISTRY = {
"aquila": AquilaConfig,
"baichuan": BaiChuanConfig,
"chatglm": ChatGLMConfig, "chatglm": ChatGLMConfig,
"mpt": MPTConfig, "mpt": MPTConfig,
"qwen": QWenConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
"yi": YiConfig, "starcoder2": Starcoder2Config,
} }
def get_config(model: str, def get_config(model: str,
trust_remote_code: bool, trust_remote_code: bool,
revision: Optional[str] = None) -> PretrainedConfig: revision: Optional[str] = None,
code_revision: Optional[str] = None) -> PretrainedConfig:
# FIXME(woosuk): This is a temporary fix for StarCoder2.
# Remove this when the model is supported by HuggingFace transformers.
if "bigcode" in model and "starcoder2" in model:
config_class = _CONFIG_REGISTRY["starcoder2"]
config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
return config
try: try:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision) model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision)
except ValueError as e: except ValueError as e:
if (not trust_remote_code and if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)): "requires you to execute the configuration file" in str(e)):
...@@ -35,5 +45,7 @@ def get_config(model: str, ...@@ -35,5 +45,7 @@ def get_config(model: str,
raise e raise e
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type] config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision) config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
return config return config
from vllm.transformers_utils.configs.aquila import AquilaConfig
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.qwen import QWenConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and # RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library. # `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.yi import YiConfig from vllm.transformers_utils.configs.starcoder2 import Starcoder2Config
__all__ = [ __all__ = [
"AquilaConfig",
"BaiChuanConfig",
"ChatGLMConfig", "ChatGLMConfig",
"MPTConfig", "MPTConfig",
"QWenConfig",
"RWConfig", "RWConfig",
"YiConfig", "Starcoder2Config",
] ]
# coding=utf-8
# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Aquila model configuration"""
from transformers import PretrainedConfig
class AquilaConfig(PretrainedConfig):
model_type = "aquila"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=100008,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.006,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.configuration_utils import PretrainedConfig
class BaiChuanConfig(PretrainedConfig):
model_type = "baichuan"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=64000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
from transformers import PretrainedConfig
class QWenConfig(PretrainedConfig):
model_type = "qwen"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
emb_dropout_prob=0.0,
attn_dropout_prob=0.0,
layer_norm_epsilon=1e-6,
initializer_range=0.02,
max_position_embeddings=8192,
scale_attn_weights=True,
use_cache=True,
bf16=False,
fp16=False,
fp32=False,
kv_channels=128,
rotary_pct=1.0,
rotary_emb_base=10000,
use_dynamic_ntk=True,
use_logn_attn=True,
use_flash_attn="auto",
intermediate_size=22016,
no_bias=True,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.emb_dropout_prob = emb_dropout_prob
self.attn_dropout_prob = attn_dropout_prob
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.bf16 = bf16
self.fp16 = fp16
self.fp32 = fp32
self.kv_channels = kv_channels
self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base
self.use_dynamic_ntk = use_dynamic_ntk
self.use_logn_attn = use_logn_attn
self.use_flash_attn = use_flash_attn
self.no_bias = no_bias
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
from transformers import PretrainedConfig
class Starcoder2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Starcoder2Model`]. It is used to instantiate a
Starcoder2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the [bigcode/starcoder2-7b_16k](https://huggingface.co/bigcode/starcoder2-7b_16k) model.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 49152):
Vocabulary size of the Starcoder2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Starcoder2Model`]
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 12288):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 30):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 24):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 2):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with. Starcoder2's sliding window attention
allows sequence of up to 4096*32 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
norm_epsilon (`float`, *optional*, defaults to 1e-05):
Epsilon value for the layer norm
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
bos_token_id (`int`, *optional*, defaults to 50256):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 50256):
The id of the "end-of-sequence" token.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
sliding_window (`int`, *optional*):
Sliding window attention window size. If not specified, will default to `None` (no sliding window).
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
residual_dropout (`float`, *optional*, defaults to 0.0):
Residual connection dropout value.
embedding_dropout (`float`, *optional*, defaults to 0.0):
Embedding dropout.
use_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias term on linear layers of the model.
```python
>>> from transformers import Starcoder2Model, Starcoder2Config
>>> # Initializing a Starcoder2 7B style configuration
>>> configuration = Starcoder2Config()
>>> # Initializing a model from the Starcoder2 7B style configuration
>>> model = Starcoder2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "starcoder2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=49152,
hidden_size=3072,
intermediate_size=12288,
num_hidden_layers=30,
num_attention_heads=24,
num_key_value_heads=2,
hidden_act="gelu_pytorch_tanh",
max_position_embeddings=4096,
initializer_range=0.018042,
norm_epsilon=1e-5,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
rope_theta=10000.0,
sliding_window=None,
attention_dropout=0.0,
residual_dropout=0.0,
embedding_dropout=0.0,
use_bias=True,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.use_bias = use_bias
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.norm_epsilon = norm_epsilon
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.embedding_dropout = embedding_dropout
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
if self.architectures is None:
self.architectures = ['Starcoder2ForCausalLM']
""" Yi model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
Yi_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class YiConfig(PretrainedConfig):
r"""
Reference:
https://huggingface.co/01-ai/Yi-6B/blob/main/configuration_yi.py
"""
model_type = "Yi"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=64000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=4,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
output_attentions=False,
rope_theta=5000000.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.output_attentions = output_attentions
self.rope_theta = rope_theta
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
...@@ -4,6 +4,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, ...@@ -4,6 +4,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast) PreTrainedTokenizerFast)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.utils import make_async, LRUCache
from vllm.transformers_utils.tokenizers import * from vllm.transformers_utils.tokenizers import *
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -65,6 +67,84 @@ def get_tokenizer( ...@@ -65,6 +67,84 @@ def get_tokenizer(
return tokenizer return tokenizer
def get_lora_tokenizer(lora_request: LoRARequest, *args,
**kwargs) -> Optional[PreTrainedTokenizer]:
if lora_request is None:
return None
try:
tokenizer = get_tokenizer(lora_request.lora_local_path, *args,
**kwargs)
except OSError as e:
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger.warning(
f"No tokenizer found in {lora_request.lora_local_path}, "
"using base model tokenizer instead. "
f"(Exception: {str(e)})")
tokenizer = None
return tokenizer
get_lora_tokenizer_async = make_async(get_lora_tokenizer)
class TokenizerGroup:
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int], **tokenizer_config):
self.tokenizer_id = tokenizer_id
self.tokenizer_config = tokenizer_config
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
if enable_lora:
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
else:
self.lora_tokenizers = None
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
return tokenizer.encode(prompt)
async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
return tokenizer.encode(prompt)
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (get_lora_tokenizer(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (await get_lora_tokenizer_async(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
def _convert_tokens_to_string_with_added_encoders( def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str], output_tokens: List[str],
......
import enum import enum
import os import os
import socket import socket
import subprocess
import uuid import uuid
from platform import uname from platform import uname
from typing import List from typing import List, Tuple, Union
from packaging.version import parse, Version
import psutil import psutil
import torch import torch
import asyncio
from functools import partial
from typing import (
Awaitable,
Callable,
TypeVar,
)
from collections import OrderedDict
from typing import Any, Hashable, Optional
from vllm._C import cuda_utils from vllm.logger import init_logger
T = TypeVar("T")
logger = init_logger(__name__)
STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
# "bfloat16": torch.bfloat16,
"float": torch.float,
# "fp8_e5m2": torch.uint8,
}
class Device(enum.Enum): class Device(enum.Enum):
...@@ -30,16 +51,91 @@ class Counter: ...@@ -30,16 +51,91 @@ class Counter:
self.counter = 0 self.counter = 0
class LRUCache:
def __init__(self, capacity: int):
self.cache = OrderedDict()
self.capacity = capacity
def __contains__(self, key: Hashable) -> bool:
return key in self.cache
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: Hashable) -> Any:
return self.get(key)
def __setitem__(self, key: Hashable, value: Any) -> None:
self.put(key, value)
def __delitem__(self, key: Hashable) -> None:
self.pop(key)
def touch(self, key: Hashable) -> None:
self.cache.move_to_end(key)
def get(self, key: Hashable, default_value: Optional[Any] = None) -> int:
if key in self.cache:
value = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
return value
def put(self, key: Hashable, value: Any) -> None:
self.cache[key] = value
self.cache.move_to_end(key)
self._remove_old_if_needed()
def _on_remove(self, key: Hashable, value: Any):
pass
def remove_oldest(self):
if not self.cache:
return
key, value = self.cache.popitem(last=False)
self._on_remove(key, value)
def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
self.remove_oldest()
def pop(self, key: int, default_value: Optional[Any] = None) -> Any:
run_on_remove = key in self.cache
value = self.cache.pop(key, default_value)
if run_on_remove:
self._on_remove(key, value)
return value
def clear(self):
while len(self.cache) > 0:
self.remove_oldest()
self.cache.clear()
def is_hip() -> bool: def is_hip() -> bool:
return torch.version.hip is not None return torch.version.hip is not None
def is_neuron() -> bool:
try:
import transformers_neuronx
except ImportError:
transformers_neuronx = None
return transformers_neuronx is not None
def get_max_shared_memory_bytes(gpu: int = 0) -> int: def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes.""" """Returns the maximum shared memory per thread block in bytes."""
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html # NOTE: This import statement should be executed lazily since
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 # the Neuron-X backend does not have the `cuda_utils` module.
max_shared_mem = cuda_utils.get_device_attribute( from vllm._C import cuda_utils
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
max_shared_mem = cuda_utils.get_max_shared_memory_per_block_device_attribute(
gpu)
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py will fail
assert max_shared_mem > 0, "max_shared_mem can not be zero"
return int(max_shared_mem) return int(max_shared_mem)
...@@ -57,17 +153,159 @@ def in_wsl() -> bool: ...@@ -57,17 +153,159 @@ def in_wsl() -> bool:
return "microsoft" in " ".join(uname()).lower() return "microsoft" in " ".join(uname()).lower()
def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
"""Take a blocking function, and run it on in an executor thread.
This function prevents the blocking function from blocking the
asyncio event loop.
The code in this function needs to be thread safe.
"""
def _async_wrapper(*args, **kwargs) -> asyncio.Future:
loop = asyncio.get_event_loop()
p_func = partial(func, *args, **kwargs)
return loop.run_in_executor(executor=None, func=p_func)
return _async_wrapper
def get_ip() -> str: def get_ip() -> str:
# try ipv4
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable try:
return s.getsockname()[0] s.connect(("dns.google", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except OSError:
# try ipv6
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
s.connect(("dns.google", 80))
return s.getsockname()[0]
def get_distributed_init_method(ip: str, port: int) -> str:
return f"tcp://{ip}:{port}"
def get_open_port() -> int: def get_open_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: # try ipv4
s.bind(("", 0)) try:
return s.getsockname()[1] with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def set_cuda_visible_devices(device_ids: List[int]) -> None: def set_cuda_visible_devices(device_ids: List[int]) -> None:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
def get_nvcc_cuda_version() -> Optional[Version]:
cuda_home = os.environ.get('CUDA_HOME')
if not cuda_home:
cuda_home = '/usr/local/cuda'
if os.path.isfile(cuda_home + '/bin/nvcc'):
logger.info(
f'CUDA_HOME is not found in the environment. Using {cuda_home} as CUDA_HOME.'
)
else:
logger.warning(
f'Not found nvcc in {cuda_home}. Skip cuda version check!')
return None
nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
universal_newlines=True)
output = nvcc_output.split()
release_idx = output.index("release") + 1
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
return nvcc_cuda_version
def _generate_random_fp8_e5m2(
tensor: torch.tensor,
low: float,
high: float,
) -> None:
# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
# it may occur Inf or NaN if we directly use torch.randint
# to generate random data for fp8 data.
# For example, s.11111.00 in fp8e5m2 format represents Inf.
# | E4M3 | E5M2
#-----|-------------|-------------------
# Inf | N/A | s.11111.00
# NaN | s.1111.111 | s.11111.{01,10,11}
from vllm._C import cache_ops
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
tensor_tmp.uniform_(low, high)
cache_ops.convert_fp8_e5m2(tensor_tmp, tensor)
del tensor_tmp
def create_kv_caches_with_random(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: Optional[int] = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
if isinstance(cache_dtype, str):
if cache_dtype == "auto":
if isinstance(model_dtype, str):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
elif isinstance(model_dtype, torch.dtype):
torch_dtype = model_dtype
else:
raise ValueError(f"Invalid model dtype: {model_dtype}")
elif cache_dtype in ["half", "bfloat16", "float"]:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
elif cache_dtype == "fp8_e5m2":
torch_dtype = torch.uint8
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
elif isinstance(cache_dtype, torch.dtype):
torch_dtype = cache_dtype
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = []
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape,
dtype=torch_dtype,
device=device)
if cache_dtype == 'fp8_e5m2':
_generate_random_fp8_e5m2(key_cache, -scale, scale)
elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
key_cache.uniform_(-scale, scale)
else:
raise ValueError(
f"Does not support key cache of type {cache_dtype}")
key_caches.append(key_cache)
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = []
for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape,
dtype=torch_dtype,
device=device)
if cache_dtype == 'fp8_e5m2':
_generate_random_fp8_e5m2(value_cache, -scale, scale)
elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
value_cache.uniform_(-scale, scale)
else:
raise ValueError(
f"Does not support value cache of type {cache_dtype}")
value_caches.append(value_cache)
return key_caches, value_caches
...@@ -3,10 +3,9 @@ from typing import Dict, List, Tuple ...@@ -3,10 +3,9 @@ from typing import Dict, List, Tuple
import torch import torch
from vllm._C import cache_ops
from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import in_wsl from vllm.utils import in_wsl, is_neuron, STR_DTYPE_TO_TORCH_DTYPE
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -34,12 +33,20 @@ class CacheEngine: ...@@ -34,12 +33,20 @@ class CacheEngine:
self.head_size = model_config.get_head_size() self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config) self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_kv_heads(parallel_config) self.num_heads = model_config.get_num_kv_heads(parallel_config)
self.dtype = model_config.dtype
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks self.num_gpu_blocks = cache_config.num_gpu_blocks
self.num_cpu_blocks = cache_config.num_cpu_blocks self.num_cpu_blocks = cache_config.num_cpu_blocks
# Skip initializing CUDA stream and buffer for Neuron backend.
if is_neuron():
return
if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype
else:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Initialize the cache. # Initialize the cache.
self.gpu_cache = self.allocate_gpu_cache() self.gpu_cache = self.allocate_gpu_cache()
self.cpu_cache = self.allocate_cpu_cache() self.cpu_cache = self.allocate_cpu_cache()
...@@ -100,11 +107,13 @@ class CacheEngine: ...@@ -100,11 +107,13 @@ class CacheEngine:
size=(self.num_cpu_blocks, *key_block_shape), size=(self.num_cpu_blocks, *key_block_shape),
dtype=self.dtype, dtype=self.dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
device="cpu",
) )
value_blocks = torch.empty( value_blocks = torch.empty(
size=(self.num_cpu_blocks, *value_block_shape), size=(self.num_cpu_blocks, *value_block_shape),
dtype=self.dtype, dtype=self.dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
device="cpu",
) )
cpu_cache.append((key_blocks, value_blocks)) cpu_cache.append((key_blocks, value_blocks))
return cpu_cache return cpu_cache
...@@ -115,6 +124,8 @@ class CacheEngine: ...@@ -115,6 +124,8 @@ class CacheEngine:
dst: List[KVCache], dst: List[KVCache],
src_to_dst: Dict[int, int], src_to_dst: Dict[int, int],
) -> None: ) -> None:
from vllm._C import cache_ops
with torch.cuda.stream(self.cache_stream): with torch.cuda.stream(self.cache_stream):
for i in range(self.num_layers): for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i] src_key_cache, src_value_cache = src[i]
...@@ -134,6 +145,8 @@ class CacheEngine: ...@@ -134,6 +145,8 @@ class CacheEngine:
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst) self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
from vllm._C import cache_ops
key_caches = [key_cache for key_cache, _ in self.gpu_cache] key_caches = [key_cache for key_cache, _ in self.gpu_cache]
value_caches = [value_cache for _, value_cache in self.gpu_cache] value_caches = [value_cache for _, value_cache in self.gpu_cache]
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU. # NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
...@@ -142,6 +155,7 @@ class CacheEngine: ...@@ -142,6 +155,7 @@ class CacheEngine:
@staticmethod @staticmethod
def get_cache_block_size( def get_cache_block_size(
block_size: int, block_size: int,
cache_dtype: str,
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
) -> int: ) -> int:
...@@ -152,7 +166,11 @@ class CacheEngine: ...@@ -152,7 +166,11 @@ class CacheEngine:
key_cache_block = block_size * num_heads * head_size key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block) total = num_layers * (key_cache_block + value_cache_block)
dtype_size = _get_dtype_size(model_config.dtype) if cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
dtype_size = _get_dtype_size(dtype)
return dtype_size * total return dtype_size * total
......
import contextlib
import time import time
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Set, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
broadcast, broadcast_object_list) broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.parallel_state import (
with_cupy_nccl_for_all_reduce)
from vllm.model_executor.parallel_utils import custom_all_reduce
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.utils import in_wsl from vllm.utils import in_wsl
logger = init_logger(__name__) logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
_PAD_SLOT_ID = -1 _PAD_SLOT_ID = -1
LORA_WARMUP_RANK = 8
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
...@@ -30,19 +40,28 @@ class ModelRunner: ...@@ -30,19 +40,28 @@ class ModelRunner:
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
): ):
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.lora_config = lora_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py. # model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this. # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self.sliding_window = (model_config.get_sliding_window() self.sliding_window = (model_config.get_sliding_window()
if model_config is not None else None) if model_config is not None else None)
self.device_config = (device_config
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.model = None self.model = None
self.block_size = None # Set after initial profiling. self.block_size = None # Set after initial profiling.
self.lora_manager = None
self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture. self.graph_memory_pool = None # Set during graph capture.
...@@ -59,9 +78,37 @@ class ModelRunner: ...@@ -59,9 +78,37 @@ class ModelRunner:
self.graph_block_tables = None # Set after initial profiling. self.graph_block_tables = None # Set after initial profiling.
# cache in_wsl result # cache in_wsl result
self.in_wsl = in_wsl() self.in_wsl = in_wsl()
self.kv_cache_dtype = kv_cache_dtype
# Set enforce_eager to True for Neuron backend, to avoid capturing graph
if self.device_config.is_neuron:
self.model_config.enforce_eager = True
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(self.model_config) self.model = get_model(self.model_config,
self.device_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
vocab_size = self.model.config.vocab_size
if self.lora_config:
assert hasattr(
self.model, "supported_lora_modules"
) and self.model.supported_lora_modules, "Model does not support LoRA"
assert hasattr(
self.model,
"embedding_modules"), "Model does not have embedding_modules"
assert hasattr(self.model, "embedding_padding_modules"
), "Model does not have embedding_padding_modules"
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens +
self.scheduler_config.max_paddings, vocab_size,
self.lora_config, self.device, self.model.embedding_modules,
self.model.embedding_padding_modules)
self.model = self.lora_manager.create_lora_manager(self.model)
def set_block_size(self, block_size: int) -> None: def set_block_size(self, block_size: int) -> None:
self.block_size = block_size self.block_size = block_size
...@@ -74,13 +121,20 @@ class ModelRunner: ...@@ -74,13 +121,20 @@ class ModelRunner:
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int]]: ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
List[int], List[int], Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = [] slot_mapping: List[List[int]] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
prompt_lens: List[int] = [] prompt_lens: List[int] = []
context_lens: List[int] = []
subquery_lens: List[int] = []
prefix_block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
...@@ -91,11 +145,34 @@ class ModelRunner: ...@@ -91,11 +145,34 @@ class ModelRunner:
prompt_tokens = seq_data.get_token_ids() prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens) prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len) prompt_lens.append(prompt_len)
prefix_len = 0
prefix = seq_group_metadata.prefix
if prefix is not None and prefix.computed:
prefix_len = prefix.get_length()
prompt_tokens = prompt_tokens[prefix_len:]
prefix_block_tables.append(prefix.get_block_numbers())
else:
prefix_block_tables.append([])
# actual prompt lens
context_lens.append(prefix_len)
subquery_lens.append(prompt_len - prefix_len)
input_tokens.append(prompt_tokens) input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt # NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.append(list(range(prompt_len))) input_positions.append(
list(range(prefix_len, prefix_len + len(prompt_tokens))))
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping.append([lora_id] * (prompt_len - prefix_len))
lora_prompt_mapping.extend(
[lora_id] *
(prompt_len - prefix_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized # During memory profiling, the block tables are not initialized
...@@ -113,8 +190,11 @@ class ModelRunner: ...@@ -113,8 +190,11 @@ class ModelRunner:
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0 start_idx = 0
if self.sliding_window is not None: if self.sliding_window is not None:
assert prefix_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention")
start_idx = max(0, prompt_len - self.sliding_window) start_idx = max(0, prompt_len - self.sliding_window)
for i in range(prompt_len): for i in range(prefix_len, prompt_len):
if i < start_idx: if i < start_idx:
slot_mapping[-1].append(_PAD_SLOT_ID) slot_mapping[-1].append(_PAD_SLOT_ID)
continue continue
...@@ -124,45 +204,87 @@ class ModelRunner: ...@@ -124,45 +204,87 @@ class ModelRunner:
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot) slot_mapping[-1].append(slot)
max_prompt_len = max(prompt_lens) max_prompt_len = max(subquery_lens)
input_tokens = _make_tensor_with_pad(input_tokens, input_tokens = _make_tensor_with_pad(input_tokens,
max_prompt_len, max_prompt_len,
pad=0, pad=0,
dtype=torch.long) dtype=torch.long,
device=self.device)
input_positions = _make_tensor_with_pad(input_positions, input_positions = _make_tensor_with_pad(input_positions,
max_prompt_len, max_prompt_len,
pad=0, pad=0,
dtype=torch.long) dtype=torch.long,
device=self.device)
slot_mapping = _make_tensor_with_pad(slot_mapping, slot_mapping = _make_tensor_with_pad(slot_mapping,
max_prompt_len, max_prompt_len,
pad=_PAD_SLOT_ID, pad=_PAD_SLOT_ID,
dtype=torch.long) dtype=torch.long,
device=self.device)
lora_index_mapping = [
_pad_to_max(mapping, max_prompt_len, pad=0)
for mapping in lora_index_mapping
]
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
# Prepare prefix block tables
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
block_tables = _make_tensor_with_pad(
prefix_block_tables,
max_len=max_prompt_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
)
start_loc_tensor = torch.arange(0,
len(prompt_lens) * max_prompt_len,
max_prompt_len,
dtype=torch.long,
device=self.device)
prompt_lens_tensor = torch.tensor(prompt_lens,
dtype=torch.long,
device=self.device)
input_metadata = InputMetadata( input_metadata = InputMetadata(
is_prompt=True, is_prompt=True,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
prompt_lens=prompt_lens_tensor,
max_seq_len=max_prompt_len,
start_loc=start_loc_tensor,
max_context_len=None, max_context_len=None,
context_lens=None, context_lens=context_lens_tensor,
block_tables=None, block_tables=block_tables,
use_cuda_graph=False, use_cuda_graph=False,
kv_cache_dtype=self.kv_cache_dtype,
) )
return input_tokens, input_positions, input_metadata, prompt_lens return (input_tokens, input_positions, input_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests)
def _prepare_decode( def _prepare_decode(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = [] slot_mapping: List[List[int]] = []
context_lens: List[int] = [] context_lens: List[int] = []
block_tables: List[List[int]] = [] block_tables: List[List[int]] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id() generation_token = seq_data.get_last_token_id()
...@@ -181,6 +303,8 @@ class ModelRunner: ...@@ -181,6 +303,8 @@ class ModelRunner:
block_offset = position % self.block_size block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append([slot]) slot_mapping.append([slot])
lora_index_mapping.append([lora_id])
lora_prompt_mapping.append(lora_id)
if self.sliding_window is not None: if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window // sliding_window_blocks = (self.sliding_window //
...@@ -211,20 +335,20 @@ class ModelRunner: ...@@ -211,20 +335,20 @@ class ModelRunner:
max_len=1, max_len=1,
pad=0, pad=0,
dtype=torch.long, dtype=torch.long,
device="cuda") device=self.device)
input_positions = _make_tensor_with_pad(input_positions, input_positions = _make_tensor_with_pad(input_positions,
max_len=1, max_len=1,
pad=0, pad=0,
dtype=torch.long, dtype=torch.long,
device="cuda") device=self.device)
slot_mapping = _make_tensor_with_pad(slot_mapping, slot_mapping = _make_tensor_with_pad(slot_mapping,
max_len=1, max_len=1,
pad=_PAD_SLOT_ID, pad=_PAD_SLOT_ID,
dtype=torch.long, dtype=torch.long,
device="cuda") device=self.device)
context_lens = torch.tensor(context_lens, context_lens = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device="cuda") device=self.device)
if use_captured_graph: if use_captured_graph:
# The shape of graph_block_tables is # The shape of graph_block_tables is
...@@ -233,40 +357,52 @@ class ModelRunner: ...@@ -233,40 +357,52 @@ class ModelRunner:
for i, block_table in enumerate(block_tables): for i, block_table in enumerate(block_tables):
if block_table: if block_table:
input_block_tables[i, :len(block_table)] = block_table input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device="cuda") block_tables = torch.tensor(input_block_tables, device=self.device)
else: else:
max_block_table_len = (max_context_len + self.block_size - max_block_table_len = max(
1) // self.block_size len(block_table) for block_table in block_tables)
block_tables = _make_tensor_with_pad( block_tables = _make_tensor_with_pad(
block_tables, block_tables,
max_len=max_block_table_len, max_len=max_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device="cuda", device=self.device,
) )
lora_index_mapping = [
_pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping
]
input_metadata = InputMetadata( input_metadata = InputMetadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
prompt_lens=None,
max_seq_len=None,
start_loc=None,
max_context_len=max_context_len, max_context_len=max_context_len,
context_lens=context_lens, context_lens=context_lens,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
kv_cache_dtype=self.kv_cache_dtype,
) )
return input_tokens, input_positions, input_metadata return (input_tokens, input_positions, input_metadata,
lora_index_mapping, lora_prompt_mapping, lora_requests)
def _prepare_sample( def _prepare_sample(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int], prompt_lens: List[int],
subquery_lens: Optional[List[int]],
) -> SamplingMetadata: ) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = [] seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = [] selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0 selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0 categorized_sample_indices_start_idx = 0
pin_memory = not self.in_wsl and not self.device_config.is_neuron
max_prompt_len = max(prompt_lens) if prompt_lens else 1 max_subquery_len = max(subquery_lens) if subquery_lens else 1
for i, seq_group_metadata in enumerate(seq_group_metadata_list): for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params sampling_params = seq_group_metadata.sampling_params
...@@ -274,10 +410,11 @@ class ModelRunner: ...@@ -274,10 +410,11 @@ class ModelRunner:
if seq_group_metadata.is_prompt: if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1 assert len(seq_ids) == 1
prompt_len = prompt_lens[i] assert subquery_lens is not None
subquery_len = subquery_lens[i]
if sampling_params.prompt_logprobs is not None: if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip # NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += prompt_len - 1 categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[ categorized_sample_indices[
sampling_params.sampling_type].append( sampling_params.sampling_type].append(
...@@ -287,10 +424,14 @@ class ModelRunner: ...@@ -287,10 +424,14 @@ class ModelRunner:
if sampling_params.prompt_logprobs is not None: if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend( selected_token_indices.extend(
range(selected_token_start_idx, range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1)) selected_token_start_idx + subquery_len - 1))
selected_token_indices.append(selected_token_start_idx + selected_token_indices.append(selected_token_start_idx +
prompt_len - 1) subquery_len - 1)
selected_token_start_idx += max_prompt_len selected_token_start_idx += max_subquery_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device="cuda").manual_seed(sampling_params.seed)
else: else:
num_seqs = len(seq_ids) num_seqs = len(seq_ids)
selected_token_indices.extend( selected_token_indices.extend(
...@@ -304,11 +445,18 @@ class ModelRunner: ...@@ -304,11 +445,18 @@ class ModelRunner:
categorized_sample_indices_start_idx + num_seqs)) categorized_sample_indices_start_idx + num_seqs))
categorized_sample_indices_start_idx += num_seqs categorized_sample_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = _async_h2d(selected_token_indices, selected_token_indices = _async_h2d(selected_token_indices,
dtype=torch.long, dtype=torch.long,
pin_memory=not self.in_wsl) target_device=self.device,
pin_memory=pin_memory)
categorized_sample_indices = { categorized_sample_indices = {
t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl) t: _async_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=pin_memory)
for t, seq_ids in categorized_sample_indices.items() for t, seq_ids in categorized_sample_indices.items()
} }
...@@ -322,121 +470,95 @@ class ModelRunner: ...@@ -322,121 +470,95 @@ class ModelRunner:
prompt_lens=prompt_lens, prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices, selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices, categorized_sample_indices=categorized_sample_indices,
generators=generators,
) )
return sampling_metadata return sampling_metadata
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]: ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
Set[int], LoRAMapping]:
if self.is_driver_worker: if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors. # Prepare input tensors.
if is_prompt: if is_prompt:
(input_tokens, input_positions, input_metadata, (input_tokens, input_positions, input_metadata, prompt_lens,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list) subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_prompt(seq_group_metadata_list)
else: else:
(input_tokens, input_positions, input_metadata (input_tokens, input_positions, input_metadata,
) = self._prepare_decode(seq_group_metadata_list) lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = [] prompt_lens = []
subquery_lens = None
sampling_metadata = self._prepare_sample(seq_group_metadata_list, sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens) prompt_lens,
subquery_lens)
def get_size_or_none(x: Optional[torch.Tensor]):
return x.size() if x is not None else None if self.lora_config:
flat_lora_index_mapping = [
# Broadcast the input data. For input tensors, we first broadcast item for sublist in lora_index_mapping for item in sublist
# its shape and then broadcast the tensor to avoid high ]
# serialization cost. lora_mapping = LoRAMapping(
py_data = { flat_lora_index_mapping,
"input_tokens_size": lora_prompt_mapping,
input_tokens.size(), )
"input_positions_size": else:
input_positions.size(), lora_mapping = None
"is_prompt":
input_metadata.is_prompt, # Broadcast the metadata.
"slot_mapping_size": metadata_dict = {
get_size_or_none(input_metadata.slot_mapping), "input_tokens": input_tokens,
"max_context_len": "input_positions": input_positions,
input_metadata.max_context_len, "is_prompt": input_metadata.is_prompt,
"context_lens_size": "slot_mapping": input_metadata.slot_mapping,
get_size_or_none(input_metadata.context_lens), "prompt_lens": input_metadata.prompt_lens,
"block_tables_size": "max_seq_len": input_metadata.max_seq_len,
get_size_or_none(input_metadata.block_tables), "start_loc": input_metadata.start_loc,
"use_cuda_graph": "max_context_len": input_metadata.max_context_len,
input_metadata.use_cuda_graph, "context_lens": input_metadata.context_lens,
"selected_token_indices_size": "block_tables": input_metadata.block_tables,
sampling_metadata.selected_token_indices.size(), "use_cuda_graph": input_metadata.use_cuda_graph,
"kv_cache_dtype": input_metadata.kv_cache_dtype,
"selected_token_indices":
sampling_metadata.selected_token_indices,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
} }
broadcast_object_list([py_data], src=0) broadcast_tensor_dict(metadata_dict, src=0)
# TODO(zhuohan): Combine the broadcasts or set async_op=True.
broadcast(input_tokens, src=0)
broadcast(input_positions, src=0)
if input_metadata.slot_mapping is not None:
broadcast(input_metadata.slot_mapping, src=0)
if input_metadata.context_lens is not None:
broadcast(input_metadata.context_lens, src=0)
if input_metadata.block_tables is not None:
broadcast(input_metadata.block_tables, src=0)
broadcast(sampling_metadata.selected_token_indices, src=0)
else: else:
receving_list = [None] metadata_dict = broadcast_tensor_dict(src=0)
broadcast_object_list(receving_list, src=0) input_tokens = metadata_dict["input_tokens"]
py_data = receving_list[0] input_positions = metadata_dict["input_positions"]
input_tokens = torch.empty(*py_data["input_tokens_size"], lora_mapping = metadata_dict["lora_mapping"]
dtype=torch.long, lora_requests = metadata_dict["lora_requests"]
device="cuda")
broadcast(input_tokens, src=0)
input_positions = torch.empty(*py_data["input_positions_size"],
dtype=torch.long,
device="cuda")
broadcast(input_positions, src=0)
if py_data["slot_mapping_size"] is not None:
slot_mapping = torch.empty(*py_data["slot_mapping_size"],
dtype=torch.long,
device="cuda")
broadcast(slot_mapping, src=0)
else:
slot_mapping = None
if py_data["context_lens_size"] is not None:
context_lens = torch.empty(*py_data["context_lens_size"],
dtype=torch.int,
device="cuda")
broadcast(context_lens, src=0)
else:
context_lens = None
if py_data["block_tables_size"] is not None:
block_tables = torch.empty(*py_data["block_tables_size"],
dtype=torch.int,
device="cuda")
broadcast(block_tables, src=0)
else:
block_tables = None
selected_token_indices = torch.empty(
*py_data["selected_token_indices_size"],
dtype=torch.long,
device="cuda")
broadcast(selected_token_indices, src=0)
input_metadata = InputMetadata( input_metadata = InputMetadata(
is_prompt=py_data["is_prompt"], is_prompt=metadata_dict["is_prompt"],
slot_mapping=slot_mapping, slot_mapping=metadata_dict["slot_mapping"],
max_context_len=py_data["max_context_len"], prompt_lens=metadata_dict["prompt_lens"],
context_lens=context_lens, max_seq_len=metadata_dict["max_seq_len"],
block_tables=block_tables, start_loc=metadata_dict["start_loc"],
use_cuda_graph=py_data["use_cuda_graph"], max_context_len=metadata_dict["max_context_len"],
context_lens=metadata_dict["context_lens"],
block_tables=metadata_dict["block_tables"],
use_cuda_graph=metadata_dict["use_cuda_graph"],
kv_cache_dtype=metadata_dict["kv_cache_dtype"],
) )
sampling_metadata = SamplingMetadata( sampling_metadata = SamplingMetadata(
seq_groups=None, seq_groups=None,
seq_data=None, seq_data=None,
prompt_lens=None, prompt_lens=None,
selected_token_indices=selected_token_indices, selected_token_indices=metadata_dict["selected_token_indices"],
categorized_sample_indices=None, categorized_sample_indices=None,
generators=None,
perform_sampling=False, perform_sampling=False,
) )
return input_tokens, input_positions, input_metadata, sampling_metadata return (input_tokens, input_positions, input_metadata,
sampling_metadata, lora_requests, lora_mapping)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
...@@ -444,8 +566,13 @@ class ModelRunner: ...@@ -444,8 +566,13 @@ class ModelRunner:
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
input_tokens, input_positions, input_metadata, sampling_metadata = ( (input_tokens, input_positions, input_metadata, sampling_metadata,
self.prepare_input_tensors(seq_group_metadata_list)) lora_requests,
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
# Execute the model. # Execute the model.
if input_metadata.use_cuda_graph: if input_metadata.use_cuda_graph:
graph_batch_size = input_tokens.shape[0] graph_batch_size = input_tokens.shape[0]
...@@ -474,6 +601,28 @@ class ModelRunner: ...@@ -474,6 +601,28 @@ class ModelRunner:
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
# passed in, which contains a lora from the lora warmup path.
dummy_lora_requests = []
dummy_lora_requests_per_seq = []
if self.lora_config:
for idx in range(self.lora_config.max_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_local_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]
# Profile memory usage with max_num_sequences sequences and the total # Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens. # number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = [] seqs: List[SequenceGroupMetadata] = []
...@@ -487,6 +636,8 @@ class ModelRunner: ...@@ -487,6 +636,8 @@ class ModelRunner:
seq_data={group_id: seq_data}, seq_data={group_id: seq_data},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables=None, block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
) )
seqs.append(seq) seqs.append(seq)
...@@ -497,8 +648,38 @@ class ModelRunner: ...@@ -497,8 +648,38 @@ class ModelRunner:
torch.cuda.synchronize() torch.cuda.synchronize()
return return
def remove_all_loras(self) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_all_loras()
def set_active_loras(self, lora_requests: List[LoRARequest],
lora_mapping: LoRAMapping) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.set_active_loras(lora_requests, lora_mapping)
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_loras()
@torch.inference_mode() @torch.inference_mode()
def capture_model(self, kv_caches: List[KVCache]) -> None: def capture_model(self, kv_caches: List[KVCache]) -> None:
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
# deleted before the CUDA graphs.
self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
assert not self.model_config.enforce_eager assert not self.model_config.enforce_eager
logger.info("Capturing the model for CUDA graphs. This may lead to " logger.info("Capturing the model for CUDA graphs. This may lead to "
"unexpected consequences if the model is not static. To " "unexpected consequences if the model is not static. To "
...@@ -527,35 +708,61 @@ class ModelRunner: ...@@ -527,35 +708,61 @@ class ModelRunner:
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
] ]
# NOTE: Capturing the largest batch size first may help reduce the # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
# memory usage of CUDA graph. # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
for batch_size in reversed(batch_size_capture_list): # either custom all-reduce kernel or CuPy NCCL. When not using CUDA
# Create dummy input_metadata. # graph, we use either custom all-reduce kernel or PyTorch NCCL.
input_metadata = InputMetadata( # We always prioritize using custom all-reduce kernel but fall back
is_prompt=False, # to PyTorch or CuPy NCCL if it is disabled or not supported.
slot_mapping=slot_mapping[:batch_size], with custom_all_reduce.capture():
max_context_len=self.max_context_len_to_capture, # NOTE: Capturing the largest batch size first may help reduce the
context_lens=context_lens[:batch_size], # memory usage of CUDA graph.
block_tables=block_tables[:batch_size], for batch_size in reversed(batch_size_capture_list):
use_cuda_graph=True, # Create dummy input_metadata.
) input_metadata = InputMetadata(
is_prompt=False,
graph_runner = CUDAGraphRunner(self.model) slot_mapping=slot_mapping[:batch_size],
graph_runner.capture( prompt_lens=None,
input_tokens[:batch_size], max_seq_len=None,
input_positions[:batch_size], start_loc=None,
kv_caches, max_context_len=self.max_context_len_to_capture,
input_metadata, context_lens=context_lens[:batch_size],
memory_pool=self.graph_memory_pool, block_tables=block_tables[:batch_size],
) use_cuda_graph=True,
self.graph_memory_pool = graph_runner.graph.pool() kv_cache_dtype=self.kv_cache_dtype,
self.graph_runners[batch_size] = graph_runner )
if self.lora_config:
lora_mapping = LoRAMapping(
[0] * batch_size,
[0] * batch_size,
)
self.set_active_loras(set(), lora_mapping)
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens[:batch_size],
input_positions[:batch_size],
kv_caches,
input_metadata,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
end_time = time.perf_counter() end_time = time.perf_counter()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
# This usually takes < 10 seconds. # This usually takes < 10 seconds.
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.") logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
def __del__(self) -> None:
# Delete the CUDA graphs before deleting the CuPy NCCL communicator.
# NOTE(woosuk): This is necessary because otherwise deadlocks can
# happen.
# FIXME(woosuk): This is a bit hacky. Find a more robust solution.
self.graph_runners.clear()
self.cupy_nccl_backend = None
class CUDAGraphRunner: class CUDAGraphRunner:
...@@ -577,18 +784,8 @@ class CUDAGraphRunner: ...@@ -577,18 +784,8 @@ class CUDAGraphRunner:
# Run the model once without capturing the graph. # Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the # This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune). # kernel launches for initial benchmarking (e.g., Triton autotune).
self.model( with _maybe_cupy_nccl():
input_ids, self.model(
positions,
kv_caches,
input_metadata,
)
torch.cuda.synchronize()
# Capture the graph.
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool):
hidden_states = self.model(
input_ids, input_ids,
positions, positions,
kv_caches, kv_caches,
...@@ -596,6 +793,20 @@ class CUDAGraphRunner: ...@@ -596,6 +793,20 @@ class CUDAGraphRunner:
) )
torch.cuda.synchronize() torch.cuda.synchronize()
# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
with _maybe_cupy_nccl():
hidden_states = self.model(
input_ids,
positions,
kv_caches,
input_metadata,
)
torch.cuda.synchronize()
# Save the input and output buffers. # Save the input and output buffers.
self.input_buffers = { self.input_buffers = {
"input_ids": input_ids, "input_ids": input_ids,
...@@ -638,6 +849,15 @@ class CUDAGraphRunner: ...@@ -638,6 +849,15 @@ class CUDAGraphRunner:
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
@contextlib.contextmanager
def _maybe_cupy_nccl():
if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
with with_cupy_nccl_for_all_reduce():
yield
else:
yield
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len assert len(x) <= max_len
return x + [pad] * (max_len - len(x)) return x + [pad] * (max_len - len(x))
...@@ -648,14 +868,10 @@ def _make_tensor_with_pad( ...@@ -648,14 +868,10 @@ def _make_tensor_with_pad(
max_len: int, max_len: int,
pad: int, pad: int,
dtype: torch.dtype, dtype: torch.dtype,
device: Union[str, torch.device] = "cuda", device: Optional[Union[str, torch.device]],
pin_memory: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, return torch.tensor(padded_x, dtype=dtype, device=device)
dtype=dtype,
device=device,
pin_memory=pin_memory and str(device) == "cpu")
def _get_graph_batch_size(batch_size: int) -> int: def _get_graph_batch_size(batch_size: int) -> int:
...@@ -667,6 +883,11 @@ def _get_graph_batch_size(batch_size: int) -> int: ...@@ -667,6 +883,11 @@ def _get_graph_batch_size(batch_size: int) -> int:
return (batch_size + 7) // 8 * 8 return (batch_size + 7) // 8 * 8
def _async_h2d(data: list, dtype, pin_memory): def _async_h2d(
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory) data: list,
return t.to(device="cuda", non_blocking=True) dtype: torch.dtype,
target_device: Union[str, torch.device],
pin_memory: bool,
) -> torch.Tensor:
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
return t.to(device=target_device, non_blocking=True)
"""A Neuron worker class."""
from typing import Dict, List, Optional, Tuple
import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.parallel_state import (
ensure_model_parallel_initialized)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
class Worker:
"""A worker class that executes the model on a group of neuron cores.
"""
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
self.model_runner = ModelRunner(model_config,
parallel_config,
scheduler_config,
device_config,
lora_config=self.lora_config,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
self.cache_engine = None
self.cache_events = None
self.gpu_cache = None
def init_model(self) -> None:
# Initialize the distributed environment.
_init_distributed_environment(self.parallel_config,
self.rank,
self.distributed_init_method,
distributed_backend="gloo")
# Initialize the model.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
@torch.inference_mode()
def profile_num_available_blocks(
self,
block_size: int = 128,
gpu_memory_utilization: float = 0.9,
cpu_swap_space: int = 0,
cache_dtype: str = "float16",
) -> Tuple[int, int]:
"""Simply returns max_num_seqs as num_gpu_blocks, 0 as num_cpu_blocks."""
num_gpu_blocks = self.scheduler_config.max_num_seqs
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.model_runner.set_block_size(self.cache_engine.block_size)
def warm_up_model(self) -> None:
# Warm up is maintained in transformers-neuronx
pass
def cache_swap(
self,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
# Issue cache operations.
issued_cache_op = False
if blocks_to_swap_in:
self.cache_engine.swap_in(blocks_to_swap_in)
issued_cache_op = True
if blocks_to_swap_out:
self.cache_engine.swap_out(blocks_to_swap_out)
issued_cache_op = True
if blocks_to_copy:
self.cache_engine.copy(blocks_to_copy)
issued_cache_op = True
cache_events = self.cache_events if issued_cache_op else None
# Wait for cache operations to finish.
if cache_events is not None:
raise NotImplementedError(
"cache operations are not implemented for neuron backend.")
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> Optional[SamplerOutput]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
data = {
"num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
else:
data = broadcast_tensor_dict(src=0)
num_seq_groups = data["num_seq_groups"]
blocks_to_swap_in = data["blocks_to_swap_in"]
blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"]
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return {}
output = self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache)
return output
def _init_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
distributed_backend: Optional[str] = None,
) -> None:
"""Initialize the distributed environment."""
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
distributed_backend = distributed_backend if distributed_backend else "nccl"
torch.distributed.init_process_group(
backend=distributed_backend,
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1))
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
from typing import List, Dict
import copy
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.worker import Worker
class MultiStepWorker(Worker):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
by invoking the scheduler less.
The MultiStepWorker does not support cache swap operations, or beam search.
Cache swap operations do not require large modifications. On the other hand,
beam search requires memory allocations during sequence forks and thus
requires more thought for MultiStepWorker support.
"""
@torch.inference_mode()
def execute_model_multi_step(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_steps: int,
) -> List[SamplerOutput]:
"""Run the model forward pass num_steps times. Returns the list of
sampler output, one per model forward pass.
"""
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
blocks_to_swap_out, blocks_to_copy)
# Shallow copy input data so modifications (such as appending tokens)
# do not cause side-effects.
copied_seq_group_metadata_list = self._shallow_copy_inputs(
seq_group_metadata_list)
# Assert enough KV space for num_steps tokens per sequence.
self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
# Run model num_steps times.
model_outputs = []
for _ in range(num_steps):
model_output = super().execute_model(
seq_group_metadata_list=copied_seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
self._append_new_tokens(model_output,
copied_seq_group_metadata_list)
model_outputs.append(model_output)
return model_outputs
def _append_new_tokens(
self, model_output: SamplerOutput,
seq_group_metadata_list: SequenceGroupMetadata) -> None:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes.
"""
for seq_group_metadata, sequence_group_outputs in zip(
seq_group_metadata_list, model_output):
seq_group_metadata.is_prompt = False
for seq_output in sequence_group_outputs.samples:
# NOTE: Beam search is not supported, so we can assume that
# parent_seq_id == seq_id.
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]
seq.append_token_id(token_id, token_logprob)
def _shallow_copy_inputs(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
) -> List[SequenceGroupMetadata]:
"""Copy input data structures to remove side-effects when input data
structures are shared with other modules.
The multi-step worker must be able to append tokens to sequences after
a forward pass. This necessitates modification of the data structures
used by the worker. Since these data structures are shared with other
parts of vLLM, like the scheduler, we must take care not to introduce
unexpected side-effects.
When Ray is used to orchestrate worker processes (such as when the
tensor-parallel degree is >1), this is not a problem because the input
datastructures will be serialized and created anew in the worker
process.
However, when Ray is not used to orchestrate the worker processes (such
as when the tensor-parallel degree is 1), this is a problem. We avoid
the problem by shallow-copying the input datastructures (specifically,
the parts that will change in multiple steps).
"""
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects.
new_seq_group_metadata_list = []
for old_seq_group_metadata in seq_group_metadata_list:
# We must shallow-copy seq_group_metadata as is_prompt could change.
seq_group_metadata = copy.copy(old_seq_group_metadata)
new_seq_group_metadata_list.append(seq_group_metadata)
# We must shallow-copy seq_data as we will append token ids
new_seq_data = {}
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
new_seq_data[seq_id] = copy.copy(old_seq_data)
new_seq_data[
seq_id].output_token_ids = old_seq_data.output_token_ids[:]
seq_group_metadata.seq_data = new_seq_data
return new_seq_group_metadata_list
def _assert_enough_kv_space(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
num_steps: int) -> None:
"""Assert there are enough physical blocks per sequence to store the
current KV plus additional KV from num_steps tokens.
"""
assert self.model_runner.block_size is not None
for seq_group_metadata in seq_group_metadata_list:
# Only one seq_id is guaranteed because there is no beam search.
seq_id = list(seq_group_metadata.seq_data.keys())[0]
seq = seq_group_metadata.seq_data[seq_id]
# After num_steps, the seq len will be the current seq len
# plus one token per step.
final_seq_len = seq.get_len() + num_steps
# We will have final_seq_len - 1 KV because vLLM saves KV for a
# token in the iteration after the token was generated.
required_num_kv_slots = final_seq_len - 1
# The allocated number of kv slots is the number of allocated blocks
# times the number of slots of block.
number_physical_blocks = len(
seq_group_metadata.block_tables[seq_id])
allocated_kv_slots = (number_physical_blocks *
self.model_runner.block_size)
if required_num_kv_slots > allocated_kv_slots:
request_id = seq_group_metadata.request_id
raise ValueError(
"The worker attempted to run "
f"{num_steps} times but found insufficient KV space for "
f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
f"{required_num_kv_slots=}).")
def _raise_if_unsupported(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
raise NotImplementedError(
"MultiStepWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in seq_group_metadata_list):
raise NotImplementedError(
"MultiStepWorker does not support beam search.")
"""A GPU worker class.""" """A GPU worker class."""
import gc
import os import os
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Tuple, Set, Optional
import torch import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
SchedulerConfig) ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
broadcast_object_list) broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel) ensure_model_parallel_initialized)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner
from vllm.lora.request import LoRARequest
from vllm.utils import is_hip
class Worker: class Worker:
...@@ -30,23 +35,33 @@ class Worker: ...@@ -30,23 +35,33 @@ class Worker:
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
self.model_runner = ModelRunner(model_config, parallel_config, self.model_runner = ModelRunner(model_config,
scheduler_config, is_driver_worker) parallel_config,
scheduler_config,
device_config,
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# self.init_cache_engine(). # self.init_cache_engine().
self.cache_config = None self.cache_config = None
...@@ -54,26 +69,30 @@ class Worker: ...@@ -54,26 +69,30 @@ class Worker:
self.cache_events = None self.cache_events = None
self.gpu_cache = None self.gpu_cache = None
def init_model(self) -> None: def init_model(self, cupy_port: Optional[int] = None) -> None:
# torch.distributed.all_reduce does not free the input tensor until if self.device_config.device.type == "cuda":
# the synchronization point. This causes the memory usage to grow # torch.distributed.all_reduce does not free the input tensor until
# as the number of all_reduce calls increases. This env var disables # the synchronization point. This causes the memory usage to grow
# this behavior. # as the number of all_reduce calls increases. This env var disables
# Related issue: # this behavior.
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 # Related issue:
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) # This env var set by Ray causes exceptions with graph building.
self.device = torch.device(f"cuda:{self.local_rank}") os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
torch.cuda.set_device(self.device) self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
_check_if_gpu_supports_dtype(self.model_config.dtype)
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment. # Initialize the distributed environment.
_init_distributed_environment(self.parallel_config, self.rank, init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method) cupy_port, self.distributed_init_method)
# Initialize the model. # Initialize the model.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
...@@ -86,6 +105,7 @@ class Worker: ...@@ -86,6 +105,7 @@ class Worker:
block_size: int, block_size: int,
gpu_memory_utilization: float, gpu_memory_utilization: float,
cpu_swap_space: int, cpu_swap_space: int,
cache_dtype: str,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model and returns the maximum """Profiles the peak memory usage of the model and returns the maximum
number of GPU and CPU cache blocks that can be allocated. number of GPU and CPU cache blocks that can be allocated.
...@@ -107,16 +127,21 @@ class Worker: ...@@ -107,16 +127,21 @@ class Worker:
# profiled peak memory. # profiled peak memory.
torch.cuda.synchronize() torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = total_gpu_memory - free_gpu_memory # NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
cache_block_size = CacheEngine.get_cache_block_size( cache_block_size = CacheEngine.get_cache_block_size(
block_size, self.model_config, self.parallel_config) block_size, cache_dtype, self.model_config, self.parallel_config)
num_gpu_blocks = int( num_gpu_blocks = int(
(total_gpu_memory * gpu_memory_utilization - peak_memory) // (total_gpu_memory * gpu_memory_utilization - peak_memory) //
cache_block_size) cache_block_size)
num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_cpu_blocks = int(cpu_swap_space // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0) num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return num_gpu_blocks, num_cpu_blocks return num_gpu_blocks, num_cpu_blocks
...@@ -175,20 +200,21 @@ class Worker: ...@@ -175,20 +200,21 @@ class Worker:
assert blocks_to_swap_in is not None assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None assert blocks_to_swap_out is not None
assert blocks_to_copy is not None assert blocks_to_copy is not None
block_swapping_info = [ data = {
blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy "num_seq_groups": num_seq_groups,
] "blocks_to_swap_in": blocks_to_swap_in,
broadcast_object_list([num_seq_groups] + block_swapping_info, "blocks_to_swap_out": blocks_to_swap_out,
src=0) "blocks_to_copy": blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
else: else:
# num_seq_groups, blocks_to_swap_in, blocks_to_swap_out, data = broadcast_tensor_dict(src=0)
# blocks_to_copy (4 elements) num_seq_groups = data["num_seq_groups"]
recv_data = [None] * 4 blocks_to_swap_in = data["blocks_to_swap_in"]
broadcast_object_list(recv_data, src=0) blocks_to_swap_out = data["blocks_to_swap_out"]
num_seq_groups = recv_data[0] blocks_to_copy = data["blocks_to_copy"]
block_swapping_info = recv_data[1:]
self.cache_swap(*block_swapping_info) self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
if num_seq_groups == 0: if num_seq_groups == 0:
...@@ -198,10 +224,20 @@ class Worker: ...@@ -198,10 +224,20 @@ class Worker:
self.gpu_cache) self.gpu_cache)
return output return output
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
def _init_distributed_environment(
def init_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
rank: int, rank: int,
cupy_port: Optional[int],
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
...@@ -224,10 +260,35 @@ def _init_distributed_environment( ...@@ -224,10 +260,35 @@ def _init_distributed_environment(
init_method=distributed_init_method, init_method=distributed_init_method,
) )
if cupy_utils.is_initialized():
cupy_world_size = cupy_utils.get_world_size()
if cupy_world_size != parallel_config.world_size:
raise RuntimeError(
"cupy.distributed is already initialized but the cupy world "
"size does not match parallel_config.world_size "
f"({cupy_world_size} vs. {parallel_config.world_size}).")
elif (parallel_config.world_size > 1 and cupy_port is not None
and not is_hip()):
# NOTE(woosuk): We don't initialize CuPy process group when world size
# is 1.
# TODO(woosuk): Support multi-node connection.
cupy_utils.init_process_group(
world_size=parallel_config.world_size,
rank=rank,
host="localhost",
port=cupy_port,
)
# A small all_reduce for warmup. # A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda()) torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(parallel_config.tensor_parallel_size, if cupy_utils.is_initialized():
parallel_config.pipeline_parallel_size) cupy_utils.all_reduce(torch.zeros(1).cuda())
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
# Initialize a custom fast all-reduce implementation.
if not parallel_config.disable_custom_all_reduce:
init_custom_ar()
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment