Commit 66b809cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.2' into v0.7.2-dev

parents 37b63c24 0408efc6
# SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import os import os
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union
......
# SPDX-License-Identifier: Apache-2.0
from array import array from array import array
from typing import Any, Type from typing import Any, Type
......
# SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import os import os
import sys import sys
......
# SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import os import os
from collections import defaultdict from collections import defaultdict
...@@ -127,13 +129,7 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -127,13 +129,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
def _init_workers_ray(self, placement_group: "PlacementGroup", def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs): **ray_remote_kwargs):
if (self.parallel_config.tensor_parallel_size == 1 num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
and self.parallel_config.pipeline_parallel_size == 1):
# For single GPU case, we use a ray worker with constrained memory.
num_gpus = self.cache_config.gpu_memory_utilization
else:
# Otherwise, the ray workers are allocated with a full GPU.
num_gpus = 1
# The driver dummy worker does not actually use any resources. # The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker. # It holds the resource for the driver worker.
...@@ -153,12 +149,29 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -153,12 +149,29 @@ class RayDistributedExecutor(DistributedExecutorBase):
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers. # Create the workers.
driver_ip = get_ip() bundle_indices: List[int]
rank = 0 if envs.VLLM_RAY_BUNDLE_INDICES:
# Use the bundle indices specified by the user.
bundle_indices = list(
map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
assert len(bundle_indices) == self.parallel_config.world_size, \
("VLLM_RAY_BUNDLE_INDICES must have the same size"
f" as the world size, but got {bundle_indices=} "
f"and {self.parallel_config.world_size=}")
assert len(set(bundle_indices)) == len(bundle_indices), \
("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
f" but got {bundle_indices=}")
else:
# use the first N bundles that have GPU resources.
bundle_indices = []
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if bundle.get(current_platform.ray_device_key, 0):
bundle_indices.append(bundle_id)
bundle_indices = bundle_indices[:self.parallel_config.world_size]
worker_metadata: List[RayWorkerMetaData] = [] worker_metadata: List[RayWorkerMetaData] = []
for bundle_id, bundle in enumerate(placement_group.bundle_specs): driver_ip = get_ip()
if not bundle.get(current_platform.ray_device_key, 0): for rank, bundle_id in enumerate(bundle_indices):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy( scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group, placement_group=placement_group,
placement_group_capture_child_tasks=True, placement_group_capture_child_tasks=True,
...@@ -185,7 +198,6 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -185,7 +198,6 @@ class RayDistributedExecutor(DistributedExecutorBase):
rpc_rank=rank) rpc_rank=rank)
worker_metadata.append( worker_metadata.append(
RayWorkerMetaData(worker=worker, created_rank=rank)) RayWorkerMetaData(worker=worker, created_rank=rank))
rank += 1
worker_ips = ray.get([ worker_ips = ray.get([
each.worker.get_node_ip.remote() # type: ignore[attr-defined] each.worker.get_node_ip.remote() # type: ignore[attr-defined]
......
# SPDX-License-Identifier: Apache-2.0
import os import os
import time import time
from collections import defaultdict from collections import defaultdict
...@@ -212,7 +214,10 @@ def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): ...@@ -212,7 +214,10 @@ def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
logger.info( logger.info(
"Waiting for creating a placement group of specs for " "Waiting for creating a placement group of specs for "
"%d seconds. specs=%s. Check " "%d seconds. specs=%s. Check "
"`ray status` to see if you have enough resources.", "`ray status` to see if you have enough resources,"
" and make sure the IP addresses used by ray cluster"
" are the same as VLLM_HOST_IP environment variable"
" specified in each node if you are running on a multi-node.",
int(time.time() - s), placement_group_specs) int(time.time() - s), placement_group_specs)
try: try:
......
# SPDX-License-Identifier: Apache-2.0
import os import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
......
# SPDX-License-Identifier: Apache-2.0
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
......
# SPDX-License-Identifier: Apache-2.0
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonInputsAdapter, SingletonPrompt, SingletonInputs, SingletonInputsAdapter, SingletonPrompt,
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal, from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal,
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Literal, Sequence, TypedDict, Union, cast, overload from typing import List, Literal, Sequence, TypedDict, Union, cast, overload
from typing_extensions import TypeIs from typing_extensions import TypeIs
......
# SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
from typing import List, Mapping, Optional, Union from typing import List, Mapping, Optional, Union
......
# SPDX-License-Identifier: Apache-2.0
import functools import functools
from collections import UserDict from collections import UserDict
from dataclasses import dataclass from dataclasses import dataclass
...@@ -29,6 +31,17 @@ C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) ...@@ -29,6 +31,17 @@ C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin) P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin)
class HashableDict(dict):
"""
A dictionary that can be hashed by lru_cache.
"""
# NOTE: pythonic dict is not hashable,
# we override on it directly for simplicity
def __hash__(self) -> int: # type: ignore[override]
return hash(frozenset(self.items()))
@dataclass(frozen=True) @dataclass(frozen=True)
class InputContext: class InputContext:
""" """
...@@ -102,6 +115,13 @@ class InputContext: ...@@ -102,6 +115,13 @@ class InputContext:
if isinstance(typ, type): if isinstance(typ, type):
merged_kwargs["processor_cls"] = typ merged_kwargs["processor_cls"] = typ
# NOTE: Pythonic dict is not hashable and will raise unhashable type
# error when calling `cached_get_processor`, therefore we need to
# wrap it to a hashable dict.
for key, value in merged_kwargs.items():
if isinstance(value, dict):
merged_kwargs[key] = HashableDict(value)
hf_processor = cached_get_processor( hf_processor = cached_get_processor(
self.model_config.model, self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
......
# SPDX-License-Identifier: Apache-2.0
"""Logging configuration for vLLM.""" """Logging configuration for vLLM."""
import datetime import datetime
import json import json
......
# SPDX-License-Identifier: Apache-2.0
from vllm.logging_utils.formatter import NewLineFormatter from vllm.logging_utils.formatter import NewLineFormatter
__all__ = [ __all__ = [
......
# SPDX-License-Identifier: Apache-2.0
import logging import logging
......
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Tuple, Union from typing import Callable, List, Tuple, Union
import torch import torch
......
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=unused-argument # pylint: disable=unused-argument
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
......
# SPDX-License-Identifier: Apache-2.0
# pylint: disable=unused-argument # pylint: disable=unused-argument
import math import math
from dataclasses import dataclass from dataclasses import dataclass
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional from typing import List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
......
# SPDX-License-Identifier: Apache-2.0
import copy import copy
import math import math
import os import os
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment