Commit 82cd3c88 authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev' into v0.9.2-dev

# Conflicts:
#	vllm/envs.py
parents 35e43dfb 7d5faa43
...@@ -553,18 +553,7 @@ def unified_attention_with_output( ...@@ -553,18 +553,7 @@ def unified_attention_with_output(
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, kv_cache)
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: def unified_attention_with_output_fake(
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
return
else:
def unified_attention_with_output_fake(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
......
...@@ -18,7 +18,7 @@ from vllm.forward_context import get_forward_context ...@@ -18,7 +18,7 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
...@@ -90,12 +90,24 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -90,12 +90,24 @@ class P2pNcclConnector(KVConnectorBase_V1):
if role == KVConnectorRole.WORKER else 0 if role == KVConnectorRole.WORKER else 0
self._local_rank = get_world_group().local_rank \ self._local_rank = get_world_group().local_rank \
if role == KVConnectorRole.WORKER else 0 if role == KVConnectorRole.WORKER else 0
self._dp_rank = get_dp_group().rank_in_group \
if role == KVConnectorRole.WORKER else 0
self._pp_rank = get_pp_group().rank_in_group \
if role == KVConnectorRole.WORKER else 0
self._tp_rank = get_tp_group().rank_in_group \
if role == KVConnectorRole.WORKER else 0
self._dp_size = get_dp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self._pp_size = get_pp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self._tp_size = get_tp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self.p2p_nccl_engine = P2pNcclEngine( self.p2p_nccl_engine = P2pNcclEngine(
local_rank=self._local_rank, local_rank=self._local_rank,
config=self.config,
hostname="",
port_offset=self._rank, port_offset=self._rank,
config=self.config,
model_config=vllm_config.model_config,
) if role == KVConnectorRole.WORKER else None ) if role == KVConnectorRole.WORKER else None
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
...@@ -105,9 +117,19 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -105,9 +117,19 @@ class P2pNcclConnector(KVConnectorBase_V1):
self.pp_size = self.parallel_config.pipeline_parallel_size self.pp_size = self.parallel_config.pipeline_parallel_size
self.tp_size = self.parallel_config.tensor_parallel_size self.tp_size = self.parallel_config.tensor_parallel_size
self.num_card = self.pp_size * self.tp_size self.num_card = self.pp_size * self.tp_size
self.multiple_machines = 1 if self.num_card > 8 else 0
if self.is_producer and self.multiple_machines == 1: self.remote_tp_size = self.config.get_from_extra_config(
"remote_tp_size", self.tp_size)
self.remote_pp_size = self.config.get_from_extra_config(
"remote_pp_size", self.pp_size)
self.enable_asymmetric_p2p = self.config.get_from_extra_config(
"enable_asymmetric_p2p", False)
self.remote_num_card = self.remote_tp_size * self.remote_pp_size
self.multiple_machines_d = 1 if self.remote_num_card > 8 else 0
self.multiple_machines_p = 1 if self.num_card > 8 else 0
if self.is_producer and self.multiple_machines_p == 1:
self.ip_map = {} self.ip_map = {}
self.duplicate_keys = [] self.duplicate_keys = []
config_file = os.getenv('IP_CONFIG_FILE') config_file = os.getenv('IP_CONFIG_FILE')
...@@ -353,6 +375,8 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -353,6 +375,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None assert self.p2p_nccl_engine is not None
is_mla = isinstance(attn_metadata, MLACommonMetadata)
def extract_kv_from_layer( def extract_kv_from_layer(
layer: torch.Tensor, layer: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
...@@ -417,7 +441,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -417,7 +441,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
) % self.parallel_config.pipeline_parallel_size ) % self.parallel_config.pipeline_parallel_size
if (self.multiple_machines): if (self.multiple_machines_p and self.multiple_machines_d):
ip_second = self.get_ip_value(ip) ip_second = self.get_ip_value(ip)
if (self.pp_size == 1): if (self.pp_size == 1):
if self._rank < 8: if self._rank < 8:
...@@ -433,29 +457,43 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -433,29 +457,43 @@ class P2pNcclConnector(KVConnectorBase_V1):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank)) kv_cache, str(ip_second) + ":" + str(port + self._rank))
else: else:
print("Error: only suppprt pp1 pp2 !!!!!!") logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
else: elif (self.multiple_machines_p and not self.multiple_machines_d):
if (self.pp_size == 1): if (self.pp_size == 2):
remote_address = ip + ":" + str(port + self._tp_rank)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address) kv_cache, remote_address)
elif (self.pp_size == 2):
if (pp_rank == 0):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank + 4))
else: else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, elif (not self.multiple_machines_p and not self.multiple_machines_d):
kv_cache, ip + ":" + str(port + self._rank - 4)) self.p2p_nccl_engine.send_tensor_new(request_id, layer_name, kv_cache,
elif (self.pp_size == 8): is_mla)
for i in range(8): # if (self.pp_size == 1):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, # self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + i)) # kv_cache, remote_address)
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank + 4))
# else:
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank - 4))
# elif (self.pp_size == 8):
# for i in range(8):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + i))
# elif (self.enable_asymmetric_p2p):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else: else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!") logger.error("Error: not support!!!!!!")
def wait_for_save(self): def wait_for_save(self):
pass pass
# if self.is_producer: # if self.is_producer:
......
...@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional
import msgpack import msgpack
import torch import torch
import zmq import zmq
import regex
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
from vllm.distributed.device_communicators.pynccl_wrapper import ( from vllm.distributed.device_communicators.pynccl_wrapper import (
...@@ -23,6 +24,11 @@ from vllm.utils import current_stream, get_ip ...@@ -23,6 +24,11 @@ from vllm.utils import current_stream, get_ip
from vllm import envs from vllm import envs
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from dataclasses import dataclass
from vllm.model_executor.models.utils import extract_layer_index
from vllm.distributed.utils import get_pp_indices
from vllm.config import ModelConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
...@@ -30,6 +36,11 @@ logger = logging.getLogger(__name__) ...@@ -30,6 +36,11 @@ logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB = 32 DEFAULT_MEM_POOL_SIZE_GB = 32
# @dataclass
# class SendQueueItem:
# tensor_id: str
# remote_address: str
# tensor: torch.Tensor
@contextmanager @contextmanager
def set_p2p_nccl_context(num_channels: str): def set_p2p_nccl_context(num_channels: str):
...@@ -65,22 +76,39 @@ class P2pNcclEngine: ...@@ -65,22 +76,39 @@ class P2pNcclEngine:
def __init__(self, def __init__(self,
local_rank: int, local_rank: int,
port_offset: int,
config: KVTransferConfig, config: KVTransferConfig,
hostname: str = "", model_config: ModelConfig,
port_offset: int = 0,
library_path: Optional[str] = None) -> None: library_path: Optional[str] = None) -> None:
self.config = config self.config = config
self.model_config = model_config
self.rank = port_offset self.rank = port_offset
self.local_rank = local_rank self.local_rank = local_rank
self.device = torch.device(f"cuda:{self.local_rank}") self.device = torch.device(f"cuda:{self.local_rank}")
self.nccl = NCCLLibrary(library_path) self.nccl = NCCLLibrary(library_path)
if not hostname: self.total_num_hidden_layers = getattr(self.model_config.hf_text_config,
hostname = get_ip() "num_hidden_layers", 0)
self.pp_rank = get_pp_group().rank_in_group
self.tp_rank = get_tp_group().rank_in_group
self.pp_size = get_pp_group().world_size
self.tp_size = get_tp_group().world_size
if config.is_kv_producer:
self.remote_tp_size = self.config.get_from_extra_config(
"remote_tp_size", 1)
self.remote_pp_size = self.config.get_from_extra_config(
"remote_pp_size", 1)
self.enable_asymmetric_p2p = self.config.get_from_extra_config(
"enable_asymmetric_p2p", False)
if self.remote_tp_size % self.tp_size != 0:
logger.error(" the Prefill TP size must be less than or equal to the Decode TP size!!!!")
self.multp = int(self.remote_tp_size / self.tp_size)
port = int(self.config.kv_port) + port_offset port = int(self.config.kv_port) + port_offset
if port == 0: if port == 0:
raise ValueError("Port cannot be 0") raise ValueError("Port cannot be 0")
self._hostname = hostname self._hostname = get_ip()
self._port = port self._port = port
# Each card corresponds to a ZMQ address. # Each card corresponds to a ZMQ address.
...@@ -195,6 +223,61 @@ class P2pNcclEngine: ...@@ -195,6 +223,61 @@ class P2pNcclEngine:
return self.socks[remote_address], self.comms[remote_address] return self.socks[remote_address], self.comms[remote_address]
def get_send_queue_items(self, request_id: str, layer_name: str,
tensor: torch.Tensor,
is_mla: bool) -> list[any]:
tensor_id = self.get_tensor_id(request_id, layer_name)
remote_ip, remote_port = self.parse_request_id(request_id, True)
if not self.enable_asymmetric_p2p:
remote_address = remote_ip + ":" + str(remote_port + self.rank)
return [(tensor_id, remote_address, tensor)]
if not is_mla:
logger.error(" P2PNCCL only support mla model symmetric PP/TP!!!!")
remote_pp_rank = self.compute_remote_pp_rank(layer_name)
items: list[Any] = []
up_down = 1
# remote_tp_rank = self.tp_rank * self.multp
for d_tp_rank in range(self.remote_tp_size):
for mul_tp in range(self.multp):
if self.tp_rank + mul_tp * self.tp_size == d_tp_rank:
remote_port_offset = remote_pp_rank * self.remote_tp_size + d_tp_rank
remote_address = remote_ip + ":" + str(remote_port + remote_port_offset)
logger.debug(
"📥 [PUT] Wait to send: tensor_id:%s, tensor_shape:%s, "
"(pp=%d, tp=%d) -> remote_address=%s(pp=%d, tp=%d)", tensor_id,
tensor.shape, self.pp_rank, self.tp_rank, remote_address,
remote_pp_rank, self.rank * mul_tp + self.rank)
items.append([tensor_id, remote_address, tensor])
return items
def send_tensor_new(
self,
request_id: str,
layer_name: str,
tensor: torch.Tensor,
is_mla: bool = False,
) -> bool:
tensor_id = self.get_tensor_id(request_id, layer_name)
if self.send_type == "PUT":
return all(
self.send_sync(item) for item in self.get_send_queue_items(
request_id, layer_name, tensor, is_mla))
if self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
for item in self.get_send_queue_items(request_id, layer_name,
tensor, is_mla):
self.send_queue.append(item)
self.send_queue_cv.notify()
return True
if self.send_type == "GET":
logger.error(" P2PNCCL new not support GET model, please set VLLM_P2PNCCL_NEW=0 use defalut model!!!!")
def send_tensor( def send_tensor(
self, self,
tensor_id: str, tensor_id: str,
...@@ -659,3 +742,38 @@ class P2pNcclEngine: ...@@ -659,3 +742,38 @@ class P2pNcclEngine:
self._send_thread.join() self._send_thread.join()
if self._ping_thread is not None: if self._ping_thread is not None:
self._ping_thread.join() self._ping_thread.join()
def compute_remote_pp_rank(self, layer_name: str) -> int:
current_layer_idx = extract_layer_index(layer_name)
for d_pp_rank in range(self.remote_pp_size):
start, end = get_pp_indices(self.total_num_hidden_layers, d_pp_rank, self.remote_pp_size)
logger.info(f"""compute_remote_pp_rank : current_layer_idx:{current_layer_idx} start:{start} end:{end}""")
if (current_layer_idx == self.total_num_hidden_layers):
return self.remote_pp_size - 1
if start <= current_layer_idx < end:
return d_pp_rank
return -1
@staticmethod
def get_tensor_id(request_id: str, layer_name: str) -> str:
return request_id + "#" + layer_name
@staticmethod
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
# Regular expression to match the string hostname and integer port
if is_prefill:
pattern = r"___decode_addr_(.*):(\d+)"
else:
pattern = r"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match = regex.search(pattern, request_id)
if match:
# Extract the ranks
ip = match.group(1)
port = int(match.group(2))
return ip, port
raise ValueError(
f"Request id {request_id} does not contain hostname and port")
\ No newline at end of file
...@@ -196,6 +196,7 @@ if TYPE_CHECKING: ...@@ -196,6 +196,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
def get_default_cache_root(): def get_default_cache_root():
...@@ -1070,7 +1071,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1070,7 +1071,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# flag to control vllm to use optimized kernels # flag to control vllm to use optimized kernels
"VLLM_CUSTOM_CACHE": "VLLM_CUSTOM_CACHE":
lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "0"))), lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "1"))),
# flag to control vllm to use optimized kernels # flag to control vllm to use optimized kernels
"VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX": "VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX":
...@@ -1276,11 +1277,14 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1276,11 +1277,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MARLIN_W16A16_MOE": "VLLM_USE_MARLIN_W16A16_MOE":
lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT":
lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in
("true", "1")),
# vLLM will use deepgemm kernel for deepep ht mode # vLLM will use deepgemm kernel for deepep ht mode
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM": "VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")), ("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -253,8 +253,6 @@ def get_model_architecture( ...@@ -253,8 +253,6 @@ def get_model_architecture(
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"): if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
# if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
# os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"): if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1' os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"): if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
...@@ -298,8 +296,6 @@ def get_model_architecture( ...@@ -298,8 +296,6 @@ def get_model_architecture(
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"): if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
# if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
# os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"): if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1' os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"): if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
......
...@@ -28,6 +28,8 @@ from .interfaces import SupportsPP ...@@ -28,6 +28,8 @@ from .interfaces import SupportsPP
from .utils import maybe_prefix from .utils import maybe_prefix
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
import vllm.envs as envs
from vllm.utils import direct_register_custom_op
class SharedHead(nn.Module): class SharedHead(nn.Module):
...@@ -72,6 +74,24 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ...@@ -72,6 +74,24 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config, self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
cache_config, quant_config) cache_config, quant_config)
def fuse_fill_rms_x2_concat(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor,
weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
from lightop import fuse_fill_rms_x2_concat
fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, weight_inputs_embeds, weight_previous_hidden_states, epsilon)
def fuse_fill_rms_x2_concat_fake(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor,
weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
pass
direct_register_custom_op(
op_name="fuse_fill_rms_x2_concat",
op_func=fuse_fill_rms_x2_concat,
mutates_args=["hidden_states_fuse", "inputs_embeds"],
fake_impl=fuse_fill_rms_x2_concat_fake,
)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -84,10 +104,14 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ...@@ -84,10 +104,14 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP # masking inputs at position 0, as not needed by MTP
if envs.VLLM_USE_FUSED_FILL_RMS_CAT:
hidden_states_fuse = torch.empty(inputs_embeds.shape[0], inputs_embeds.shape[1]*2, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
torch.ops.vllm.fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, self.enorm.weight, self.hnorm.weight, self.enorm.variance_epsilon)
hidden_states = self.eh_proj(hidden_states_fuse)
else:
inputs_embeds[positions == 0] = 0 inputs_embeds[positions == 0] = 0
inputs_embeds = self.enorm(inputs_embeds) inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states) previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj( hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
......
...@@ -22,22 +22,19 @@ ...@@ -22,22 +22,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights.""" """Inference-only Qwen3MoE model compatible with HuggingFace weights."""
import typing from collections.abc import Iterable
from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any, Optional, Union from typing import Any, Optional, Union
import os import os
import re import re
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_ep_group, get_pp_group, from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
...@@ -51,17 +48,17 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -51,17 +48,17 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader
default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .interfaces import SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
import vllm.envs as envs import vllm.envs as envs
from vllm.utils import direct_register_custom_op
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
...@@ -108,86 +105,49 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -108,86 +105,49 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_text_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts: if self.tp_size > config.num_experts:
raise ValueError( raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than " f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.") f"the number of experts {config.num_experts}.")
# Load balancing settings. self.experts = FusedMoE(num_experts=config.num_experts,
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = (self.ep_rank *
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
self.experts = FusedMoE(num_experts=self.n_routed_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
reduce_results=True, reduce_results=False,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts")
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts, config.num_experts,
bias=False, bias=False,
quant_config=quant_config, quant_config=None,
prefix=f"{prefix}.gate") prefix=f"{prefix}.gate")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert hidden_states.dim( # NOTE: hidden_states can have either 1D or 2D shape.
) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" orig_shape = hidden_states.shape
is_input_1d = hidden_states.dim() == 1 hidden_dim = hidden_states.shape[-1]
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if self.is_sequence_parallel: if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states, 0) final_hidden_states)
final_hidden_states = final_hidden_states[:num_tokens]
# return to 1d if input is 1d return final_hidden_states.view(orig_shape)
return final_hidden_states.squeeze(0) if is_input_1d else \
final_hidden_states
class Qwen3MoeAttention(nn.Module): class Qwen3MoeAttention(nn.Module):
...@@ -206,7 +166,6 @@ class Qwen3MoeAttention(nn.Module): ...@@ -206,7 +166,6 @@ class Qwen3MoeAttention(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -230,7 +189,6 @@ class Qwen3MoeAttention(nn.Module): ...@@ -230,7 +189,6 @@ class Qwen3MoeAttention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(hidden_size, self.qkv_proj = QKVParallelLinear(hidden_size,
self.head_dim, self.head_dim,
...@@ -252,25 +210,72 @@ class Qwen3MoeAttention(nn.Module): ...@@ -252,25 +210,72 @@ class Qwen3MoeAttention(nn.Module):
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
) )
self.attn = Attention( self.attn = Attention(self.num_heads,
self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn")
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {},
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
def rms_rotary_embedding_fuse_fake(
# q_out:torch.Tensor,
# k_out:torch.Tensor,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
direct_register_custom_op(
op_name="rms_rotary_embedding_fuse",
op_func=rms_rotary_embedding_fuse,
mutates_args=["query", "key"],
fake_impl=rms_rotary_embedding_fuse_fake,
)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -278,7 +283,36 @@ class Qwen3MoeAttention(nn.Module): ...@@ -278,7 +283,36 @@ class Qwen3MoeAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm if envs.VLLM_USE_FUSED_RMS_ROPE :
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache
# # q, k 使用 continuous
q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions,
q,
k,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
else:
# Add qk-norm then RoPE (original path).
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim) self.head_dim)
if envs.VLLM_USE_APEX_RN: if envs.VLLM_USE_APEX_RN:
...@@ -302,21 +336,19 @@ class Qwen3MoeAttention(nn.Module): ...@@ -302,21 +336,19 @@ class Qwen3MoeAttention(nn.Module):
class Qwen3MoeDecoderLayer(nn.Module): class Qwen3MoeDecoderLayer(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_text_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
8192) 8192)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
self.self_attn = Qwen3MoeAttention( self.self_attn = Qwen3MoeAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
...@@ -330,7 +362,6 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -330,7 +362,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
dual_chunk_attention_config=dual_chunk_attention_config,
) )
# `mlp_only_layers` in the config. # `mlp_only_layers` in the config.
...@@ -340,7 +371,8 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -340,7 +371,8 @@ class Qwen3MoeDecoderLayer(nn.Module):
if (layer_idx not in mlp_only_layers) and ( if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0): (layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config, self.mlp = Qwen3MoeSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp") prefix=f"{prefix}.mlp")
else: else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
...@@ -384,11 +416,9 @@ class Qwen3MoeModel(nn.Module): ...@@ -384,11 +416,9 @@ class Qwen3MoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_text_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -403,11 +433,12 @@ class Qwen3MoeModel(nn.Module): ...@@ -403,11 +433,12 @@ class Qwen3MoeModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens") prefix=f"{prefix}.embed_tokens")
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config, lambda prefix: Qwen3MoeDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix), prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
...@@ -444,7 +475,8 @@ class Qwen3MoeModel(nn.Module): ...@@ -444,7 +475,8 @@ class Qwen3MoeModel(nn.Module):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
...@@ -454,16 +486,6 @@ class Qwen3MoeModel(nn.Module): ...@@ -454,16 +486,6 @@ class Qwen3MoeModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
num_redundant_experts=self.num_redundant_experts)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
...@@ -480,9 +502,16 @@ class Qwen3MoeModel(nn.Module): ...@@ -480,9 +502,16 @@ class Qwen3MoeModel(nn.Module):
".v_scale", "_v_scale", ".weight_scale", ".v_scale", "_v_scale", ".weight_scale",
"_weight_scale", ".input_scale", "_input_scale") "_weight_scale", ".input_scale", "_input_scale")
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if self.use_llama_nn: if self.use_llama_nn:
current_count = loaded_weight.current_count current_count = loaded_weight.current_count
...@@ -508,68 +537,35 @@ class Qwen3MoeModel(nn.Module): ...@@ -508,68 +537,35 @@ class Qwen3MoeModel(nn.Module):
# Skip layers on other devices. # Skip layers on other devices.
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if name not in params_dict: if name not in params_dict:
continue continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = param.weight_loader
default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
is_expert_weight = False
for mapping in expert_params_mapping: for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name)
# Anyway, this is an expert weight and should not be # Skip layers on other devices.
# attempted to load as other weights later if is_pp_missing_parameter(name, self):
is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue continue
# Skip loading extra parameters for GPTQ/modelopt models. # Skip loading extra parameters for GPTQ/modelopt models.
if name_mapped.endswith( if name.endswith(
ignore_suffixes ignore_suffixes) and name not in params_dict:
) and name_mapped not in params_dict:
continue continue
param = params_dict[name]
param = params_dict[name_mapped] weight_loader = param.weight_loader
# We should ask the weight loader to return success or not weight_loader(param,
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight, loaded_weight,
name_mapped, name,
shard_id=shard_id, shard_id=shard_id,
expert_id=expert_id, expert_id=expert_id)
return_success=True)
if success:
name = name_mapped
break break
else: else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra parameters for GPTQ/modelopt models. # Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith( if name.endswith(
ignore_suffixes) and name not in params_dict: ignore_suffixes) and name not in params_dict:
...@@ -639,8 +635,7 @@ class Qwen3MoeModel(nn.Module): ...@@ -639,8 +635,7 @@ class Qwen3MoeModel(nn.Module):
return loaded_params return loaded_params
class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
MixtureOfExperts):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -657,7 +652,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, ...@@ -657,7 +652,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_text_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
...@@ -665,74 +660,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, ...@@ -665,74 +660,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config)
prefix=maybe_prefix(prefix, "lm_head"))
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
# Set MoE hyperparameters
self.expert_weights = []
self.moe_layers: list[FusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Qwen3MoeDecoderLayer)
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
raise RuntimeError("No Qwen3MoE layer found in the model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = (num_physical_experts -
self.num_logical_experts)
for layer in self.model.layers:
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
...@@ -750,14 +684,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, ...@@ -750,14 +684,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states) logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
...@@ -217,6 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, ...@@ -217,6 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from lightop import fused_rms_norm_rope_contiguous
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -1163,7 +1164,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1163,7 +1164,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str = "bf16" kv_cache_dtype_str = "bf16"
else: else:
kv_cache_dtype_str = self.kv_cache_dtype kv_cache_dtype_str = self.kv_cache_dtype
from lightop import fused_rms_norm_rope_contiguous
fused_rms_norm_rope_contiguous( fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...], positions[:num_actual_toks, ...],
q, q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment