Commit aecdff18 authored by gaoqiong's avatar gaoqiong
Browse files

合入中科嘉禾mp代码

parent 6b58062d
...@@ -892,8 +892,8 @@ class ModelConfig: ...@@ -892,8 +892,8 @@ class ModelConfig:
optimized_quantization_methods = [ optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "slimquant_w4a8",
"slimquant_w4a8","slimquant_w4a8_marlin" "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin"
] ]
if self.quantization is not None: if self.quantization is not None:
self.quantization = cast(me_quant.QuantizationMethods, self.quantization = cast(me_quant.QuantizationMethods,
...@@ -920,7 +920,8 @@ class ModelConfig: ...@@ -920,7 +920,8 @@ class ModelConfig:
"awq_marlin", "awq_marlin",
"ipex", "ipex",
"moe_wna16", "moe_wna16",
"slimquant_w4a8_marlin" "slimquant_w4a8_marlin",
"slimquant_compressed_tensors_marlin"
] ]
quantization_methods = [ quantization_methods = [
q for q in supported_quantization if q not in overrides q for q in supported_quantization if q not in overrides
...@@ -1777,7 +1778,7 @@ class LoadConfig: ...@@ -1777,7 +1778,7 @@ class LoadConfig:
self.ignore_patterns = ["original/**/*"] self.ignore_patterns = ["original/**/*"]
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher", "mp_rpc"]
@config @config
...@@ -2003,10 +2004,10 @@ class ParallelConfig: ...@@ -2003,10 +2004,10 @@ class ParallelConfig:
logger.info("Disabling V1 multiprocessing for external launcher.") logger.info("Disabling V1 multiprocessing for external launcher.")
if self.enable_eplb: if self.enable_eplb:
# if not current_platform.is_cuda(): if not current_platform.is_cuda():
# raise ValueError( raise ValueError(
# "Expert parallelism load balancing is only supported on " "Expert parallelism load balancing is only supported on "
# "CUDA devices now.") "CUDA devices now.")
if self.num_redundant_experts < 0: if self.num_redundant_experts < 0:
raise ValueError( raise ValueError(
"num_redundant_experts must be non-negative, but got " "num_redundant_experts must be non-negative, but got "
...@@ -2068,14 +2069,14 @@ class ParallelConfig: ...@@ -2068,14 +2069,14 @@ class ParallelConfig:
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform from vllm.platforms import current_platform
if self.distributed_executor_backend not in ( if self.distributed_executor_backend not in (
"ray", "mp", "uni", "ray", "mp", "uni", "mp_rpc",
"external_launcher", None) and not (isinstance( "external_launcher", None) and not (isinstance(
self.distributed_executor_backend, type) and issubclass( self.distributed_executor_backend, type) and issubclass(
self.distributed_executor_backend, ExecutorBase)): self.distributed_executor_backend, ExecutorBase)):
raise ValueError( raise ValueError(
"Unrecognized distributed executor backend " "Unrecognized distributed executor backend "
f"{self.distributed_executor_backend}. Supported " f"{self.distributed_executor_backend}. Supported "
"values are 'ray', 'mp' 'uni', 'external_launcher' or" "values are 'ray', 'mp' 'uni', 'external_launcher', 'mp_rpc' or"
" custom ExecutorBase subclass.") " custom ExecutorBase subclass.")
if self.use_ray: if self.use_ray:
from vllm.executor import ray_utils from vllm.executor import ray_utils
...@@ -4755,12 +4756,12 @@ class VllmConfig: ...@@ -4755,12 +4756,12 @@ class VllmConfig:
batch_size_capture_list = [] batch_size_capture_list = []
if self.model_config is not None and \ if self.model_config is not None and \
not self.model_config.enforce_eager: not self.model_config.enforce_eager:
if self.model_config.use_mla and self.scheduler_config.max_num_seqs<=512: if self.model_config.use_mla and self.compilation_config.full_cuda_graph and self.scheduler_config.max_num_seqs<=512:
cuda_graph_sizes = [self.scheduler_config.max_num_seqs] cuda_graph_sizes = [self.scheduler_config.max_num_seqs]
else: else:
cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes
if len(cuda_graph_sizes) == 1: if len(cuda_graph_sizes) == 1:
batch_size_capture_list = [1, 2, 4] + [ batch_size_capture_list = [1, 2, 3, 4] + [
i for i in range(8, cuda_graph_sizes[0] + 1, 8) i for i in range(8, cuda_graph_sizes[0] + 1, 8)
] ]
elif len(cuda_graph_sizes) > 1: elif len(cuda_graph_sizes) > 1:
......
...@@ -21,7 +21,7 @@ import vllm.envs as envs ...@@ -21,7 +21,7 @@ import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup, sched_yield from vllm.distributed.utils import StatelessProcessGroup, sched_yield
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path, from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
is_valid_ipv6_address) is_valid_ipv6_address, get_loopback_ip)
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
...@@ -255,7 +255,7 @@ class MessageQueue: ...@@ -255,7 +255,7 @@ class MessageQueue:
# for remote readers, we will: # for remote readers, we will:
# create a publish-subscribe socket to communicate large data # create a publish-subscribe socket to communicate large data
if not connect_ip: if not connect_ip:
connect_ip = get_ip() connect_ip = get_loopback_ip()
self.remote_socket = context.socket(XPUB) self.remote_socket = context.socket(XPUB)
self.remote_socket.setsockopt(XPUB_VERBOSE, True) self.remote_socket.setsockopt(XPUB_VERBOSE, True)
remote_subscribe_port = get_open_port() remote_subscribe_port = get_open_port()
......
...@@ -948,12 +948,6 @@ def init_distributed_environment( ...@@ -948,12 +948,6 @@ def init_distributed_environment(
"Fallback Gloo backend is not available.") "Fallback Gloo backend is not available.")
backend = "gloo" backend = "gloo"
# this backend is used for WORLD # this backend is used for WORLD
parallel_config = config.parallel_config
data_parallel_size = parallel_config.data_parallel_size
use_mori_ep = envs.VLLM_ALL2ALL_BACKEND == 'mori' and data_parallel_size > 1 and parallel_config.enable_expert_parallel
if use_mori_ep:
backend="cpu:gloo,cuda:nccl"
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=backend, backend=backend,
init_method=distributed_init_method, init_method=distributed_init_method,
...@@ -1044,7 +1038,7 @@ def initialize_model_parallel( ...@@ -1044,7 +1038,7 @@ def initialize_model_parallel(
_TP = init_model_parallel_group(group_ranks, _TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank, get_world_group().local_rank,
backend, backend,
use_message_queue_broadcaster=True, use_message_queue_broadcaster=False,
group_name="tp") group_name="tp")
# Build the pipeline model-parallel groups. # Build the pipeline model-parallel groups.
......
...@@ -1499,7 +1499,7 @@ class EngineArgs: ...@@ -1499,7 +1499,7 @@ class EngineArgs:
if (self.pipeline_parallel_size > 1 if (self.pipeline_parallel_size > 1
and self.distributed_executor_backend and self.distributed_executor_backend
not in (ParallelConfig.distributed_executor_backend, "ray", not in (ParallelConfig.distributed_executor_backend, "ray",
"mp", "external_launcher")): "mp", "external_launcher", "mp_rpc")):
name = "Pipeline Parallelism without Ray distributed executor " \ name = "Pipeline Parallelism without Ray distributed executor " \
"or multiprocessing executor or external launcher" "or multiprocessing executor or external launcher"
_raise_or_fallback(feature_name=name, recommend_to_remove=False) _raise_or_fallback(feature_name=name, recommend_to_remove=False)
......
...@@ -175,7 +175,10 @@ if TYPE_CHECKING: ...@@ -175,7 +175,10 @@ if TYPE_CHECKING:
USE_FUSED_SILU_MUL_QUANT: bool = False USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_P2P_ASYNC: bool = False VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000 VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_ENABLE_MOE_GROUP_GEMM: bool = False VLLM_SCHED_ENABLE_MINIMAL_INJECTION: bool = False
VLLM_USE_PD_SPLIT: bool = False
VLLM_LOOPBACK_IP: str = ""
VLLM_MP_RPC_READY_BASE_PORT: int = 28888
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -945,7 +948,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -945,7 +948,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "pplx": use pplx kernels # - "pplx": use pplx kernels
# - "deepep_high_throughput", use deepep high-throughput kernels # - "deepep_high_throughput", use deepep high-throughput kernels
# - "deepep_low_latency", use deepep low-latency kernels # - "deepep_low_latency", use deepep low-latency kernels
# - "mori", use mori kernels
"VLLM_ALL2ALL_BACKEND": "VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
...@@ -1093,7 +1095,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1093,7 +1095,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_ATTN_PA": "VLLM_USE_FLASH_ATTN_PA":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "True").lower() in lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use apex for rmsnorm # vLLM will use apex for rmsnorm
"VLLM_USE_APEX_RN": "VLLM_USE_APEX_RN":
lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in
...@@ -1134,29 +1135,31 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1134,29 +1135,31 @@ environment_variables: dict[str, Callable[[], Any]] = {
"USE_FUSED_RMS_QUANT": "USE_FUSED_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in
("true", "1")), ("true", "1")),
# vllm will use lightop's moe_sum fusion operator for deepseek
"VLLM_USE_DEEPSEEK_MOE_SUM_MUL_ADD":
lambda: (os.getenv('VLLM_USE_DEEPSEEK_MOE_SUM_MUL_ADD', 'True').lower() in
("true", "1")),
# vllm will use silu_mul_quant fused op # vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT": "USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in lambda: (os.getenv('USE_FUSED_SILU_MUL_QUANT', '0').lower() in
("true", "1")), ("true", "1")),
# vllm pd separation will be used async # vllm pd separation will be used async
"VLLM_P2P_ASYNC": "VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))), lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
# pd separation p2p async buf tokens # pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS": "VLLM_P2P_BUF_TOKENS":
lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")), lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")),
# vllm will enable minimal injection for pipeline parallel scheduling
# pd separation p2p async buf tokens "VLLM_SCHED_ENABLE_MINIMAL_INJECTION":
"VLLM_ENABLE_MOE_GROUP_GEMM": lambda: (os.getenv("VLLM_SCHED_ENABLE_MINIMAL_INJECTION", "0").lower() in
lambda: (os.environ.get("VLLM_ENABLE_MOE_GROUP_GEMM", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in
("true", "1")),
# Used to force set up loopback IP
"VLLM_LOOPBACK_IP":
lambda: os.getenv("VLLM_LOOPBACK_IP", ""),
# Used to get READY_BASE_PORT in multiproc_rpc_executor
"VLLM_MP_RPC_READY_BASE_PORT":
lambda: int(os.getenv("VLLM_MP_RPC_READY_BASE_PORT", "28888")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import torch
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationArgs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig, CompressedTensorsLinearMethod, CompressedTensorsKVCacheMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe_marlin import (
CompressedTensorsMarlinMoEMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
should_ignore_layer)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
import os
from vllm import _custom_ops as ops
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
__all__ = ["CompressedTensorsLinearMethod"]
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
def __init__(
self,
target_scheme_map: dict[str, Any],
ignore: list[str],
quant_format: str,
sparsity_scheme_map: dict[str, SparsityCompressionConfig],
sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
):
super().__init__(
target_scheme_map,
ignore,
quant_format,
sparsity_scheme_map,
sparsity_ignore_list,
kv_cache_scheme,
config
)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
if hf_quant_cfg.get("quant_method") == "compressed-tensors" \
and user_quant == "slimquant_marlin":
return cls.get_name()
return None
@classmethod
def get_name(cls) -> QuantizationMethods:
return "slimquant_compressed_tensors_marlin"
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization.
if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
if scheme is None:
return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod()
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMarlinMoEMethod.get_moe_method(self, layer)
return None
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from enum import Enum
from typing import Callable, Optional
from math import prod
import torch
from compressed_tensors.quantization import (QuantizationStrategy)
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_ep_group, get_dp_group
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase,
FusedMoEConfig, FusedMoeWeightScaleSupported,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.w8a8_utils import(
get_w8a8_int8_marlin_weights, w8a8_nt_kpack2_marlin_weight)
try:
from lightop import m_grouped_w8a8_gemm_nt_masked, fuse_silu_mul_quant_ep
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
logger = init_logger(__name__)
__all__ = [
"CompressedTensorsW8A8Int8MarlinMoEMethod",
]
class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
@staticmethod
def get_moe_method(
quant_config: "SlimQuantCompressedTensorsMarlinConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
) -> "CompressedTensorsMarlinMoEMethod":
# are supported + check if the layer is being ignored.
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
input_quant = quant_config.target_scheme_map["Linear"].get(
"input_activations")
if quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MarlinMoEMethod(quant_config)
else:
raise RuntimeError(
f"Slimquant_marlin does not support the FusedMoe scheme: {weight_quant}, {input_quant}")
class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsMarlinConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not per_channel:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
self.fused_experts = self.fused_moe_forward
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size
self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.use_deepep:
self.N = 2 * intermediate_size_per_partition
self.K = hidden_size
params_dtype = torch.int8
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
hidden_size,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
assert not self.static_input_scales
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]):
if not self.use_deepep:
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else:
w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
del w1_marlin_list
w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]):
if not self.use_deepep:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def groupgemm_workspace_shapes(self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,):
assert a.dim() == 2
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens_per_rank is None else self.max_num_tokens_per_rank
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def w8a8_groupgemm_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
**_ ):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
local_num_experts = w1.size(0)
E, max_num_tokens, _, _, top_k = mk._moe_problem_size(
q_x, w1, w2, topk_ids)
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.groupgemm_workspace_shapes(
x, q_x, max_num_tokens, N, K, top_k, global_num_experts,
local_num_experts)
workspace13 = torch.empty(prod(workspace13_shape),
device=x.device,
dtype=workspace_dtype)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
fused_out = _resize_cache(workspace13, fused_out_shape)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m = max_num_tokens
m_grouped_w8a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
fused_out,
expert_num_tokens,
expected_m)
return fused_out
def fused_moe_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
**_ ):
return fused_experts_impl_int8_marlin(
hidden_states=x,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for "
"`CompressedTensorsW8A8Int8MoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate,
e_score_correction_bias=e_score_correction_bias,
indices_type=torch.int64 if self.use_deepep else None,)
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
TritonOrGroupGemmExperts)
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = (
prepare_finalize.max_num_tokens_per_rank())
assert max_num_tokens_per_rank is not None
self.max_num_tokens_per_rank = max_num_tokens_per_rank
logger.debug(
"TritonOrGroupGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank,
None, True)
return TritonOrGroupGemmExperts(
use_int8_w8a8=True,
per_act_token_quant=True,
fused_experts=self.w8a8_groupgemm_forward
)
else:
logger.debug(
"TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, None,
False)
return TritonOrGroupGemmExperts(
fused_experts=self.fused_moe_forward
)
\ No newline at end of file
...@@ -70,6 +70,7 @@ import vllm.envs as envs ...@@ -70,6 +70,7 @@ import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger from vllm.logger import enable_trace_function_call, init_logger
import json import json
if TYPE_CHECKING: if TYPE_CHECKING:
from argparse import Namespace from argparse import Namespace
...@@ -80,12 +81,11 @@ logger = init_logger(__name__) ...@@ -80,12 +81,11 @@ logger = init_logger(__name__)
# This value is chosen to have a balance between ITL and TTFT. Note it is # This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput. # not optimized for throughput.
DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 DEFAULT_MAX_NUM_BATCHED_TOKENS = 10240
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
is_kme = any(arch in GPU_ARCH for arch in ["gfx928"])
SUPPORT_TC = any(arch in GPU_ARCH for arch in ["gfx928", "gfx936"]) SUPPORT_TC = any(arch in GPU_ARCH for arch in ["gfx928", "gfx936"])
def _generate_random_int8( def _generate_random_int8(
...@@ -630,6 +630,31 @@ def get_ip() -> str: ...@@ -630,6 +630,31 @@ def get_ip() -> str:
stacklevel=2) stacklevel=2)
return "0.0.0.0" return "0.0.0.0"
def test_loopback_bind(address, family):
try:
s = socket.socket(family, socket.SOCK_DGRAM)
s.bind((address, 0)) # Port 0 = auto assign
s.close()
return True
except OSError:
return False
def get_loopback_ip() -> str:
loopback_ip = envs.VLLM_LOOPBACK_IP
if loopback_ip:
return loopback_ip
# VLLM_LOOPBACK_IP is not set, try to get it based on network interface
if test_loopback_bind("127.0.0.1", socket.AF_INET):
return "127.0.0.1"
elif test_loopback_bind("::1", socket.AF_INET6):
return "::1"
else:
raise RuntimeError(
"Neither 127.0.0.1 nor ::1 are bound to a local interface. "
"Set the VLLM_LOOPBACK_IP environment variable explicitly.")
def is_valid_ipv6_address(address: str) -> bool: def is_valid_ipv6_address(address: str) -> bool:
try: try:
......
...@@ -44,6 +44,10 @@ class Executor(ExecutorBase): ...@@ -44,6 +44,10 @@ class Executor(ExecutorBase):
elif distributed_executor_backend == "mp": elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor executor_class = MultiprocExecutor
elif distributed_executor_backend == "mp_rpc":
from vllm.v1.executor.multiproc_rpc_executor import (
MultiprocRPCExecutor)
executor_class = MultiprocRPCExecutor
elif distributed_executor_backend == "uni": elif distributed_executor_backend == "uni":
executor_class = UniProcExecutor executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher": elif distributed_executor_backend == "external_launcher":
......
...@@ -530,3 +530,4 @@ class WorkerProc: ...@@ -530,3 +530,4 @@ class WorkerProc:
if output_rank is None or self.rank == output_rank: if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue( self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output)) (WorkerProc.ResponseStatus.SUCCESS, output))
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import pickle
import dill
import signal
import sys
import threading
import time
import traceback
import socket
import weakref
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from multiprocessing.process import BaseProcess
from typing import Any, Callable, Optional, Union, cast
import cloudpickle
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
MessageQueue)
from vllm.executor.multiproc_worker_utils import (
_add_prefix, set_multiprocessing_worker_envs)
from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_ip,
get_loopback_ip, get_open_port)
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
class MultiprocRPCExecutor(Executor):
def _init_executor(self) -> None:
# Call self.shutdown at exit to clean up
# and ensure workers will be terminated.
self._finalizer = weakref.finalize(self, self.shutdown)
self.is_failed = False
self.shutdown_event = threading.Event()
self.failure_callback: Optional[FailureCallback] = None
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
pp_parallel_size = self.parallel_config.pipeline_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
f"_parallel_size ({pp_parallel_size}). ")
# Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs(self.parallel_config)
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.distributed_init_method = distributed_init_method
# Initialize worker and set up message queues for SchedulerOutputs
# and ModelRunnerOutputs
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
self.rpc_broadcast_mq = MessageQueue(self.world_size,
0, # self.world_size,
max_chunk_bytes=max_chunk_bytes)
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers
unready_workers: list[UnreadyWorkerProcHandle] = []
success = False
try:
ready_sockets = []
host_ip = "0.0.0.0"
display_ip = get_loopback_ip()
for rank in range(self.world_size):
port = envs.VLLM_MP_RPC_READY_BASE_PORT + rank
server_socket = socket.socket(socket.AF_INET,
socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET,
socket.SO_REUSEADDR, 1)
server_socket.bind((host_ip, port))
server_socket.listen(1)
ready_sockets.append(server_socket)
unready_workers.append(
UnreadyWorkerProcHandle(
proc=None, # type: ignore
rank=rank,
ready_pipe=server_socket)) # type: ignore
logger.info("Executor waiting for %d workers to connect...",
self.world_size)
for rank, sock in enumerate(ready_sockets):
port = envs.VLLM_MP_RPC_READY_BASE_PORT + rank
logger.info(" - Worker Rank %d should connect to %s 0.0.0.0:%d",
rank, display_ip, port)
# Step 1: Accept connections from all workers
connections: list[Optional[socket.socket]] = [None
] * self.world_size
for unready_handle in unready_workers:
server_socket = unready_handle.ready_pipe
rank = unready_handle.rank
conn, addr = server_socket.accept()
logger.info("Accepted connection from worker rank %d at %s",
rank, str(addr))
connections[rank] = conn
server_socket.close() # Close listening socket
# Step 2: Sequentially send configs to all workers
self._send_configs_to_workers(connections, scheduler_output_handle)
# Step 3: Sequentially wait for ready signals from all workers
self.workers = self._wait_for_workers_ready(
connections, unready_workers)
# Ensure message queues are ready.
self.rpc_broadcast_mq.wait_until_ready()
for w in self.workers:
w.worker_response_mq.wait_until_ready()
logger.warning("Remote worker monitoring is not implemented. "
"System relies on RPC timeouts to detect failures.")
success = True
finally:
if not success:
if self.rpc_broadcast_mq:
logger.info("Sending shutdown command to all workers...")
try:
self.rpc_broadcast_mq.enqueue(
("shutdown", (), {}, None))
except Exception as e:
logger.warning(
"Could not send shutdown command to workers: %s",
e)
# Clean up the worker procs if there was a failure.
for handle in unready_workers:
if handle.ready_pipe:
handle.ready_pipe.close()
# For pipeline parallel, we use a thread pool for asynchronous
# execute_model.
if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io")
self.output_rank = self._get_output_rank()
def register_failure_callback(self, callback: FailureCallback):
if self.is_failed:
callback()
else:
self.failure_callback = callback
def execute_model(
self,
scheduler_output,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
(output, ) = self.collective_rpc(
"execute_model",
args=(scheduler_output, ),
unique_reply_rank=self.output_rank,
non_block=self.max_concurrent_batches > 1,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
return output
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
non_block: bool = False,
unique_reply_rank: Optional[int] = None) -> list[Any]:
if self.is_failed:
raise RuntimeError("Executor failed.")
deadline = None if timeout is None else time.monotonic() + timeout
kwargs = kwargs or {}
# NOTE: If the args are heterogeneous, then we pack them into a list,
# and unpack them in the method of every worker, because every worker
# knows their own rank.
try:
if isinstance(method, str):
send_method = method
else:
send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue(
(send_method, args, kwargs, unique_reply_rank))
workers = (self.workers[unique_reply_rank],
) if unique_reply_rank is not None else self.workers
responses = []
def get_response(w: WorkerProcHandle,
dequeue_timeout: Optional[float] = None,
cancel_event: Optional[threading.Event] = None):
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout, cancel=cancel_event)
if status != WorkerProc.ResponseStatus.SUCCESS:
raise RuntimeError(
f"Worker failed with error '{result}', please check the"
" stack trace above for the root cause")
return result
for w in workers:
dequeue_timeout = None if deadline is None else (
deadline - time.monotonic())
if non_block:
result = self.io_thread_pool.submit( # type: ignore
get_response, w, dequeue_timeout, self.shutdown_event)
else:
result = get_response(w, dequeue_timeout)
responses.append(result)
return responses
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e
def shutdown(self):
"""Properly shut down the executor and its workers"""
if not getattr(self, 'shutting_down', False):
self.shutting_down = True
self.shutdown_event.set()
if self.rpc_broadcast_mq:
logger.info("Sending shutdown command to all workers...")
try:
self.rpc_broadcast_mq.enqueue(("shutdown", (), {}, None))
except Exception as e:
logger.warning(
"Could not send shutdown command to workers: %s", e)
if self.io_thread_pool is not None:
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
self.io_thread_pool = None
self.rpc_broadcast_mq = None
def check_health(self) -> None:
self.collective_rpc("check_health", timeout=10)
return
@property
def max_concurrent_batches(self) -> int:
return self.parallel_config.pipeline_parallel_size
def _get_output_rank(self) -> int:
# Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
# (the first TP worker of the last PP stage).
# Example:
# Assuming TP=8, PP=4, then the world_size=32
# 0-7, PP rank 0
# 8-15, PP rank 1
# 16-23, PP rank 2
# 24-31, PP rank 3
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
return self.world_size - self.parallel_config.tensor_parallel_size
def _send_configs_to_workers(self,
connections: list[Optional[socket.socket]],
scheduler_output_handle: Handle):
"""Sequentially sends configuration to all connected workers."""
logger.info("Sending configuration to all workers...")
# MODIFIED: Gather all VLLM_ prefixed environment variables.
vllm_envs = {
k: v
for k, v in os.environ.items() if k.startswith("VLLM_")
}
logger.info(self.vllm_config)
logger.info(self.distributed_init_method)
logger.info(scheduler_output_handle)
config_payload = dill.dumps({
"vllm_config": self.vllm_config,
"distributed_init_method": self.distributed_init_method,
"input_shm_handle": scheduler_output_handle,
"vllm_envs": vllm_envs, # Add envs to the payload
})
for rank, conn in enumerate(connections):
assert conn is not None
try:
conn.sendall(len(config_payload).to_bytes(4, 'big'))
conn.sendall(config_payload)
logger.info("Configuration sent to worker rank %d.", rank)
except Exception as e:
logger.error("Failed to send config to worker %d: %s", rank, e)
raise e
def _wait_for_workers_ready(
self, connections: list[Optional[socket.socket]],
unready_proc_handles: list["UnreadyWorkerProcHandle"]
) -> list["WorkerProcHandle"]:
"""Sequentially waits for a READY signal from all workers."""
logger.info("Waiting for ready signal from all workers...")
ready_proc_handles: list[Optional[WorkerProcHandle]] = (
[None] * self.world_size)
e = Exception(
"WorkerProc initialization failed. See logs for details.")
for rank, conn in enumerate(connections):
unready_proc_handle = unready_proc_handles[rank]
assert conn is not None
try:
with conn:
len_data = conn.recv(4)
if not len_data:
raise ConnectionAbortedError(
"Worker %d disconnected before "
"sending ready signal.", rank)
payload_len = int.from_bytes(len_data, 'big')
payload = conn.recv(payload_len, socket.MSG_WAITALL)
if not payload:
raise ConnectionAbortedError(
f"Worker {rank} sent an empty ready payload.")
response: dict[str, Any] = pickle.loads(payload)
if response["status"] != WorkerProc.READY_STR:
logger.error("Worker %d failed to initialize: %s",
rank, response.get('error'))
raise e
logger.info("Received ready signal from worker rank %d.",
rank)
worker_response_mq = MessageQueue.create_from_handle(
response["handle"], 0)
ready_proc_handles[rank] = (
WorkerProcHandle.from_unready_handle(
unready_proc_handle, worker_response_mq))
except Exception as e_inner:
e.__suppress_context__ = True
raise e from e_inner
return cast(list[WorkerProcHandle], ready_proc_handles)
@dataclass
class UnreadyWorkerProcHandle:
"""WorkerProcess handle before READY."""
proc: BaseProcess
rank: int
ready_pipe: socket.socket
@dataclass
class WorkerProcHandle:
proc: BaseProcess
rank: int
worker_response_mq: MessageQueue # The worker process writes to this MQ
@classmethod
def from_unready_handle(
cls, unready_handle: UnreadyWorkerProcHandle,
worker_response_mq: MessageQueue) -> "WorkerProcHandle":
return cls(
proc=unready_handle.proc,
rank=unready_handle.rank,
worker_response_mq=worker_response_mq,
)
class WorkerProc:
"""Wrapper that runs one Worker in a separate process."""
READY_STR = "READY"
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
input_shm_handle: Handle,
):
self.rank = rank
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
# TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
]
is_driver_worker = (
rank % vllm_config.parallel_config.tensor_parallel_size == 0)
all_kwargs[rank] = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
"is_driver_worker": is_driver_worker,
}
wrapper.init_worker(all_kwargs)
self.worker = wrapper
pid = os.getpid()
_add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid)
_add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid)
# Initialize MessageQueue for receiving SchedulerOutput
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
input_shm_handle, self.worker.rank)
# Initializes a message queue for sending the model output
# TODO: dynamically detect the number of local readers
self.worker_response_mq = MessageQueue(1, n_local_reader=0)
logger.info("Initialize device and loads weights")
# Initialize device and loads weights
self.worker.init_device()
self.worker.load_model()
def shutdown(self):
self.rpc_broadcast_mq = None
self.worker_response_mq = None
destroy_model_parallel()
destroy_distributed_environment()
@staticmethod
def worker_main(**kwargs):
""" Worker initialization and execution loops.
This runs a background process """
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = False
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the worker
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
worker = None
executor_addr = kwargs.pop("executor_addr")
ready_port = kwargs.pop("ready_port")
local_rank = kwargs.pop("local_rank")
rank = kwargs.pop("rank")
ready_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Set connection timeout parameters
max_retry_time = 300 # 300 seconds maximum wait time
retry_interval = 60 # 60 seconds between retries
start_time = time.time()
try:
# Add connection retry logic with logging
connected = False
last_log_time = 0.0
while not connected and (time.time() -
start_time) < max_retry_time:
try:
ready_socket.connect((executor_addr, ready_port))
connected = True
except ConnectionRefusedError:
current_time = time.time()
if current_time - last_log_time >= retry_interval:
logger.debug(
"Waiting for executor connection... "
"Executor: %s:%d, Time elapsed: %ds",
executor_addr, ready_port,
int(current_time - start_time))
last_log_time = current_time
# Check if we should continue waiting
if (time.time() - start_time) < max_retry_time:
time.sleep(
retry_interval) # Sleep briefly before retry
else:
raise ConnectionError(
f"Failed to connect to executor at "
f"{executor_addr}:{ready_port} after "
f"{max_retry_time} seconds") from None
if not connected:
raise ConnectionError(
f"Unable to establish connection to executor "
f"after {max_retry_time} seconds")
with ready_socket:
# 1. RECEIVE config from executor
len_data = ready_socket.recv(4)
if not len_data:
raise ConnectionError(
"Executor closed connection during config exchange.")
payload_len = int.from_bytes(len_data, 'big')
payload = ready_socket.recv(payload_len, socket.MSG_WAITALL)
if not payload:
raise ConnectionError(
"Did not receive config payload from executor.")
config_data = dill.loads(payload)
# Set environment variables received from the executor.
# This should be done before initializing other components.
vllm_envs = config_data.get("vllm_envs", {})
exclude_list = ["VLLM_LOOPBACK_IP"]
for k, v in vllm_envs.items():
existing_v = os.getenv(k)
if not (k in exclude_list):
if existing_v is not None and existing_v != v:
logger.warning(
"Overwriting worker's environment variable '%s'. "
"Existing value: '%s', New value: '%s'",
(k, existing_v, v))
os.environ[k] = v
vllm_config = config_data["vllm_config"]
distributed_init_method = config_data[
"distributed_init_method"]
input_shm_handle = config_data["input_shm_handle"]
worker = WorkerProc(vllm_config, local_rank, rank,
distributed_init_method, input_shm_handle)
# 2. SEND ready signal back to executor
ready_payload = pickle.dumps({
"status":
WorkerProc.READY_STR,
"handle":
worker.worker_response_mq.export_handle(),
})
ready_socket.sendall(len(ready_payload).to_bytes(4, 'big'))
ready_socket.sendall(ready_payload)
worker.rpc_broadcast_mq.wait_until_ready()
worker.worker_response_mq.wait_until_ready()
worker.worker_busy_loop()
except Exception as e:
logger.exception(
"WorkerProc failed during initialization or execution.")
if ready_socket and ready_socket.fileno() != -1:
error_payload = pickle.dumps({
"status":
"FAILURE",
"error":
str(e),
"traceback":
traceback.format_exc()
})
ready_socket.sendall(len(error_payload).to_bytes(4, 'big'))
ready_socket.sendall(error_payload)
# The parent sends a SIGTERM to all worker processes if
# any worker dies. Set this value so we don't re-throw
# SystemExit() to avoid zmq exceptions in __del__.
shutdown_requested = True
finally:
if ready_socket and ready_socket.fileno() != -1:
ready_socket.close()
# Clean up once worker exits busy loop
if worker is not None:
worker.shutdown()
class ResponseStatus(Enum):
SUCCESS = auto()
FAILURE = auto()
def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers"""
while True:
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
if method == "shutdown":
logger.info("Received shutdown command. Exiting busy loop.")
break
try:
if isinstance(method, str):
func = getattr(self.worker, method)
elif isinstance(method, bytes):
func = partial(cloudpickle.loads(method), self.worker)
output = func(*args, **kwargs)
except Exception as e:
# Notes have been introduced in python 3.11
if hasattr(e, "add_note"):
e.add_note(traceback.format_exc())
logger.exception("WorkerProc hit an exception.")
# exception might not be serializable, so we convert it to
# string, only for logging purpose.
if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, str(e)))
continue
if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment