Unverified Commit 6cdf015c authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Misc] Fix `Current vLLM config is not set.` warnings, assert to avoid issues...


[Misc] Fix `Current vLLM config is not set.` warnings, assert to avoid issues in the future (#31747)
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 5d3b6097
...@@ -117,9 +117,9 @@ class DeviceCommunicatorBase: ...@@ -117,9 +117,9 @@ class DeviceCommunicatorBase:
use_ep = False use_ep = False
all2all_backend = None all2all_backend = None
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config() config = get_current_vllm_config_or_none()
if config is not None: if config is not None:
# as long as we use data parallel (coupled data parallel # as long as we use data parallel (coupled data parallel
# where all data parallel ranks execute forward together), # where all data parallel ranks execute forward together),
......
...@@ -9,7 +9,7 @@ from torch.distributed import ProcessGroup ...@@ -9,7 +9,7 @@ from torch.distributed import ProcessGroup
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config_or_none
from vllm.distributed.parallel_state import in_the_same_node_as from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -184,7 +184,7 @@ class QuickAllReduce: ...@@ -184,7 +184,7 @@ class QuickAllReduce:
) )
return return
self.qr_quant_level = QuickReduceRegime[regime_str] self.qr_quant_level = QuickReduceRegime[regime_str]
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config_or_none()
if ( if (
vllm_config is not None vllm_config is not None
and hasattr(vllm_config, "model_config") and hasattr(vllm_config, "model_config")
......
...@@ -1177,9 +1177,9 @@ def init_distributed_environment( ...@@ -1177,9 +1177,9 @@ def init_distributed_environment(
distributed_init_method, distributed_init_method,
backend, backend,
) )
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config() config = get_current_vllm_config_or_none()
if ( if (
config is not None config is not None
and config.parallel_config.distributed_executor_backend != "external_launcher" and config.parallel_config.distributed_executor_backend != "external_launcher"
...@@ -1251,7 +1251,7 @@ def init_distributed_environment( ...@@ -1251,7 +1251,7 @@ def init_distributed_environment(
if _WORLD is None: if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size())) ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend) _WORLD = init_world_group(ranks, local_rank, backend)
if config.parallel_config.nnodes > 1: if config is not None and config.parallel_config.nnodes > 1:
_NODE_COUNT = config.parallel_config.nnodes _NODE_COUNT = config.parallel_config.nnodes
else: else:
_NODE_COUNT = _node_count(_WORLD.cpu_group) _NODE_COUNT = _node_count(_WORLD.cpu_group)
...@@ -1260,7 +1260,7 @@ def init_distributed_environment( ...@@ -1260,7 +1260,7 @@ def init_distributed_environment(
assert _WORLD.world_size == torch.distributed.get_world_size(), ( assert _WORLD.world_size == torch.distributed.get_world_size(), (
"world group already initialized with a different world size" "world group already initialized with a different world size"
) )
if config.parallel_config.nnodes_within_dp > 1: if config is not None and config.parallel_config.nnodes_within_dp > 1:
if parallel_config.data_parallel_size > 1: if parallel_config.data_parallel_size > 1:
world_size_inner_dp = parallel_config.world_size world_size_inner_dp = parallel_config.world_size
group_ranks = [ group_ranks = [
...@@ -1316,9 +1316,9 @@ def initialize_model_parallel( ...@@ -1316,9 +1316,9 @@ def initialize_model_parallel(
backend = backend or torch.distributed.get_backend(get_world_group().device_group) backend = backend or torch.distributed.get_backend(get_world_group().device_group)
data_parallel_size = 1 data_parallel_size = 1
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config() config = get_current_vllm_config_or_none()
if config is not None: if config is not None:
data_parallel_size = config.parallel_config.data_parallel_size data_parallel_size = config.parallel_config.data_parallel_size
......
...@@ -13,10 +13,28 @@ from vllm.model_executor.layers.quantization.utils.layer_utils import replace_pa ...@@ -13,10 +13,28 @@ from vllm.model_executor.layers.quantization.utils.layer_utils import replace_pa
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
_CPU_MOE_LAYER_CACHE = {} _CPU_MOE_LAYER_CACHE = {}
_CPU_MOE_ACT = {
"silu": SiluAndMul(),
"swigluoai": SwigluOAIAndMul(), class _LazyActivationDict(dict):
} """Lazily instantiate activation functions on first access.
Avoids triggering CustomOp.__init__() at module import time,
which would call get_current_vllm_config() before config is set.
"""
_factories: dict[str, type[SiluAndMul] | type[SwigluOAIAndMul]] = {
"silu": SiluAndMul,
"swigluoai": SwigluOAIAndMul,
}
def __missing__(self, key: str) -> SiluAndMul | SwigluOAIAndMul:
if key not in self._factories:
raise KeyError(f"{key} is not a supported activation")
self[key] = self._factories[key]()
return self[key]
_CPU_MOE_ACT = _LazyActivationDict()
def grouped_topk( def grouped_topk(
...@@ -212,7 +230,7 @@ class CPUFusedMOE: ...@@ -212,7 +230,7 @@ class CPUFusedMOE:
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation in _CPU_MOE_ACT, f"{activation} is not supported." assert activation in _CPU_MOE_ACT._factories, f"{activation} is not supported."
assert not apply_router_weight_on_input assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
......
...@@ -540,6 +540,20 @@ class FusedMoE(CustomOp): ...@@ -540,6 +540,20 @@ class FusedMoE(CustomOp):
self.apply_router_weight_on_input = apply_router_weight_on_input self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation self.activation = activation
self._grouped_topk_impl: GroupedTopk | None = None
if self.use_grouped_topk:
assert self.num_expert_group is not None
assert self.topk_group is not None
self._grouped_topk_impl = GroupedTopk(
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
num_fused_shared_experts=self.num_fused_shared_experts,
)
if self.scoring_func != "softmax" and not self.use_grouped_topk: if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError( raise ValueError(
"Only softmax scoring function is supported for non-grouped topk." "Only softmax scoring function is supported for non-grouped topk."
...@@ -1588,19 +1602,8 @@ class FusedMoE(CustomOp): ...@@ -1588,19 +1602,8 @@ class FusedMoE(CustomOp):
# DeepSeekv2 uses grouped_top_k # DeepSeekv2 uses grouped_top_k
elif self.use_grouped_topk and valid_grouping(): elif self.use_grouped_topk and valid_grouping():
assert self.topk_group is not None assert self._grouped_topk_impl is not None
assert self.num_expert_group is not None topk_weights, topk_ids = self._grouped_topk_impl(
grouped_topk_impl = GroupedTopk(
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
num_fused_shared_experts=self.num_fused_shared_experts,
)
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
......
...@@ -339,15 +339,11 @@ def apply_rotary_pos_emb_flashatt( ...@@ -339,15 +339,11 @@ def apply_rotary_pos_emb_flashatt(
k: torch.Tensor, k: torch.Tensor,
cos: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor, sin: torch.Tensor,
apply_rotary_emb: ApplyRotaryEmb,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous() cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous()
apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
q_embed = apply_rotary_emb(q, cos, sin) q_embed = apply_rotary_emb(q, cos, sin)
k_embed = apply_rotary_emb(k, cos, sin) k_embed = apply_rotary_emb(k, cos, sin)
...@@ -410,6 +406,11 @@ class KeyeSiglipAttention(nn.Module): ...@@ -410,6 +406,11 @@ class KeyeSiglipAttention(nn.Module):
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
) )
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -448,7 +449,7 @@ class KeyeSiglipAttention(nn.Module): ...@@ -448,7 +449,7 @@ class KeyeSiglipAttention(nn.Module):
self.num_kv_heads, self.num_kv_heads,
self.head_dim, self.head_dim,
) )
q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin) q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin, self.apply_rotary_emb)
v = v.view( v = v.view(
*v.shape[:-1], *v.shape[:-1],
self.num_kv_heads, self.num_kv_heads,
......
...@@ -152,16 +152,12 @@ def apply_rotary_pos_emb( ...@@ -152,16 +152,12 @@ def apply_rotary_pos_emb(
k: torch.Tensor, k: torch.Tensor,
cos: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor, sin: torch.Tensor,
is_flash_attn_backend: bool = False, is_flash_attn_backend: bool,
apply_rotary_emb: ApplyRotaryEmb,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous() cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous()
apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
if is_flash_attn_backend and current_platform.is_cuda(): if is_flash_attn_backend and current_platform.is_cuda():
apply_rotary_emb_func = apply_rotary_emb.forward_cuda apply_rotary_emb_func = apply_rotary_emb.forward_cuda
elif is_flash_attn_backend and current_platform.is_rocm(): elif is_flash_attn_backend and current_platform.is_rocm():
...@@ -235,6 +231,11 @@ class Siglip2Attention(nn.Module): ...@@ -235,6 +231,11 @@ class Siglip2Attention(nn.Module):
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
) )
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -260,6 +261,7 @@ class Siglip2Attention(nn.Module): ...@@ -260,6 +261,7 @@ class Siglip2Attention(nn.Module):
cos, cos,
sin, sin,
self.attn.is_flash_attn_backend, self.attn.is_flash_attn_backend,
self.apply_rotary_emb,
) )
queries = queries.squeeze(0) queries = queries.squeeze(0)
keys = keys.squeeze(0) keys = keys.squeeze(0)
......
...@@ -14,7 +14,7 @@ import torch.distributed ...@@ -14,7 +14,7 @@ import torch.distributed
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
from vllm.config.compilation import CompilationMode from vllm.config.compilation import CompilationMode
from vllm.distributed import ( from vllm.distributed import (
ensure_model_parallel_initialized, ensure_model_parallel_initialized,
...@@ -268,7 +268,9 @@ class Worker(WorkerBase): ...@@ -268,7 +268,9 @@ class Worker(WorkerBase):
# to hijack tensor allocation. # to hijack tensor allocation.
def load_model(self) -> None: def load_model(self) -> None:
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1" eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
with self._maybe_get_memory_pool_context(tag="weights"): with self._maybe_get_memory_pool_context(
tag="weights"
) and set_current_vllm_config(self.vllm_config):
self.model_runner.load_model(eep_scale_up=eep_scale_up) self.model_runner.load_model(eep_scale_up=eep_scale_up)
def update_config(self, overrides: dict[str, Any]) -> None: def update_config(self, overrides: dict[str, Any]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment