Commit 82cd3c88 authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev' into v0.9.2-dev

# Conflicts:
#	vllm/envs.py
parents 35e43dfb 7d5faa43
......@@ -553,18 +553,7 @@ def unified_attention_with_output(
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
return
else:
def unified_attention_with_output_fake(
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
......
......@@ -18,7 +18,7 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
......@@ -90,12 +90,24 @@ class P2pNcclConnector(KVConnectorBase_V1):
if role == KVConnectorRole.WORKER else 0
self._local_rank = get_world_group().local_rank \
if role == KVConnectorRole.WORKER else 0
self._dp_rank = get_dp_group().rank_in_group \
if role == KVConnectorRole.WORKER else 0
self._pp_rank = get_pp_group().rank_in_group \
if role == KVConnectorRole.WORKER else 0
self._tp_rank = get_tp_group().rank_in_group \
if role == KVConnectorRole.WORKER else 0
self._dp_size = get_dp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self._pp_size = get_pp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self._tp_size = get_tp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self.p2p_nccl_engine = P2pNcclEngine(
local_rank=self._local_rank,
config=self.config,
hostname="",
port_offset=self._rank,
config=self.config,
model_config=vllm_config.model_config,
) if role == KVConnectorRole.WORKER else None
self.parallel_config = vllm_config.parallel_config
......@@ -105,9 +117,19 @@ class P2pNcclConnector(KVConnectorBase_V1):
self.pp_size = self.parallel_config.pipeline_parallel_size
self.tp_size = self.parallel_config.tensor_parallel_size
self.num_card = self.pp_size * self.tp_size
self.multiple_machines = 1 if self.num_card > 8 else 0
if self.is_producer and self.multiple_machines == 1:
self.remote_tp_size = self.config.get_from_extra_config(
"remote_tp_size", self.tp_size)
self.remote_pp_size = self.config.get_from_extra_config(
"remote_pp_size", self.pp_size)
self.enable_asymmetric_p2p = self.config.get_from_extra_config(
"enable_asymmetric_p2p", False)
self.remote_num_card = self.remote_tp_size * self.remote_pp_size
self.multiple_machines_d = 1 if self.remote_num_card > 8 else 0
self.multiple_machines_p = 1 if self.num_card > 8 else 0
if self.is_producer and self.multiple_machines_p == 1:
self.ip_map = {}
self.duplicate_keys = []
config_file = os.getenv('IP_CONFIG_FILE')
......@@ -353,6 +375,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None
is_mla = isinstance(attn_metadata, MLACommonMetadata)
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
......@@ -417,7 +441,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
) % self.parallel_config.pipeline_parallel_size
if (self.multiple_machines):
if (self.multiple_machines_p and self.multiple_machines_d):
ip_second = self.get_ip_value(ip)
if (self.pp_size == 1):
if self._rank < 8:
......@@ -433,29 +457,43 @@ class P2pNcclConnector(KVConnectorBase_V1):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank))
else:
print("Error: only suppprt pp1 pp2 !!!!!!")
else:
if (self.pp_size == 1):
logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
elif (self.multiple_machines_p and not self.multiple_machines_d):
if (self.pp_size == 2):
remote_address = ip + ":" + str(port + self._tp_rank)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
elif (self.pp_size == 2):
if (pp_rank == 0):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank + 4))
else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank - 4))
elif (self.pp_size == 8):
for i in range(8):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + i))
logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
elif (not self.multiple_machines_p and not self.multiple_machines_d):
self.p2p_nccl_engine.send_tensor_new(request_id, layer_name, kv_cache,
is_mla)
# if (self.pp_size == 1):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank + 4))
# else:
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank - 4))
# elif (self.pp_size == 8):
# for i in range(8):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + i))
# elif (self.enable_asymmetric_p2p):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!")
logger.error("Error: not support!!!!!!")
def wait_for_save(self):
pass
# if self.is_producer:
......
......@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional
import msgpack
import torch
import zmq
import regex
from vllm.config import KVTransferConfig
from vllm.distributed.device_communicators.pynccl_wrapper import (
......@@ -23,6 +24,11 @@ from vllm.utils import current_stream, get_ip
from vllm import envs
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from dataclasses import dataclass
from vllm.model_executor.models.utils import extract_layer_index
from vllm.distributed.utils import get_pp_indices
from vllm.config import ModelConfig
if TYPE_CHECKING:
from vllm.forward_context import ForwardContext
......@@ -30,6 +36,11 @@ logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB = 32
# @dataclass
# class SendQueueItem:
# tensor_id: str
# remote_address: str
# tensor: torch.Tensor
@contextmanager
def set_p2p_nccl_context(num_channels: str):
......@@ -65,22 +76,39 @@ class P2pNcclEngine:
def __init__(self,
local_rank: int,
port_offset: int,
config: KVTransferConfig,
hostname: str = "",
port_offset: int = 0,
model_config: ModelConfig,
library_path: Optional[str] = None) -> None:
self.config = config
self.model_config = model_config
self.rank = port_offset
self.local_rank = local_rank
self.device = torch.device(f"cuda:{self.local_rank}")
self.nccl = NCCLLibrary(library_path)
if not hostname:
hostname = get_ip()
self.total_num_hidden_layers = getattr(self.model_config.hf_text_config,
"num_hidden_layers", 0)
self.pp_rank = get_pp_group().rank_in_group
self.tp_rank = get_tp_group().rank_in_group
self.pp_size = get_pp_group().world_size
self.tp_size = get_tp_group().world_size
if config.is_kv_producer:
self.remote_tp_size = self.config.get_from_extra_config(
"remote_tp_size", 1)
self.remote_pp_size = self.config.get_from_extra_config(
"remote_pp_size", 1)
self.enable_asymmetric_p2p = self.config.get_from_extra_config(
"enable_asymmetric_p2p", False)
if self.remote_tp_size % self.tp_size != 0:
logger.error(" the Prefill TP size must be less than or equal to the Decode TP size!!!!")
self.multp = int(self.remote_tp_size / self.tp_size)
port = int(self.config.kv_port) + port_offset
if port == 0:
raise ValueError("Port cannot be 0")
self._hostname = hostname
self._hostname = get_ip()
self._port = port
# Each card corresponds to a ZMQ address.
......@@ -195,6 +223,61 @@ class P2pNcclEngine:
return self.socks[remote_address], self.comms[remote_address]
def get_send_queue_items(self, request_id: str, layer_name: str,
tensor: torch.Tensor,
is_mla: bool) -> list[any]:
tensor_id = self.get_tensor_id(request_id, layer_name)
remote_ip, remote_port = self.parse_request_id(request_id, True)
if not self.enable_asymmetric_p2p:
remote_address = remote_ip + ":" + str(remote_port + self.rank)
return [(tensor_id, remote_address, tensor)]
if not is_mla:
logger.error(" P2PNCCL only support mla model symmetric PP/TP!!!!")
remote_pp_rank = self.compute_remote_pp_rank(layer_name)
items: list[Any] = []
up_down = 1
# remote_tp_rank = self.tp_rank * self.multp
for d_tp_rank in range(self.remote_tp_size):
for mul_tp in range(self.multp):
if self.tp_rank + mul_tp * self.tp_size == d_tp_rank:
remote_port_offset = remote_pp_rank * self.remote_tp_size + d_tp_rank
remote_address = remote_ip + ":" + str(remote_port + remote_port_offset)
logger.debug(
"📥 [PUT] Wait to send: tensor_id:%s, tensor_shape:%s, "
"(pp=%d, tp=%d) -> remote_address=%s(pp=%d, tp=%d)", tensor_id,
tensor.shape, self.pp_rank, self.tp_rank, remote_address,
remote_pp_rank, self.rank * mul_tp + self.rank)
items.append([tensor_id, remote_address, tensor])
return items
def send_tensor_new(
self,
request_id: str,
layer_name: str,
tensor: torch.Tensor,
is_mla: bool = False,
) -> bool:
tensor_id = self.get_tensor_id(request_id, layer_name)
if self.send_type == "PUT":
return all(
self.send_sync(item) for item in self.get_send_queue_items(
request_id, layer_name, tensor, is_mla))
if self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
for item in self.get_send_queue_items(request_id, layer_name,
tensor, is_mla):
self.send_queue.append(item)
self.send_queue_cv.notify()
return True
if self.send_type == "GET":
logger.error(" P2PNCCL new not support GET model, please set VLLM_P2PNCCL_NEW=0 use defalut model!!!!")
def send_tensor(
self,
tensor_id: str,
......@@ -659,3 +742,38 @@ class P2pNcclEngine:
self._send_thread.join()
if self._ping_thread is not None:
self._ping_thread.join()
def compute_remote_pp_rank(self, layer_name: str) -> int:
current_layer_idx = extract_layer_index(layer_name)
for d_pp_rank in range(self.remote_pp_size):
start, end = get_pp_indices(self.total_num_hidden_layers, d_pp_rank, self.remote_pp_size)
logger.info(f"""compute_remote_pp_rank : current_layer_idx:{current_layer_idx} start:{start} end:{end}""")
if (current_layer_idx == self.total_num_hidden_layers):
return self.remote_pp_size - 1
if start <= current_layer_idx < end:
return d_pp_rank
return -1
@staticmethod
def get_tensor_id(request_id: str, layer_name: str) -> str:
return request_id + "#" + layer_name
@staticmethod
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
# Regular expression to match the string hostname and integer port
if is_prefill:
pattern = r"___decode_addr_(.*):(\d+)"
else:
pattern = r"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match = regex.search(pattern, request_id)
if match:
# Extract the ranks
ip = match.group(1)
port = int(match.group(2))
return ip, port
raise ValueError(
f"Request id {request_id} does not contain hostname and port")
\ No newline at end of file
......@@ -196,6 +196,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
def get_default_cache_root():
......@@ -1070,7 +1071,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# flag to control vllm to use optimized kernels
"VLLM_CUSTOM_CACHE":
lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "0"))),
lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "1"))),
# flag to control vllm to use optimized kernels
"VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX":
......@@ -1276,11 +1277,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MARLIN_W16A16_MOE":
lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in
("true", "1")),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT":
lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in
("true", "1")),
# vLLM will use deepgemm kernel for deepep ht mode
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -253,8 +253,6 @@ def get_model_architecture(
os.environ['VLLM_USE_OPT_CAT'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1'
# if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
# os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
......@@ -298,8 +296,6 @@ def get_model_architecture(
os.environ['VLLM_USE_OPT_CAT'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1'
# if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
# os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
......
......@@ -28,6 +28,8 @@ from .interfaces import SupportsPP
from .utils import maybe_prefix
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
import vllm.envs as envs
from vllm.utils import direct_register_custom_op
class SharedHead(nn.Module):
......@@ -72,6 +74,24 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
cache_config, quant_config)
def fuse_fill_rms_x2_concat(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor,
weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
from lightop import fuse_fill_rms_x2_concat
fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, weight_inputs_embeds, weight_previous_hidden_states, epsilon)
def fuse_fill_rms_x2_concat_fake(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor,
weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
pass
direct_register_custom_op(
op_name="fuse_fill_rms_x2_concat",
op_func=fuse_fill_rms_x2_concat,
mutates_args=["hidden_states_fuse", "inputs_embeds"],
fake_impl=fuse_fill_rms_x2_concat_fake,
)
def forward(
self,
input_ids: torch.Tensor,
......@@ -84,10 +104,14 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
if envs.VLLM_USE_FUSED_FILL_RMS_CAT:
hidden_states_fuse = torch.empty(inputs_embeds.shape[0], inputs_embeds.shape[1]*2, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
torch.ops.vllm.fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, self.enorm.weight, self.hnorm.weight, self.enorm.variance_epsilon)
hidden_states = self.eh_proj(hidden_states_fuse)
else:
inputs_embeds[positions == 0] = 0
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
......
This diff is collapsed.
......@@ -217,6 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from lightop import fused_rms_norm_rope_contiguous
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
......@@ -1163,7 +1164,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str = "bf16"
else:
kv_cache_dtype_str = self.kv_cache_dtype
from lightop import fused_rms_norm_rope_contiguous
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment