Commit f87b35b2 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #2648 failed with stages
in 0 seconds
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py
import torch
import torch.nn as nn
from enum import IntEnum
from typing import Dict, List, Optional, Set, Tuple, Union
import warnings
import vllm.envs as envs
from vllm.attention import (AttentionMetadata, get_attn_backend)
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.models.interfaces import (supports_lora, supports_vision)
from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available)
from vllm.worker.model_runner import ModelRunner, CUDAGraphRunner
from vllm.prompt_adapter.worker_manager import (LRUCacheWorkerPromptAdapterManager)
from .model_loader import get_model
from .config import ModelConfig, LoadConfig
logger = init_logger(__name__)
# How batches are constructed.
class BatchType(IntEnum):
# Every batch is prefill.
PREFILL = 0
# Every batch is decode.
DECODE = 1
# Batch is a mixture of prefill and decode.
MIXED = 2
class ModelRunner(ModelRunner):
def __init__(
self,
model: Union[nn.Module, Dict], # [verl] model itself or its parameter dict
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False,
):
super().__init__(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
lora_config,
kv_cache_dtype,
is_driver_worker=True, # a hack
prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config,
return_hidden_states=return_hidden_states)
# NOTE(sgm): add for verl
self.model = model # this will be replaced by get_model()
# NOTE(sgm): initialize model using the actor model
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with CudaMemoryProfiler() as m:
self.model = get_model(actor_model=self.model,
model_config=self.model_config,
device_config=self.device_config,
lora_config=self.lora_config,
load_config=self.load_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
multimodal_config=self.multimodal_config,
cache_config=self.cache_config)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30))
if self.lora_config:
assert supports_lora(self.model), "Model does not support LoRA"
assert not supports_vision(self.model), "To be tested: vision language model with LoRA settings."
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=self.model.config.max_position_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)
if self.prompt_adapter_config:
self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, self.device,
self.prompt_adapter_config)
self.model = (self.prompt_adapter_manager.create_prompt_adapter_manager(self.model))
if self.kv_cache_dtype == "fp8" and is_hip():
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# in the future.
if self.model_config.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
warnings.warn(
"Loading kv cache scaling factor from JSON is "
"deprecated and will be removed. Please include "
"kv cache scaling factors in the model checkpoint.",
FutureWarning,
stacklevel=2)
self.model.load_kv_cache_scales(self.model_config.quantization_param_path)
logger.info("Loaded KV cache scaling factors from %s", self.model_config.quantization_param_path)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
"model %s does not support loading scaling factors.", self.model.__class__)
else:
logger.warning("Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE:
self.model = torch.compile(self.model, fullgraph=True, backend="eager")
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups."""
import os
import torch
import torch.distributed
from typing import Optional
import vllm.distributed.parallel_state as ps
from vllm.distributed.parallel_state import get_pp_group, get_world_group, init_distributed_environment, init_model_parallel_group
import vllm.envs as envs
from vllm.logger import init_logger
from torch.distributed.device_mesh import init_device_mesh
logger = init_logger(__name__)
"""
This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron.
- We assume the Megatron tp+dp+pp world is already established before calling this function.
"""
# Device mesh for using DTensor
_DEVICE_MESH = None
# Tensor model parallel group that the current rank belongs to.
_TP = None
# Pipeline model parallel group that the current rank belongs to.
_PP = None
# This method is for initializing the ParallelGroup when using HybridEngine
def initialize_parallel_state(
distributed_init_method: str = "env://",
backend: str = "nccl",
tensor_model_parallel_size: int = 1,
num_tp_per_train_tp: int = 1,
pipeline_model_parallel_size: int = 1,
):
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN.
rank = int(os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
# Use the world_size set by TORCHRUN
world_size = int(os.getenv("WORLD_SIZE", "-1"))
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend)
if torch.distributed.get_world_size() > 1:
# NOTE: build a sepearate inference group with infer tp & micro dp
initialize_model_parallel_for_vllm(tensor_model_parallel_size=tensor_model_parallel_size,
num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp)
else:
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
# get the backend of _DEVICE_WORLD_GROUP
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)
return
assert (get_tensor_model_parallel_world_size() == tensor_model_parallel_size), (
"tensor parallel group already initialized, but of unexpected size: "
f"{get_tensor_model_parallel_world_size()=} vs. "
f"{tensor_model_parallel_size=}")
pp_world_size = get_pp_group().world_size
assert (pp_world_size == pipeline_model_parallel_size), (
"pipeline parallel group already initialized, but of unexpected size: "
f"{pp_world_size=} vs. "
f"{pipeline_model_parallel_size=}")
# TODO(sgm): deviate from the v0.5.4, not pp now
def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized."""
return (ps._TP is not None)
# and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
def initialize_model_parallel_for_vllm(tensor_model_parallel_size: int,
num_tensor_model_parallel_groups_per_train_tp: int = 1,
pipeline_model_parallel_size: int = 1) -> None:
from torch.distributed import new_group
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
assert isinstance(tensor_model_parallel_size, int)
# assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group
# assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group
# Build the tensor model-parallel groups.
assert ps._TP is None, ("tensor model parallel group is already initialized")
global _TP
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = torch.distributed.get_backend()
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
if num_tensor_model_parallel_groups_per_train_tp == 1:
# if tensor_model_parallel_size == train_tensor_parallel_size:
# using the same tp group as Megatron/vllm
assert _TP is None, ("tensor model parallel group is already initialized")
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
group_ranks.append(ranks)
_TP = init_model_parallel_group(
group_ranks=group_ranks,
local_rank=get_world_group().local_rank,
backend=backend,
use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer
use_message_queue_broadcaster=True)
ps._TP = _TP
# _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine
else:
# initialize a micro_dp group and a tp group
# assume training tp=4, infer tp=2, then, weight is partitioned as
# [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference
# Build the inference tp groups
# train_tp = train_tensor_parallel_size
train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size
# num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size
assert _TP is None, ("tensor model parallel group is already initialized")
group_ranks = []
for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp):
start = train_tp * i
end = train_tp * (i + 1)
for j in range(num_tensor_model_parallel_groups_per_train_tp):
ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp))
for i in range(len(ranks)):
ranks[i] += j
group_ranks.append(ranks)
_TP = init_model_parallel_group(
group_ranks=group_ranks,
local_rank=get_world_group().local_rank,
backend=backend,
use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer
use_message_queue_broadcaster=True)
ps._TP = _TP
# Build the pipeline model-parallel groups.
# global _PIPELINE_MODEL_PARALLEL_GROUP
# global _PIPELINE_GLOBAL_RANKS
# assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized")
# ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group()
# ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks()
# TODO: init using device mesh (not support hybrid engine now)
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (world_size // pipeline_model_parallel_size)
global _PP
assert _PP is None, ("pipeline model parallel group is already initialized")
group_ranks = []
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
_PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False)
ps._PP = _PP # for verl
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""
NOTE: This method is a hack from the open-sourced version without
asertion of world_size = tp * pp
Initialize model parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model
parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model
parallelism.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
4 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 pipeline model-parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group)
# NOTE(sgm) we don't assert world_size == tp * pp
# DP is not managed by vllm but by the verl WorkerGroup
# if (world_size !=
# tensor_model_parallel_size * pipeline_model_parallel_size):
# raise RuntimeError(
# f"world_size ({world_size}) is not equal to "
# f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
# f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
num_tensor_model_parallel_groups: int = (world_size // tensor_model_parallel_size)
rank = torch.distributed.get_rank()
global _TP
assert _TP is None, ("tensor model parallel group is already initialized")
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size))
group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer
use_message_queue_broadcaster=True)
ps._TP = _TP
# TODO: init using device mesh (not support hybrid engine now)
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (world_size // pipeline_model_parallel_size)
global _PP
assert _PP is None, ("pipeline model parallel group is already initialized")
group_ranks = []
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
_PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False)
ps._PP = _PP # for verl
"""
Device mesh utilities
"""
def get_device_mesh():
assert _DEVICE_MESH is not None, ("device mesh is not initialized")
return _DEVICE_MESH
"""
Tensor model parallel utilities
"""
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TP is not None, ("tensor model parallel group is not initialized")
return _TP.device_group
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py
import os
import socket
from typing import Any, Dict, List, Optional, Set, Tuple
import torch
import vllm.envs as envs
from vllm.executor.executor_base import ExecutorBase, ExecutorAsyncBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, ExecuteModelRequest
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig)
from .config import ModelConfig, LoadConfig
logger = init_logger(__name__)
class SPMDGPUExecutor(ExecutorBase):
"""SPMD-based multi-GPU executor implementations."""
def __init__(
self,
model, # pytorch model itself or its parameter dict
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.multimodal_config = multimodal_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
distributed_init_method = initialize_cluster(parallel_config)
self._init_executor(model, distributed_init_method)
# TODO(sgm): verl not support speculative decode now
def _init_executor(self, model, distributed_init_method) -> None:
assert (not self.speculative_config), "Speculative decoding not yet supported for multi-GPU backend."
# Create the parallel worker for each GPU.
self._init_workers_sp(model, distributed_init_method)
def _init_workers_sp(self, model, distributed_init_method: str):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from .worker import Worker # pylint: disable=import-outside-toplevel
rank = int(os.getenv("RANK"))
local_rank = int(os.getenv("LOCAL_RANK"))
print(f'local rank {local_rank}')
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
self.worker = Worker(
model,
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
self.load_config,
local_rank,
rank,
distributed_init_method,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
speculative_config=None,
prompt_adapter_config=self.speculative_config,
is_driver_worker=True,
model_runner_cls=None, # use the default one
)
# NOTE(shengguangming): torch.distributed.init_process_group will be called inside the init_model()
self.worker.init_device()
self.worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self.worker.determine_num_available_blocks()
# NOTE(shengguangming): Now we don't use a shared centralized controler but each process will
# have its own scheduler
num_gpu_blocks = num_blocks[0]
num_cpu_blocks = num_blocks[1]
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
"""Initialize the KV cache in all workers.
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
if torch.distributed.get_rank() == 0:
print(
f'before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB'
)
self.worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks)
if torch.distributed.get_rank() == 0:
print(
f'after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB'
)
# NOTE(sgm): This will not profile & capture the model(CUDAGraph) when rebuilding KVCache
def init_cache_engine(self) -> None:
self.worker._init_cache_engine()
def free_cache_engine(self) -> None:
self.worker.free_cache_engine()
def execute_model(self, execute_model_req) -> List[SamplerOutput]:
all_outputs = self.worker.execute_model(execute_model_req=execute_model_req)
# NOTE(sgm):
# Each GPU in vllm under verl has its own spmd_gpu_executor, therefore all GPUs should return the outputs
# In vllm with ray, only the driver worker returns the sampling results.
return all_outputs
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.worker.add_lora(lora_request=lora_request)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.worker.remove_lora(lora_id=lora_id)
def list_loras(self) -> Set[int]:
return self.worker.list_loras()
def check_health(self) -> None:
# SPMDExecutor will always be healthy as long as
# it's running.
return
# NOTE(sgm) add for verl to pass the abstract class test, not used
from vllm.prompt_adapter.request import PromptAdapterRequest
def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool:
assert prompt_adapter_request.prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.worker.add_prompt_adapter(prompt_adapter_request)
def list_prompt_adapters(self) -> Set[int]:
return self.worker.list_prompt_adapters()
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.worker.pin_lora(lora_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.worker.pin_prompt_adapter(prompt_adapter_id)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.worker.remove_prompt_adapter(prompt_adapter_id)
# NOTE(sgm): add for verl
def offload_model_weights(self) -> None:
self.worker.offload_model_weights()
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
self.worker.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
def initialize_cluster(
parallel_config: ParallelConfig,
engine_use_ray: bool = False,
ray_address: Optional[str] = None,
) -> Tuple[str, Optional[None]]:
"""Initialize the distributed cluster probably with Ray.
Args:
parallel_config: The configurations for parallel execution.
Returns:
The `distributed_init_method` is the address for initializing the
distributed backend.
"""
# Initialize cluster locally.
port = get_open_port()
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
# distributed_init_method = f"tcp://localhost:{port}"
distributed_init_method = 'env://'
return distributed_init_method
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
# TODO(sgm): not implemented async executor yet
class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase):
async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Executes one model step on the given sequences."""
raise NotImplementedError
async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
self.check_health()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py
from typing import List, Optional, Tuple, Union
from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast)
from vllm.lora.request import LoRARequest
from vllm.utils import make_async, LRUCache
from vllm.transformers_utils.tokenizers import *
class TokenizerGroup:
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int]):
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = tokenizer
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None
def ping(self) -> bool:
"""Check if the tokenizer group is alive."""
return True
def get_max_input_len(self, lora_request: Optional[LoRARequest] = None) -> Optional[int]:
"""Get the maximum input length for the LoRA request."""
return self.max_input_length
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
return tokenizer.encode(prompt)
async def encode_async(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
return tokenizer.encode(prompt)
def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
# TODO(sgm): the lora tokenizer is also passed, but may be different
tokenizer = self.tokenizer
# tokenizer = (get_lora_tokenizer(
# lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
# FIXME(sgm): for simplicity, we assign the special token here
@property
def pad_token_id(self):
return self.tokenizer.pad_token_id
@property
def eos_token_id(self):
return self.tokenizer.eos_token_id
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py
"""A GPU worker class."""
import os
import gc
from typing import Dict, List, Tuple, Optional, Union, Type
import torch
import torch.distributed
import torch.nn as nn
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig)
from vllm.model_executor import set_random_seed
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SamplerOutput)
from vllm.worker.cache_engine import CacheEngine
# TODO(sgm): check why vllm has similar file in vllm.model_executor.parallel_utils.parallel_state
from vllm.distributed import (init_distributed_environment, set_custom_all_reduce, get_tensor_model_parallel_group)
from vllm.worker.worker_base import WorkerInput
from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase
from .model_runner import ModelRunner
from .megatron_weight_loaders import load_megatron_weights
from .hf_weight_loader import load_hf_weights
from .dtensor_weight_loaders import load_dtensor_weights
from .parallel_state import (ensure_model_parallel_initialized)
from .config import ModelConfig, LoadConfig, LoadFormat
class Worker(Worker):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for
maintaining the KV cache and executing the model on the GPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""
def __init__(
self,
model: Union[nn.Module, Dict], # model itself or its parameter dict
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
) -> None:
# self.model = model # will be replaced in the init_model
self.model_config = model_config
self.parallel_config = parallel_config
self.parallel_config.rank = rank
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.load_config = load_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker # TODO: we don't need driver
# if parallel_config and is_driver_worker:
# assert rank % parallel_config.tensor_parallel_size == 0, \
# "Driver worker should be rank 0 of tensor parallel group."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.multimodal_config = multimodal_config
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator"]) \
else {"return_hidden_states": True}
# TODO(sgm): set correct model runner class
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_runner_cls is not None:
ModelRunnerClass = model_runner_cls
elif self.model_config.embedding_mode:
ModelRunnerClass = EmbeddingModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
model, # [VERL]: add for verl
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=load_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config,
**speculative_args,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CacheEngine] = None
# Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
# NOTE(sgm): [VERL] For offloading inference engine params
self.cpu_model = None
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN.
self.rank = self.rank if self.rank is not None else int(os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.device = torch.device(f"cuda:{local_rank}")
if self.rank < 0:
raise ValueError("Invalid or unspecified rank.")
torch.cuda.set_device(self.device)
# Use the world_size set by TORCHRUN
world_size = int(os.getenv("WORLD_SIZE", "-1"))
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
self.parallel_config.world_size = world_size
_check_if_gpu_supports_dtype(self.model_config.dtype)
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
# self.model = get_model(actor_model=self.model, model_config=self.model_config)
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
# torch.cuda.reset_peak_memory_stats()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = total_gpu_memory - free_gpu_memory
assert peak_memory > 0, ("Error in memory profiling. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes()
# NOTE(sgm) [VERL] use the remaining memory
num_gpu_blocks = int((free_gpu_memory * self.cache_config.gpu_memory_utilization) // cache_block_size)
# num_gpu_blocks = int((total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
# NOTE(sgm): Add for [VERL], synchronize number of blocks with all the rank
num_gpu_blocks = torch.tensor([num_gpu_blocks], device='cuda')
num_cpu_blocks = torch.tensor([num_cpu_blocks], device='cuda')
torch.distributed.all_reduce(num_gpu_blocks,
op=torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group)
torch.distributed.all_reduce(num_cpu_blocks,
op=torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group)
num_gpu_blocks = num_gpu_blocks.item()
num_cpu_blocks = num_cpu_blocks.item()
gc.collect()
torch.cuda.empty_cache()
return num_gpu_blocks, num_cpu_blocks
def _init_cache_engine(self):
if self.cache_engine is None and self.gpu_cache is None:
super()._init_cache_engine()
def free_cache_engine(self):
# ensure `enforce_eager=True`
self.cache_engine = None
self.gpu_cache = None
# NOTE(sgm): [VERL]: adapt from _execute_model_spmd()
def execute_model(self,
execute_model_req: ExecuteModelRequest,
intermediate_tensors: Optional[IntermediateTensors] = None) -> Optional[List[SamplerOutput]]:
"""
Execute model in Single Program Multiple Data (SPMD) fashion.
All workers take the same request, prepare the input and
execute the model.
"""
assert execute_model_req is not None, ("_execute_model_spmd() requires each worker to take in an "
"ExecuteModelRequest")
worker_input: WorkerInput = self.prepare_worker_input(execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list))
# verl.worker.workerbase.WorkerBase
# swap cache
super().execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
if worker_input.num_seq_groups == 0:
return []
return self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None,
intermediate_tensors)
# assume the input is .state_dict()
def sync_model_weights(self, actor_weights: Dict, load_format: str):
if load_format in [LoadFormat.MEGATRON, LoadFormat.AUTO]:
load_megatron_weights(actor_weights, self.model_runner.model)
elif load_format == LoadFormat.HF:
# full model state dict without no sharding
load_hf_weights(actor_weights, self.model_runner.model)
elif load_format == LoadFormat.DTENSOR:
load_dtensor_weights(actor_weights, self.model_runner.model)
def offload_model_weights(self) -> None:
if self.cpu_model == None:
self.cpu_model = {}
for name, params in self.model_runner.model.named_parameters():
self.cpu_model[name] = torch.empty_like(params, device='cpu')
params.data = self.cpu_model[name]
else:
for name, params in self.model_runner.model.named_parameters():
params.data = self.cpu_model[name]
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = "env://",
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
# NOTE(sgm) use tcp://localhost:xxxx will hang in HF setting without megatron
init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank)
ensure_model_parallel_initialized(tensor_model_parallel_size=parallel_config.tensor_parallel_size,
pipeline_model_parallel_size=parallel_config.pipeline_parallel_size)
# TODO(sgm): check whether need this
# if pynccl_utils.is_initialized():
# pynccl_world_size = pynccl_utils.get_world_size()
# if pynccl_world_size != parallel_config.world_size:
# raise RuntimeError(
# "pynccl is already initialized but the pynccl world "
# "size does not match parallel_config.world_size "
# f"({pynccl_world_size} vs. {parallel_config.world_size}).")
# elif parallel_config.world_size > 1:
# # NOTE(woosuk): We don't initialize pynccl process group when world size
# # is 1.
# # NOTE(kaichao): By default, pynccl is initialized for tp group.
# pynccl_utils.init_process_group(
# group=get_tensor_model_parallel_cpu_group())
# # Initialize a custom fast all-reduce implementation.
# if not parallel_config.disable_custom_all_reduce:
# init_custom_ar()
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
# if pynccl_utils.is_initialized():
# pynccl_utils.all_reduce(torch.zeros(1).cuda())
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py
import os
from dataclasses import dataclass
from transformers import PretrainedConfig
from vllm.config import EngineConfig
from vllm.engine.arg_utils import EngineArgs
from .config import LoadConfig, ModelConfig
@dataclass
class EngineArgs(EngineArgs):
model_hf_config: PretrainedConfig = None # for verl
def __post_init__(self):
pass
def create_model_config(self) -> ModelConfig:
return ModelConfig(
hf_config=self.model_hf_config,
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code,
dtype=self.dtype,
seed=self.seed,
revision=self.revision,
code_revision=self.code_revision,
rope_scaling=self.rope_scaling,
rope_theta=self.rope_theta,
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
quantization=self.quantization,
quantization_param_path=self.quantization_param_path,
enforce_eager=self.enforce_eager,
max_context_len_to_capture=self.max_context_len_to_capture,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
disable_sliding_window=self.disable_sliding_window,
skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc,
override_neuron_config=self.override_neuron_config,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
)
def create_load_config(self) -> LoadConfig:
return LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
)
def create_engine_config(self) -> EngineConfig:
engine_config = super().create_engine_config()
# NOTE[VERL]: Use the world_size set by torchrun
world_size = int(os.getenv("WORLD_SIZE", "-1"))
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
engine_config.parallel_config.world_size = world_size
return engine_config
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py
import enum
import json
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Union
from transformers import PretrainedConfig
# Add for verl
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.utils import is_hip
if TYPE_CHECKING:
from vllm.model_executor.model_loader.loader import BaseModelLoader
logger = init_logger(__name__)
class LoadFormat(str, enum.Enum):
AUTO = "auto"
MEGATRON = "megatron"
HF = "hf"
DTENSOR = "dtensor"
DUMMY_HF = "dummy_hf"
DUMMY_MEGATRON = "dummy_megatron"
DUMMY_DTENSOR = "dummy_dtensor"
class ModelConfig(ModelConfig):
def __init__(self, hf_config: PretrainedConfig, *args, **kwargs) -> None:
super().__init__(model=hf_config._name_or_path, tokenizer=hf_config._name_or_path, *args, **kwargs)
self.hf_config = hf_config
@dataclass
class LoadConfig:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(model_loader_extra_config)
self._verify_load_format()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info("Ignoring the following patterns when downloading weights: %s", self.ignore_patterns)
else:
self.ignore_patterns = ["original/**/*"]
def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format)
]
raise ValueError(f"load format '{load_format}' is not supported in ROCm. "
f"Supported load formats are "
f"{rocm_supported_load_format}")
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
from typing import Dict
import torch.nn as nn
from torch.distributed._tensor import DTensor
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import is_pp_missing_parameter
def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
stacked_name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if stacked_name.endswith(".bias") and stacked_name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[stacked_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module):
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "lm_head.weight" in name:
continue
if ".attn.bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight)
def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def qwen2vl_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
from vllm.model_executor.layers.fused_moe import FusedMoE
def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=vllm_model.config.n_routed_experts,
)
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, vllm_model):
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, vllm_model):
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param,
local_loaded_weight.to(dtype=param.dtype),
weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, vllm_model):
continue
param = params_dict[name]
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, local_loaded_weight.to(dtype=param.dtype))
def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
pass
def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None):
param_name = _process_parameter_names(name=param_name)
if parallelize_plan is not None:
assert (
param_name
in parallelize_plan.keys()), f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}"
placement = parallelize_plan[param_name]
local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh,
placements=placement).to_local()
else:
local_loaded_weights = loaded_weights.full_tensor()
return local_loaded_weights
def _process_parameter_names(name):
# Remove '.weight' if it exists at the end of the string
if name.endswith(".weight"):
name = name[:-7]
# Remove 'model.layers.x.' or 'model.' prefix
if "model.layers" in name:
parts = name.split(".")
# Reconstruct the string without 'model.layers.x.'
name = ".".join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x'
elif name.startswith("model."):
name = name[6:] # Remove 'model.'
return name
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = {
"GPT2LMHeadModel": gpt2_dtensor_weight_loader,
"LlamaForCausalLM": llama_dtensor_weight_loader,
"LLaMAForCausalLM": llama_dtensor_weight_loader,
"MistralForCausalLM": llama_dtensor_weight_loader, # mistral is the same as llama in vLLM
"InternLMForCausalLM": llama_dtensor_weight_loader,
"AquilaModel": llama_dtensor_weight_loader,
"AquilaForCausalLM": llama_dtensor_weight_loader,
"Phi3ForCausalLM": llama_dtensor_weight_loader,
"GemmaForCausalLM": gemma_dtensor_weight_loader,
"Gemma2ForCausalLM": gemma_dtensor_weight_loader,
"GPTBigCodeForCausalLM": gptbigcode_dtensor_load_weights,
"Starcoder2ForCausalLM": starcoder2_dtensor_load_weights,
"Qwen2ForCausalLM": qwen2_dtensor_weight_loader,
"DeepseekV2ForCausalLM": deepseekv2_dtensor_weight_loader,
"Qwen2VLForConditionalGeneration": qwen2vl_dtensor_weight_loader,
}
# the actor model is .state_dict()
# Load dtensor weights
def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module):
weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
weight_loader(actor_weights, vllm_model)
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
# after init, and we need this after sync model weights for in first iter.
vllm_model = vllm_model.cuda()
def _get_model_weight_loader(arch: str):
if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__:
return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch]
raise ValueError(f"Model architectures {arch} are not supported for now. "
f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}")
# NOTE(sgm): we use per-parameter weight loader in each vllm sub
def update_dtensor_weight_loader():
pass
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
from typing import Dict
import torch.nn as nn
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
def update_hf_weight_loader():
print("no hf weight loader need to be updated")
return
def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
assert isinstance(actor_weights, Dict)
with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys():
del actor_weights["lm_head.weight"]
vllm_model.load_weights(actor_weights.items())
for _, module in vllm_model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
vllm_model = vllm_model.cuda()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer
from vllm import LLM
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.utils import Counter
from .arg_utils import EngineArgs
from .llm_engine_sp import LLMEngine
class LLM(LLM):
"""An LLM for generating texts from given prompts and sampling parameters.
This class includes a tokenizer, a language model (possibly distributed
across multiple GPUs), and GPU memory space allocated for intermediate
states (aka KV cache). Given a batch of prompts and sampling parameters,
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args:
model: A HuggingFace Transformers model instance.
tokenizer: A HuggingFace Transformers tokenizer instance.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
the `torch_dtype` attribute specified in the model config file.
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq". If None, we assume the model weights are not
quantized and use `dtype` to determine the data type of the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id.
seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
disable_custom_all_reduce: See ParallelConfig
"""
def __init__(
self,
model: Union[nn.Module, Dict], # model itself or its parameter dict
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer],
model_hf_config: PretrainedConfig,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
skip_tokenizer_init: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
cpu_offload_gb: float = 0,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
load_format="auto",
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
removed_vision_keys = ("image_token_id", "image_feature_size", "image_input_shape", "image_input_type")
if any(k in kwargs for k in removed_vision_keys):
raise TypeError("There is no need to pass vision-related arguments anymore.")
engine_args = EngineArgs(
model_hf_config=model_hf_config,
# tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
load_format=load_format,
**kwargs,
)
tokenizer_cls = (PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer)
if not isinstance(tokenizer, tokenizer_cls):
raise ValueError(
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
)
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext
self.request_counter = Counter()
def init_cache_engine(self):
self.llm_engine.init_cache_engine()
def free_cache_engine(self):
self.llm_engine.free_cache_engine()
def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
self.llm_engine.tokenizer = tokenizer
def _run_engine(self, *, use_tqdm: bool) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
outputs = super()._run_engine(use_tqdm=use_tqdm)
return self._post_process_outputs(outputs)
# # NOTE(shengguangming): add for verl
# # TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding.
# def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]:
# # remove the left padding in the prompt token_id
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
# non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
# token_ids = prompt_token_ids[non_pad_index:].tolist()
# return token_ids
# NOTE(shengguangming): add for verl
def _post_process_outputs(self, request_outputs: List[RequestOutput]) -> Tuple[torch.Tensor, torch.Tensor]:
output_token_ids = []
logprobs = []
for request_output in request_outputs: # List[RequestOutput]
outputs = request_output.outputs
for output in outputs: # List[CompletionOutput], usually len == 1
output_token_ids.append(torch.tensor(output.token_ids))
# TODO(shengguangming): can be optimzied by rewrite the Sampler._get_logprobs() logits
logprobs_dicts = output.logprobs
if logprobs_dicts is not None:
logprob = []
for logprobs_dict, id in zip(logprobs_dicts, output.token_ids):
logprob.append(logprobs_dict[id].logprob)
logprobs.append(torch.tensor(logprob))
pad_token_id = (self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None
else self.llm_engine.tokenizer.eos_token_id)
output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id)
if len(logprobs) > 0:
logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id)
return output_token_ids, logprobs
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
self.llm_engine.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
def offload_model_weights(self) -> None:
self.llm_engine.offload_model_weights()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py
from functools import partial
from typing import Callable, Dict, Optional, Type, Union
import torch
import torch.nn as nn
from vllm.config import (
CacheConfig,
DecodingConfig,
DeviceConfig,
EngineConfig,
LoadConfig,
LoRAConfig,
ModelConfig,
ObservabilityConfig,
ParallelConfig,
PromptAdapterConfig,
SchedulerConfig,
SpeculativeConfig,
)
from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine, SchedulerContext, SchedulerOutputState, _load_generation_config_dict
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.sequence import Sequence
from vllm.tracing import init_tracer
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message
from vllm.utils import Counter, weak_bind
from vllm.version import __version__ as VLLM_VERSION
from .arg_utils import EngineArgs
from .config import LoadConfig, ModelConfig
from .tokenizer import TokenizerGroup
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
class LLMEngine(LLMEngine):
"""An LLM engine that receives requests and generates texts.
This is the main class for the vLLM engine. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The :class:`~vllm.LLM` class wraps this class for offline batched inference
and the :class:`AsyncLLMEngine` class wraps this class for online serving.
The config arguments are derived from :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)
Args:
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device.
lora_config (Optional): The configuration related to serving multi-LoRA.
speculative_config (Optional): The configuration related to speculative
decoding.
executor_class: The model executor class for managing distributed
execution.
prompt_adapter_config (Optional): The configuration related to serving
prompt adapters.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection.
"""
def __init__(
self,
# NOTE(sgm): first two arguments are added for verl
model: Union[nn.Module, Dict], # model itself or its parameter dict
tokenizer: nn.Module,
# NOTE(sgm): vllm original arguments
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"mm_processor_kwargs=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
model_config.tokenizer,
model_config.skip_tokenizer_init,
model_config.tokenizer_mode,
model_config.revision,
model_config.override_neuron_config,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision,
model_config.trust_remote_code,
model_config.dtype,
model_config.max_model_len,
load_config.download_dir,
load_config.load_format,
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.disable_custom_all_reduce,
model_config.quantization,
model_config.enforce_eager,
cache_config.cache_dtype,
model_config.quantization_param_path,
device_config.device,
decoding_config,
observability_config,
model_config.seed,
model_config.served_model_name,
scheduler_config.use_v2_block_manager,
scheduler_config.num_scheduler_steps,
scheduler_config.chunked_prefill_enabled,
scheduler_config.multi_step_stream_outputs,
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
use_cached_outputs,
model_config.mm_processor_kwargs,
)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config or ObservabilityConfig()
self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs
if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer(tokenizer)
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
else:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
assert tokenizer_group, "tokenizer_group cannot be None, " "make sure skip_tokenizer_init is False"
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(model_config)
self.input_preprocessor = InputPreprocessor(model_config, self.tokenizer)
self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor(model_config)
self.model_executor = executor_class(
model=model, # add for spmd_gpu_executor
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
observability_config=self.observability_config,
)
if not self.model_config.embedding_mode:
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
from vllm.model_executor.model_loader import get_architecture_class_name
usage_message.report_usage(
get_architecture_class_name(model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype": str(model_config.dtype),
"tensor_parallel_size": parallel_config.tensor_parallel_size,
"block_size": cache_config.block_size,
"gpu_memory_utilization": cache_config.gpu_memory_utilization,
# Quantization
"quantization": model_config.quantization,
"kv_cache_dtype": str(cache_config.cache_dtype),
# Feature flags
"enable_lora": bool(lora_config),
"enable_prompt_adapter": bool(prompt_adapter_config),
"enable_prefix_caching": cache_config.enable_prefix_caching,
"enforce_eager": model_config.enforce_eager,
"disable_custom_all_reduce": parallel_config.disable_custom_all_reduce,
},
)
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
self.cached_scheduler_outputs = [
SchedulerOutputState() for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.scheduler_contexts = [
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.multi_step_stream_outputs)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
if model_config.use_async_output_proc:
process_model_outputs = weak_bind(self._process_model_outputs)
self.async_callbacks = [
partial(process_model_outputs, ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
else:
self.async_callbacks = []
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self.process_request_outputs_callback: Optional[Callable] = None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = [
Scheduler(
scheduler_config,
cache_config,
lora_config,
parallel_config.pipeline_parallel_size,
self.async_callbacks[v_id] if model_config.use_async_output_proc else None,
) for v_id in range(parallel_config.pipeline_parallel_size)
]
# Metric Logging.
if self.log_stats:
if stat_loggers is not None:
self.stat_loggers = stat_loggers
else:
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from vllm.engine.metrics import LoggingStatLogger, PrometheusStatLogger
self.stat_loggers = {
"logging":
LoggingStatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
"prometheus":
PrometheusStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len,
),
}
self.stat_loggers["prometheus"].info("cache_config", self.cache_config)
self.tracer = None
if self.observability_config.otlp_traces_endpoint:
self.tracer = init_tracer("vllm.llm_engine", self.observability_config.otlp_traces_endpoint)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
),
)
# TODO(sgm): add for verl but we may not tokenizer in Rollout
def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs):
init_kwargs = dict(enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None)
init_kwargs.update(tokenizer_init_kwargs)
return TokenizerGroup(tokenizer, **init_kwargs)
def init_cache_engine(self):
# TODO: check whether we should rebuild the CUDAGraph every iter when offload/load KVCache
# Re-capture CUDAGraph would be time-consuming
self.model_executor.init_cache_engine()
def free_cache_engine(self):
self.model_executor.free_cache_engine()
# NOTE(sgm): currently, we only support GPU executor
# The GPUExecutor remove the Ray dependency
@classmethod
def _get_executor_cls(cls, engine_config: EngineConfig) -> Type[ExecutorBase]:
distributed_executor_backend = engine_config.parallel_config.distributed_executor_backend
# Initialize the cluster and specify the executor class.]
assert (engine_config.device_config.device_type == "cuda"
), "Currently, the vllm in verl only support running on GPU"
# print('Waiting for debugger'); import os,debugpy; debugpy.listen(('localhost', 5678 + int(os.getenv('RANK', '0')))); debugpy.wait_for_client()
if engine_config.parallel_config.world_size == 1:
engine_config.load_config.load_format = "dummy_hf"
from .spmd_gpu_executor import SPMDGPUExecutor
executor_class = SPMDGPUExecutor
return executor_class
@classmethod
def from_engine_args(
cls,
model,
tokenizer,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
executor_class = cls._get_executor_cls(engine_config)
# Initialize the cluster and specify the executor class.
assert (engine_config.device_config.device_type == "cuda"
), "Currently, the vllm in verl only support running on GPU"
from .spmd_gpu_executor import SPMDGPUExecutor
executor_class = SPMDGPUExecutor
# Create the LLM engine.
engine = cls(
model,
tokenizer,
**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
return engine
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
self.model_executor.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
def offload_model_weights(self) -> None:
self.model_executor.offload_model_weights()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
from typing import Dict
import torch
import torch.nn as nn
from vllm.model_executor.layers.linear import *
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding
from vllm.model_executor.models import ModelRegistry
# NOTE(shengguangming): replace the origin weight loader function in the class
def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Parallel Linear weight loader."""
assert (param.size() == loaded_weight.size(
)), "the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}".format(
param.size(), loaded_weight.size())
assert (param.data.dtype == loaded_weight.data.dtype
), "if we want to shared weights, the data type should also be the same"
param.data = loaded_weight.data
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
assert param.size() == loaded_weight.size()
assert (param.data.dtype == loaded_weight.data.dtype
), "if we want to shared weights, the data type should also be the same"
param.data = loaded_weight.data
def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
for name, loaded_weight in actor_weights.items():
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if not name.startswith("transformer."):
name = "transformer." + name
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
# TODO: check megatron
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def qwen2_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_mapping = [
# (megatron core gpt model name, vllm model name)
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", "self_attn.o_proj"),
("pre_mlp_layernorm", "post_attention_layernorm"),
("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"),
("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"),
("mlp.linear_fc1", "mlp.gate_up_proj"),
("mlp.linear_fc2", "mlp.down_proj"),
("decoder.final_layernorm", "model.norm"),
("output_layer", "lm_head"),
]
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
name = _replace_name(name, params_mapping)
if name.endswith(".bias") and name not in params_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_mapping = [
# (megatron core gpt model name, vllm model name)
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", "self_attn.o_proj"),
(
"input_layernorm",
"input_layernorm",
),
("pre_mlp_layernorm", "post_attention_layernorm"),
("mlp.linear_fc1", "mlp.gate_up_proj"),
("mlp.linear_fc2", "mlp.down_proj"),
("decoder.final_layernorm", "model.norm"),
("output_layer", "lm_head"),
]
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
name = _replace_name(name, params_mapping)
if name.endswith(".bias") and name not in params_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def _replace_name(megatron_name, name_mapping):
for m_name, v_name in name_mapping:
if m_name not in megatron_name:
continue
if "layers" in megatron_name: # deal with decoder layers
megatron_name = megatron_name.replace("decoder", "model")
megatron_name_list = megatron_name.split(".")
if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list:
param_name_list = megatron_name_list[:3]
param_name_list.append(v_name)
param_name = ".".join(param_name_list)
else:
param_name_list = megatron_name_list[:3]
weight_or_bias = megatron_name_list[-1]
param_name_list.append(v_name)
param_name_list.append(weight_or_bias)
param_name = ".".join(param_name_list)
return param_name
else:
param_name = megatron_name.replace(m_name, v_name)
return param_name
def _replace_name(megatron_name, name_mapping):
for m_name, v_name in name_mapping:
if m_name not in megatron_name:
continue
if "layers" in megatron_name: # deal with decoder layers
megatron_name = megatron_name.replace("decoder", "model")
megatron_name_list = megatron_name.split(".")
if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list:
param_name_list = megatron_name_list[:3]
param_name_list.append(v_name)
param_name = ".".join(param_name_list)
else:
param_name_list = megatron_name_list[:3]
weight_or_bias = megatron_name_list[-1]
param_name_list.append(v_name)
param_name_list.append(weight_or_bias)
param_name = ".".join(param_name_list)
return param_name
else:
param_name = megatron_name.replace(m_name, v_name)
return param_name
def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
# TODO: need to implement a general way to deal with prefix
params_dict = dict(vllm_model.named_parameters())
for name, loaded_weight in actor_weights.items():
if "rotary_emb.inv_freq" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
def megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
params_mapping = [
# (megatron core gpt model name, vllm model name)
("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", "self_attn.o_proj"),
("pre_mlp_layernorm", "post_attention_layernorm"),
("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"),
("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"),
("mlp.linear_fc1", "mlp.gate_up_proj"),
("mlp.linear_fc2", "mlp.down_proj"),
("decoder.final_layernorm", "model.norm"),
("output_layer", "lm_head"),
]
# NOTE(shengguangming): the megatron llama may have this prefix
params_dict = dict(vllm_model.named_parameters())
for original_name, loaded_weight in actor_weights.items():
name = _replace_name(original_name, params_mapping)
if not name or name.endswith(".bias") and name not in params_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name:
continue
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
__LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = {
ColumnParallelLinear: parallel_weight_loader,
MergedColumnParallelLinear: parallel_weight_loader,
QKVParallelLinear: parallel_weight_loader,
RowParallelLinear: parallel_weight_loader,
VocabParallelEmbedding: parallel_weight_loader,
ParallelLMHead: parallel_weight_loader,
# "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights
# "default_weight_loader": default_weight_loader
}
# for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items():
# # setattr(layer_class, 'megatron_weight_loader', weight_loader)
# layer_class.weight_loader = weight_loader
__MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = {
"GPT2LMHeadModel": gpt2_weight_loader,
"LlamaForCausalLM": megatron_core_te_weight_loader, # use te backend for open-source megatron
"LLaMAForCausalLM": megatron_core_te_weight_loader,
"MistralForCausalLM": mistral_megatron_weight_loader,
'Qwen2ForCausalLM': megatron_core_te_weight_loader,
}
# the actor model is .state_dict()
# Load megatron weights
def load_megatron_weights(actor_weights: Dict, vllm_model: nn.Module):
weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
weight_loader(actor_weights, vllm_model)
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
# after init, and we need this after sync model weights for in first iter.
vllm_model = vllm_model.cuda()
def _get_model_weight_loader(arch: str):
if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__:
return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch]
raise ValueError(f"Model architectures {arch} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def update_megatron_weight_loader():
for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items():
layer_class.weight_loader = weight_loader
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models
"""Utilities for selecting and loading models."""
from typing import Dict, Optional, Union
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from vllm.config import CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
from vllm.model_executor.model_loader import BaseModelLoader
from vllm.model_executor.model_loader.loader import _initialize_model
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from .config import LoadConfig, LoadFormat, ModelConfig
from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader
from .hf_weight_loader import update_hf_weight_loader
from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader
def get_model(
actor_model: Union[PreTrainedModel, Dict],
model_config: ModelConfig,
load_config: LoadConfig,
device_config: DeviceConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig = None,
) -> nn.Module:
loader = get_model_loader(load_config)
if load_config.load_format.startswith("dummy"):
return loader.load_model(
model_config=model_config,
device_config=device_config,
lora_config=lora_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
cache_config=cache_config,
)
else:
return loader.load_model(
actor_model=actor_model,
model_config=model_config,
device_config=device_config,
lora_config=lora_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
cache_config=cache_config,
)
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
if isinstance(load_config.load_format, type):
return load_config.load_format(load_config)
if load_config.load_format == LoadFormat.AUTO:
update_megatron_weight_loader()
return MegatronLoader(load_config)
# NOTE(sgm): change the weight_loader function in runtime
if load_config.load_format == LoadFormat.MEGATRON:
update_megatron_weight_loader()
return MegatronLoader(load_config)
if load_config.load_format == LoadFormat.HF:
update_hf_weight_loader()
return HFLoader(load_config)
if load_config.load_format == LoadFormat.DTENSOR:
update_dtensor_weight_loader()
return DTensorLoader(load_config)
if load_config.load_format == LoadFormat.DUMMY_HF:
update_hf_weight_loader()
return DummyModelLoader(load_config)
if load_config.load_format == LoadFormat.DUMMY_MEGATRON:
update_megatron_weight_loader()
return DummyModelLoader(load_config)
if load_config.load_format == LoadFormat.DUMMY_DTENSOR:
update_dtensor_weight_loader()
return DummyModelLoader(load_config)
raise ValueError("load format not supported in verl: {}, only support {} and {}".format(
load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF))
class DummyModelLoader(BaseModelLoader):
"""Model loader that will set model weights to random values."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def download_model(self, model_config: ModelConfig) -> None:
pass
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
# initialize_dummy_weights(model)
return model.eval()
class MegatronLoader(BaseModelLoader):
"""Model loader that can load the model weights from partitioned megatron model."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
# NOTE(shengguangming) Load the weights from the actor model
pass
# if isinstance(actor_model, nn.Module):
# load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model)
# else:
# load_weights(actor_weights=actor_model, vllm_model=model)
# return actor_model
def load_model(
self,
actor_model: Union[PreTrainedModel, Dict],
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config)
# TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm
if isinstance(actor_model, nn.Module):
load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)),
vllm_model=model)
else:
load_megatron_weights(actor_weights=actor_model, vllm_model=model)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(sgm) Some weights are point to gpu, but still need this.
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
return model.eval()
class HFLoader(BaseModelLoader):
"""Model loader that can load the model weights from model's full params."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]):
if isinstance(actor_model, Dict):
return actor_model.items()
elif isinstance(actor_model, nn.Module):
return dict(actor_model.named_parameters()).items()
else:
raise ValueError(f"actor model should be Dict or nn.Module, but get {type(actor_model)}")
def load_model(
self,
actor_model: Union[PreTrainedModel, Dict],
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
# with torch.device(device_config.device):
# NOTE(sgm): init the model in cpu
model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config)
model.load_weights(self._get_weights_iterator(actor_model))
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(sgm) Some weights are point to gpu, but still need this.
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
return model.eval()
class DTensorLoader(BaseModelLoader):
"""Model loader that can load the model weights from partitioned megatron model."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]):
# NOTE(shengguangming) Load the weights from the actor model
pass
# if isinstance(actor_model, nn.Module):
# load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model)
# else:
# load_weights(actor_weights=actor_model, vllm_model=model)
# return actor_model
def load_model(
self,
actor_model: Union[PreTrainedModel, Dict],
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config)
# TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm
if isinstance(actor_model, nn.Module):
load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)),
vllm_model=model)
else:
load_dtensor_weights(actor_weights=actor_model, vllm_model=model)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(sgm) Some weights are point to gpu, but still need this.
model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage
return model.eval()
# FIXME(sgm): hack the _get_logits function in vllm v0.4.2
# as they use ray, the _get_logits result will only need to return to the driver node,
# therefore gather is enough. However, we use SPMD instead of a central scheduler,
# all_gather is required (aligned with v0.2.6)
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
from vllm.model_executor.layers.logits_processor import LogitsProcessor
def logitsprocessor_init(
self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: float = 1.0,
logits_as_input: bool = False,
soft_cap: Optional[float] = None,
) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
"""
super(LogitsProcessor, self).__init__()
self.scale = scale
self.vocab_size = vocab_size
# Whether the input is logits (default is hidden states).
self.logits_as_input = logits_as_input
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_gather = False
LogitsProcessor.__init__ = logitsprocessor_init # use all_gather
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py
import warnings
from enum import IntEnum
from typing import Dict, Optional, Union
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.config import (
CacheConfig,
DeviceConfig,
LoadConfig,
LoRAConfig,
ModelConfig,
ObservabilityConfig,
ParallelConfig,
PromptAdapterConfig,
SchedulerConfig,
)
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.models.interfaces import supports_lora
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.prompt_adapter.worker_manager import LRUCacheWorkerPromptAdapterManager
from vllm.utils import DeviceMemoryProfiler, is_hip, supports_dynamo
from vllm.worker.model_runner import ModelRunner
from .config import LoadConfig, ModelConfig
from .model_loader import get_model
logger = init_logger(__name__)
# How batches are constructed.
class BatchType(IntEnum):
# Every batch is prefill.
PREFILL = 0
# Every batch is decode.
DECODE = 1
# Batch is a mixture of prefill and decode.
MIXED = 2
class ModelRunner(ModelRunner):
def __init__(
self,
model: Union[nn.Module, Dict], # [verl] model itself or its parameter dict
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
super().__init__(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
lora_config,
kv_cache_dtype,
is_driver_worker=True, # a hack
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states,
observability_config=observability_config,
input_registry=input_registry,
mm_registry=mm_registry,
)
# NOTE(sgm): add for verl
self.model = model # this will be replaced by get_model()
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m:
self.model = get_model(
self.model,
model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config,
)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30))
if self.lora_config:
assert supports_lora(self.model), f"{self.model.__class__.__name__} does not support LoRA yet."
if supports_multimodal(self.model):
logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
# It's necessary to distinguish between the max_position_embeddings
# of VLMs and LLMs.
if hasattr(self.model.config, "max_position_embeddings"):
max_pos_embeddings = self.model.config.max_position_embeddings
else:
max_pos_embeddings = self.model.config.text_config.max_position_embeddings
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)
if self.prompt_adapter_config:
self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.device,
self.prompt_adapter_config,
)
self.model = self.prompt_adapter_manager.create_prompt_adapter_manager(self.model)
if self.kv_cache_dtype == "fp8" and is_hip():
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# in the future.
if self.model_config.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
warnings.warn(
"Loading kv cache scaling factor from JSON is "
"deprecated and will be removed. Please include "
"kv cache scaling factors in the model checkpoint.",
FutureWarning,
stacklevel=2,
)
self.model.load_kv_cache_scales(self.model_config.quantization_param_path)
logger.info("Loaded KV cache scaling factors from %s", self.model_config.quantization_param_path)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
"model %s does not support loading scaling factors.",
self.model.__class__,
)
else:
logger.warning("Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend() or "eager"
self.model = torch.compile(self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, backend=backend)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups."""
import os
from typing import Optional
import torch
import torch.distributed
import vllm.distributed.parallel_state as ps
from vllm.distributed.parallel_state import (
get_pp_group,
get_world_group,
init_distributed_environment,
init_model_parallel_group,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
"""
This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron.
- We assume the Megatron tp+dp+pp world is already established before calling this function.
"""
# Device mesh for using DTensor
_DEVICE_MESH = None
# Tensor model parallel group that the current rank belongs to.
_TP = None
# Pipeline model parallel group that the current rank belongs to.
_PP = None
# This method is for initializing the ParallelGroup when using HybridEngine
def initialize_parallel_state(
distributed_init_method: str = "env://",
backend: str = "nccl",
tensor_model_parallel_size: int = 1,
num_tp_per_train_tp: int = 1,
pipeline_model_parallel_size: int = 1,
):
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN.
rank = int(os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
# Use the world_size set by TORCHRUN
world_size = int(os.getenv("WORLD_SIZE", "-1"))
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend)
if torch.distributed.get_world_size() > 1:
# NOTE: build a sepearate inference group with infer tp & micro dp
initialize_model_parallel_for_vllm(
tensor_model_parallel_size=tensor_model_parallel_size,
num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp,
)
else:
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
# get the backend of _DEVICE_WORLD_GROUP
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend)
return
assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (
"tensor parallel group already initialized, but of unexpected size: "
f"{get_tensor_model_parallel_world_size()=} vs. "
f"{tensor_model_parallel_size=}")
pp_world_size = get_pp_group().world_size
assert pp_world_size == pipeline_model_parallel_size, (
"pipeline parallel group already initialized, but of unexpected size: "
f"{pp_world_size=} vs. "
f"{pipeline_model_parallel_size=}")
# TODO(sgm): deviate from the v0.5.4, not pp now
def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized."""
return ps._TP is not None
# and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
def initialize_model_parallel_for_vllm(
tensor_model_parallel_size: int,
num_tensor_model_parallel_groups_per_train_tp: int = 1,
pipeline_model_parallel_size: int = 1,
) -> None:
pass
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
assert isinstance(tensor_model_parallel_size, int)
# assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group
# assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group
# Build the tensor model-parallel groups.
assert ps._TP is None, "tensor model parallel group is already initialized"
global _TP
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = torch.distributed.get_backend()
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
if num_tensor_model_parallel_groups_per_train_tp == 1:
# if tensor_model_parallel_size == train_tensor_parallel_size:
# using the same tp group as Megatron/vllm
assert _TP is None, "tensor model parallel group is already initialized"
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
group_ranks.append(ranks)
_TP = init_model_parallel_group(
group_ranks=group_ranks,
local_rank=get_world_group().local_rank,
backend=backend,
use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer
use_message_queue_broadcaster=True,
)
ps._TP = _TP
# _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine
else:
# initialize a micro_dp group and a tp group
# assume training tp=4, infer tp=2, then, weight is partitioned as
# [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference
# Build the inference tp groups
# train_tp = train_tensor_parallel_size
train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size
# num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size
assert _TP is None, "tensor model parallel group is already initialized"
group_ranks = []
for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp):
start = train_tp * i
end = train_tp * (i + 1)
for j in range(num_tensor_model_parallel_groups_per_train_tp):
ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp))
for i in range(len(ranks)):
ranks[i] += j
group_ranks.append(ranks)
_TP = init_model_parallel_group(
group_ranks=group_ranks,
local_rank=get_world_group().local_rank,
backend=backend,
use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer
use_message_queue_broadcaster=True,
)
ps._TP = _TP
# Build the pipeline model-parallel groups.
# global _PIPELINE_MODEL_PARALLEL_GROUP
# global _PIPELINE_GLOBAL_RANKS
# assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized")
# ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group()
# ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks()
# TODO: init using device mesh (not support hybrid engine now)
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
global _PP
assert _PP is None, "pipeline model parallel group is already initialized"
group_ranks = []
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
_PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False)
ps._PP = _PP # for verl
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""
NOTE: This method is a hack from the open-sourced version without
asertion of world_size = tp * pp
Initialize model parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model
parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model
parallelism.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
4 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 pipeline model-parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group)
# NOTE(sgm) we don't assert world_size == tp * pp
# DP is not managed by vllm but by the VeRL WorkerGroup
# if (world_size !=
# tensor_model_parallel_size * pipeline_model_parallel_size):
# raise RuntimeError(
# f"world_size ({world_size}) is not equal to "
# f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
# f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
rank = torch.distributed.get_rank()
global _TP
assert _TP is None, "tensor model parallel group is already initialized"
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size))
group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer
use_message_queue_broadcaster=True,
)
ps._TP = _TP
# TODO: init using device mesh (not support hybrid engine now)
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
global _PP
assert _PP is None, "pipeline model parallel group is already initialized"
group_ranks = []
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
_PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False)
ps._PP = _PP # for verl
"""
Device mesh utilities
"""
def get_device_mesh():
assert _DEVICE_MESH is not None, "device mesh is not initialized"
return _DEVICE_MESH
"""
Tensor model parallel utilities
"""
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TP is not None, "tensor model parallel group is not initialized"
return _TP.device_group
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py
import os
import socket
from typing import Dict, List, Optional, Set, Tuple
import torch
from vllm.config import (
CacheConfig,
DeviceConfig,
LoRAConfig,
ObservabilityConfig,
ParallelConfig,
PromptAdapterConfig,
SchedulerConfig,
SpeculativeConfig,
)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from .config import LoadConfig, ModelConfig
logger = init_logger(__name__)
class SPMDGPUExecutor(ExecutorBase):
"""SPMD-based multi-GPU executor implementations."""
def __init__(
self,
model, # pytorch model itself or its parameter dict
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
observability_config: Optional[ObservabilityConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
distributed_init_method = initialize_cluster(parallel_config)
self._init_executor(model, distributed_init_method)
# TODO(sgm): verl not support speculative decode now
def _init_executor(self, model, distributed_init_method) -> None:
assert not self.speculative_config, "Speculative decoding not yet supported for multi-GPU backend."
# Create the parallel worker for each GPU.
self._init_workers_sp(model, distributed_init_method)
def _init_workers_sp(self, model, distributed_init_method: str):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from .worker import Worker # pylint: disable=import-outside-toplevel
rank = int(os.getenv("RANK"))
local_rank = int(os.getenv("LOCAL_RANK"))
print(f"local rank {local_rank}")
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ["NCCL_CUMEM_ENABLE"] = "0"
self.worker = Worker(
model,
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
self.load_config,
local_rank,
rank,
distributed_init_method,
lora_config=self.lora_config,
speculative_config=None,
prompt_adapter_config=self.speculative_config,
is_driver_worker=True,
model_runner_cls=None, # use the default one
)
# NOTE(shengguangming): torch.distributed.init_process_group will be called inside the init_model()
self.worker.init_device()
self.worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self.worker.determine_num_available_blocks()
# NOTE(shengguangming): Now we don't use a shared centralized controler but each process will
# have its own scheduler
num_gpu_blocks = num_blocks[0]
num_cpu_blocks = num_blocks[1]
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
"""Initialize the KV cache in all workers."""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
if torch.distributed.get_rank() == 0:
print(
f"before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB"
)
self.worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks)
if torch.distributed.get_rank() == 0:
print(
f"after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB"
)
# NOTE(sgm): This will not profile & capture the model(CUDAGraph) when rebuilding KVCache
def init_cache_engine(self) -> None:
self.worker._init_cache_engine()
def free_cache_engine(self) -> None:
self.worker.free_cache_engine()
def execute_model(self, execute_model_req) -> List[SamplerOutput]:
all_outputs = self.worker.execute_model(execute_model_req=execute_model_req)
# NOTE(sgm):
# Each GPU in vllm under verl has its own spmd_gpu_executor, therefore all GPUs should return the outputs
# In vllm with ray, only the driver worker returns the sampling results.
return all_outputs
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.worker.add_lora(lora_request=lora_request)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.worker.remove_lora(lora_id=lora_id)
def list_loras(self) -> Set[int]:
return self.worker.list_loras()
def check_health(self) -> None:
# SPMDExecutor will always be healthy as long as
# it's running.
return
# NOTE(sgm) add for verl to pass the abstract class test, not used
from vllm.prompt_adapter.request import PromptAdapterRequest
def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool:
assert prompt_adapter_request.prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0."
return self.worker.add_prompt_adapter(prompt_adapter_request)
def list_prompt_adapters(self) -> Set[int]:
return self.worker.list_prompt_adapters()
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.worker.pin_lora(lora_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0."
return self.worker.pin_prompt_adapter(prompt_adapter_id)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0."
return self.worker.remove_prompt_adapter(prompt_adapter_id)
# NOTE(sgm): add for verl
def offload_model_weights(self) -> None:
self.worker.offload_model_weights()
def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None:
self.worker.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
def initialize_cluster(
parallel_config: ParallelConfig,
engine_use_ray: bool = False,
ray_address: Optional[str] = None,
) -> Tuple[str, Optional[None]]:
"""Initialize the distributed cluster probably with Ray.
Args:
parallel_config: The configurations for parallel execution.
Returns:
The `distributed_init_method` is the address for initializing the
distributed backend.
"""
# Initialize cluster locally.
port = get_open_port()
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
# distributed_init_method = f"tcp://localhost:{port}"
distributed_init_method = "env://"
return distributed_init_method
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
# TODO(sgm): not implemented async executor yet
class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase):
async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Executes one model step on the given sequences."""
raise NotImplementedError
async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
self.check_health()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py
from typing import Optional
from transformers import PreTrainedTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.utils import LRUCache
class TokenizerGroup(TokenizerGroup):
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int]):
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = tokenizer
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None
# FIXME(sgm): for simplicity, we assign the special token here
@property
def pad_token_id(self):
return self.tokenizer.pad_token_id
@property
def eos_token_id(self):
return self.tokenizer.eos_token_id
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py
"""A GPU worker class."""
import gc
import os
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
import torch.distributed
import torch.nn as nn
from vllm.config import (
CacheConfig,
DeviceConfig,
LoRAConfig,
ParallelConfig,
PromptAdapterConfig,
SchedulerConfig,
SpeculativeConfig,
)
# TODO(sgm): check why vllm has similar file in vllm.model_executor.parallel_utils.parallel_state
from vllm.distributed import get_tensor_model_parallel_group, init_distributed_environment, set_custom_all_reduce
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase
from vllm.worker.model_runner_base import ModelRunnerInputBase
from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype
from vllm.worker.worker_base import WorkerInput
from .config import LoadConfig, LoadFormat, ModelConfig
from .dtensor_weight_loaders import load_dtensor_weights
from .hf_weight_loader import load_hf_weights
from .megatron_weight_loaders import load_megatron_weights
from .model_runner import ModelRunner
from .parallel_state import ensure_model_parallel_initialized
class Worker(Worker):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for
maintaining the KV cache and executing the model on the GPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""
def __init__(
self,
model: Union[nn.Module, Dict], # model itself or its parameter dict
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
) -> None:
# self.model = model # will be replaced in the init_model
self.model_config = model_config
self.parallel_config = parallel_config
self.parallel_config.rank = rank
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.load_config = load_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker # TODO: we don't need driver
# if parallel_config and is_driver_worker:
# assert rank % parallel_config.tensor_parallel_size == 0, \
# "Driver worker should be rank 0 of tensor parallel group."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_args = (
{} if speculative_config is None or (speculative_config.draft_model_config.model == model_config.model) or
(speculative_config.draft_model_config.hf_config.model_type not in ["medusa", "mlp_speculator"]) else {
"return_hidden_states": True
})
# TODO(sgm): set correct model runner class
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_runner_cls is not None:
ModelRunnerClass = model_runner_cls
elif self.model_config.embedding_mode:
ModelRunnerClass = EmbeddingModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
model, # [VERL]: add for verl
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=load_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
**speculative_args,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CacheEngine] = None
# Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
# NOTE(sgm): [VERL] For offloading inference engine params
self.cpu_model = None
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN.
self.rank = self.rank if self.rank is not None else int(os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.device = torch.device(f"cuda:{local_rank}")
if self.rank < 0:
raise ValueError("Invalid or unspecified rank.")
torch.cuda.set_device(self.device)
# Use the world_size set by TORCHRUN
world_size = int(os.getenv("WORLD_SIZE", "-1"))
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
self.parallel_config.world_size = world_size
_check_if_gpu_supports_dtype(self.model_config.dtype)
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
# self.model = get_model(actor_model=self.model, model_config=self.model_config)
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
# torch.cuda.reset_peak_memory_stats()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = total_gpu_memory - free_gpu_memory
assert peak_memory > 0, ("Error in memory profiling. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes()
# NOTE(sgm) [VERL] use the remaining memory
num_gpu_blocks = int((free_gpu_memory * self.cache_config.gpu_memory_utilization) // cache_block_size)
# num_gpu_blocks = int((total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
# NOTE(sgm): Add for [VERL], synchronize number of blocks with all the rank
num_gpu_blocks = torch.tensor([num_gpu_blocks], device="cuda")
num_cpu_blocks = torch.tensor([num_cpu_blocks], device="cuda")
torch.distributed.all_reduce(num_gpu_blocks,
op=torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group)
torch.distributed.all_reduce(num_cpu_blocks,
op=torch.distributed.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group)
num_gpu_blocks = num_gpu_blocks.item()
num_cpu_blocks = num_cpu_blocks.item()
gc.collect()
torch.cuda.empty_cache()
return num_gpu_blocks, num_cpu_blocks
def _init_cache_engine(self):
if self.cache_engine is None and self.gpu_cache is None:
super()._init_cache_engine()
def free_cache_engine(self):
# ensure `enforce_eager=True`
self.cache_engine = None
self.gpu_cache = None
# NOTE(sgm): [VERL]: adapt from _execute_model_spmd()
def execute_model(self,
execute_model_req: ExecuteModelRequest,
intermediate_tensors: Optional[IntermediateTensors] = None) -> Optional[List[SamplerOutput]]:
"""
Execute model in Single Program Multiple Data (SPMD) fashion.
All workers take the same request, prepare the input and
execute the model.
"""
assert execute_model_req is not None, ("_execute_model_spmd() requires each worker to take in an "
"ExecuteModelRequest")
worker_input: WorkerInput = self.prepare_worker_input(execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list)
# verl.worker.workerbase.WorkerBase
# swap cache
super().execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
if worker_input.num_seq_groups == 0:
return []
return self.model_runner.execute_model(
model_input,
self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None,
intermediate_tensors,
)
# assume the input is .state_dict()
def sync_model_weights(self, actor_weights: Dict, load_format: str):
if load_format in [LoadFormat.MEGATRON, LoadFormat.AUTO]:
load_megatron_weights(actor_weights, self.model_runner.model)
elif load_format == LoadFormat.HF:
# full model state dict without no sharding
load_hf_weights(actor_weights, self.model_runner.model)
elif load_format == LoadFormat.DTENSOR:
load_dtensor_weights(actor_weights, self.model_runner.model)
def offload_model_weights(self) -> None:
if self.cpu_model == None:
self.cpu_model = {}
for name, params in self.model_runner.model.named_parameters():
self.cpu_model[name] = torch.empty_like(params, device="cpu")
params.data = self.cpu_model[name]
else:
for name, params in self.model_runner.model.named_parameters():
params.data = self.cpu_model[name]
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = "env://",
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
# NOTE(sgm) use tcp://localhost:xxxx will hang in HF setting without megatron
init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank)
ensure_model_parallel_initialized(
tensor_model_parallel_size=parallel_config.tensor_parallel_size,
pipeline_model_parallel_size=parallel_config.pipeline_parallel_size,
)
# TODO(sgm): check whether need this
# if pynccl_utils.is_initialized():
# pynccl_world_size = pynccl_utils.get_world_size()
# if pynccl_world_size != parallel_config.world_size:
# raise RuntimeError(
# "pynccl is already initialized but the pynccl world "
# "size does not match parallel_config.world_size "
# f"({pynccl_world_size} vs. {parallel_config.world_size}).")
# elif parallel_config.world_size > 1:
# # NOTE(woosuk): We don't initialize pynccl process group when world size
# # is 1.
# # NOTE(kaichao): By default, pynccl is initialized for tp group.
# pynccl_utils.init_process_group(
# group=get_tensor_model_parallel_cpu_group())
# # Initialize a custom fast all-reduce implementation.
# if not parallel_config.disable_custom_all_reduce:
# init_custom_ar()
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
# if pynccl_utils.is_initialized():
# pynccl_utils.all_reduce(torch.zeros(1).cuda())
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment