Unverified Commit 6cdf015c authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Misc] Fix `Current vLLM config is not set.` warnings, assert to avoid issues...


[Misc] Fix `Current vLLM config is not set.` warnings, assert to avoid issues in the future (#31747)
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 5d3b6097
......@@ -117,9 +117,9 @@ class DeviceCommunicatorBase:
use_ep = False
all2all_backend = None
from vllm.config import get_current_vllm_config
from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config()
config = get_current_vllm_config_or_none()
if config is not None:
# as long as we use data parallel (coupled data parallel
# where all data parallel ranks execute forward together),
......
......@@ -9,7 +9,7 @@ from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import get_current_vllm_config
from vllm.config import get_current_vllm_config_or_none
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -184,7 +184,7 @@ class QuickAllReduce:
)
return
self.qr_quant_level = QuickReduceRegime[regime_str]
vllm_config = get_current_vllm_config()
vllm_config = get_current_vllm_config_or_none()
if (
vllm_config is not None
and hasattr(vllm_config, "model_config")
......
......@@ -1177,9 +1177,9 @@ def init_distributed_environment(
distributed_init_method,
backend,
)
from vllm.config import get_current_vllm_config
from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config()
config = get_current_vllm_config_or_none()
if (
config is not None
and config.parallel_config.distributed_executor_backend != "external_launcher"
......@@ -1251,7 +1251,7 @@ def init_distributed_environment(
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
if config.parallel_config.nnodes > 1:
if config is not None and config.parallel_config.nnodes > 1:
_NODE_COUNT = config.parallel_config.nnodes
else:
_NODE_COUNT = _node_count(_WORLD.cpu_group)
......@@ -1260,7 +1260,7 @@ def init_distributed_environment(
assert _WORLD.world_size == torch.distributed.get_world_size(), (
"world group already initialized with a different world size"
)
if config.parallel_config.nnodes_within_dp > 1:
if config is not None and config.parallel_config.nnodes_within_dp > 1:
if parallel_config.data_parallel_size > 1:
world_size_inner_dp = parallel_config.world_size
group_ranks = [
......@@ -1316,9 +1316,9 @@ def initialize_model_parallel(
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
data_parallel_size = 1
from vllm.config import get_current_vllm_config
from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config()
config = get_current_vllm_config_or_none()
if config is not None:
data_parallel_size = config.parallel_config.data_parallel_size
......
......@@ -13,10 +13,28 @@ from vllm.model_executor.layers.quantization.utils.layer_utils import replace_pa
from vllm.utils.torch_utils import direct_register_custom_op
_CPU_MOE_LAYER_CACHE = {}
_CPU_MOE_ACT = {
"silu": SiluAndMul(),
"swigluoai": SwigluOAIAndMul(),
}
class _LazyActivationDict(dict):
"""Lazily instantiate activation functions on first access.
Avoids triggering CustomOp.__init__() at module import time,
which would call get_current_vllm_config() before config is set.
"""
_factories: dict[str, type[SiluAndMul] | type[SwigluOAIAndMul]] = {
"silu": SiluAndMul,
"swigluoai": SwigluOAIAndMul,
}
def __missing__(self, key: str) -> SiluAndMul | SwigluOAIAndMul:
if key not in self._factories:
raise KeyError(f"{key} is not a supported activation")
self[key] = self._factories[key]()
return self[key]
_CPU_MOE_ACT = _LazyActivationDict()
def grouped_topk(
......@@ -212,7 +230,7 @@ class CPUFusedMOE:
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation in _CPU_MOE_ACT, f"{activation} is not supported."
assert activation in _CPU_MOE_ACT._factories, f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
......
......@@ -540,6 +540,20 @@ class FusedMoE(CustomOp):
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation
self._grouped_topk_impl: GroupedTopk | None = None
if self.use_grouped_topk:
assert self.num_expert_group is not None
assert self.topk_group is not None
self._grouped_topk_impl = GroupedTopk(
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
num_fused_shared_experts=self.num_fused_shared_experts,
)
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError(
"Only softmax scoring function is supported for non-grouped topk."
......@@ -1588,19 +1602,8 @@ class FusedMoE(CustomOp):
# DeepSeekv2 uses grouped_top_k
elif self.use_grouped_topk and valid_grouping():
assert self.topk_group is not None
assert self.num_expert_group is not None
grouped_topk_impl = GroupedTopk(
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
num_fused_shared_experts=self.num_fused_shared_experts,
)
topk_weights, topk_ids = grouped_topk_impl(
assert self._grouped_topk_impl is not None
topk_weights, topk_ids = self._grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias,
......
......@@ -339,15 +339,11 @@ def apply_rotary_pos_emb_flashatt(
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
apply_rotary_emb: ApplyRotaryEmb,
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
q_embed = apply_rotary_emb(q, cos, sin)
k_embed = apply_rotary_emb(k, cos, sin)
......@@ -410,6 +406,11 @@ class KeyeSiglipAttention(nn.Module):
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def forward(
self,
hidden_states: torch.Tensor,
......@@ -448,7 +449,7 @@ class KeyeSiglipAttention(nn.Module):
self.num_kv_heads,
self.head_dim,
)
q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin)
q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin, self.apply_rotary_emb)
v = v.view(
*v.shape[:-1],
self.num_kv_heads,
......
......@@ -152,16 +152,12 @@ def apply_rotary_pos_emb(
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_flash_attn_backend: bool = False,
is_flash_attn_backend: bool,
apply_rotary_emb: ApplyRotaryEmb,
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
if is_flash_attn_backend and current_platform.is_cuda():
apply_rotary_emb_func = apply_rotary_emb.forward_cuda
elif is_flash_attn_backend and current_platform.is_rocm():
......@@ -235,6 +231,11 @@ class Siglip2Attention(nn.Module):
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def forward(
self,
hidden_states: torch.Tensor,
......@@ -260,6 +261,7 @@ class Siglip2Attention(nn.Module):
cos,
sin,
self.attn.is_flash_attn_backend,
self.apply_rotary_emb,
)
queries = queries.squeeze(0)
keys = keys.squeeze(0)
......
......@@ -14,7 +14,7 @@ import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
from vllm.config.compilation import CompilationMode
from vllm.distributed import (
ensure_model_parallel_initialized,
......@@ -268,7 +268,9 @@ class Worker(WorkerBase):
# to hijack tensor allocation.
def load_model(self) -> None:
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
with self._maybe_get_memory_pool_context(tag="weights"):
with self._maybe_get_memory_pool_context(
tag="weights"
) and set_current_vllm_config(self.vllm_config):
self.model_runner.load_model(eep_scale_up=eep_scale_up)
def update_config(self, overrides: dict[str, Any]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment