Commit 5a5e4f3b authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev-ds' into v0.9.2-dev-ds

# Conflicts:
#	vllm/model_executor/layers/fused_moe/ep_moe/layer.py
parents f505d366 a7992f79
......@@ -19,6 +19,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| Qwen3MoeForCausalLM | QWen3MoE | Yes | - | - | v0.8.4 | Yes |
| ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | No | Yes | v0.5.0 | Yes |
| Glm4ForCausalLM | GLM-4-0414 | No/Yes | - | - | v0.8.5.post1 | Yes |
| Glm4MoeForCausalLM | GLM-4.5,GLM-4.5-Air | No/Yes | - | - | v0.9.2 | Yes |
| DeepseekForCausalLM | Deepseek | Yes | No | - | v0.5.0 | Yes |
| DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | No | - | v0.6.2 | Yes |
| DeepseekVLV2ForCausalLM | DeepSeek-VL2 | Yes | No | - | v0.7.2 | Yes |
......
......@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None:
sha = get_sha(vllm_root)
if (major, minor) >= ('2', '5'):
version = 'das.opt1.rc2.' + sha[:7]
version = 'das.opt1.' + sha[:7]
else:
if (major, minor) >= ('2', '5'):
version = 'das.opt1.rc2'
version = 'das.opt1'
# dtk version
......
......@@ -2174,7 +2174,6 @@ def gather_cache(src_cache: torch.Tensor,
cu_seq_lens, batch_size, seq_starts)
#dst_fp8->bf16
convert_fp8(dst, dst_fp8, scale, kv_dtype)
else:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)
......
......@@ -944,11 +944,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_meta.max_decode_seq_len, self.sliding_window,
self.kv_cache_dtype, self.alibi_slopes)
if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else
decode_meta.max_encoder_seq_len)
assert max_seq_len is not None
if use_custom:
max_num_partitions = (
(max_seq_len + _PARTITION_SIZE_ROCM - 1) //
_PARTITION_SIZE_ROCM)
......@@ -1002,6 +1003,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor
if envs.VLLM_USE_FLASH_ATTN_PA:
from flash_attn import vllm_flash_attn_with_kvcache
if decode_meta.use_cuda_graph:
max_seq_len = 0
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA SIZE:")
print(f"q.shape = {decode_query.unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}, kv_cache_dtype = {self.kv_cache_dtype}")
......@@ -1024,6 +1027,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
k_scale=layer._k_scale,
v_scale=layer._v_scale,
kv_cache_dtype=self.kv_cache_dtype,
max_seqlen_k=max_seq_len,
).squeeze(1)
else:
out_pa[:] = paged_attn.forward_decode(
......
......@@ -30,6 +30,7 @@ try:
except AttributeError:
tag_cudagraph_unsafe = () # type: ignore[assignment]
class Attention(nn.Module):
"""Attention layer.
......@@ -439,6 +440,7 @@ direct_register_custom_op(
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,)
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
......
......@@ -326,7 +326,7 @@ class ModelConfig:
"""Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the
model does not support sliding window, this argument is ignored."""
disable_cascade_attn: bool = False
disable_cascade_attn: bool = True
"""Disable cascade attention for V1. While cascade attention does not
change the mathematical correctness, disabling it could be useful for
preventing potential numerical issues. Note that even if this is set to
......@@ -418,7 +418,6 @@ class ModelConfig:
- "transformers" will use the Transformers model implementation."""
override_attention_dtype: Optional[str] = None
"""Override dtype for attention"""
enable_chunked_prefill: Optional[bool] = None
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
......
......@@ -18,6 +18,7 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
......@@ -213,6 +214,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
kv_cache_layer = kv_cache[ \
forward_context.virtual_engine]
if not envs.VLLM_P2P_ASYNC:
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name)
......@@ -234,6 +236,61 @@ class P2pNcclConnector(KVConnectorBase_V1):
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.p2p_nccl_engine.pool.free(addr)
else:
dst_kv_cache_layer_shape = kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
assert kv_cache_layer.is_contiguous()
dst_kv_cache_layer = kv_cache_layer.reshape(
num_pages * page_size, -1)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
assert kv_cache_layer.is_contiguous()
dst_kv_cache_layer = kv_cache_layer.reshape(
2, num_pages * page_size, -1)
inject_start_index = 0
for num in range(self.p2p_nccl_engine.tensor_split_num):
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name + "#" + str(num))
if kv_cache is None:
logger.warning("🚧src_kv_cache is None, %s",
request.request_id)
continue
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()):
num_token = kv_cache.shape[0]
if len(request.slot_mapping) == num_token:
dst_kv_cache_layer[request.slot_mapping, ...] = kv_cache
else:
dst_kv_cache_layer[request.slot_mapping[inject_start_index:inject_start_index + num_token],
...] = kv_cache
else:
num_token = kv_cache.shape[1]
if len(request.slot_mapping) == num_token:
dst_kv_cache_layer[:, request.slot_mapping, ...] = kv_cache
else:
dst_kv_cache_layer[:, request.slot_mapping[inject_start_index:inject_start_index + num_token],
...] = kv_cache
inject_start_index += num_token
# inject_kv_into_layer(kv_cache_layer, kv_cache,
# request.slot_mapping, request.request_id)
tensor_id = request.request_id + "#" + layer_name + "#" + str(num)
if tensor_id in self.p2p_nccl_engine.recv_store:
tensor = self.p2p_nccl_engine.recv_store.pop(tensor_id, None)
self.p2p_nccl_engine.send_request_id_to_tensor_ids.pop(
request.request_id, None)
self.p2p_nccl_engine.recv_request_id_to_tensor_ids.pop(
request.request_id, None)
addr = 0
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.p2p_nccl_engine.pool.free(addr)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
def wait_for_layer_load(self, layer_name: str) -> None:
......@@ -296,30 +353,29 @@ class P2pNcclConnector(KVConnectorBase_V1):
request.slot_mapping_device = \
request.slot_mapping.pin_memory().to(device=kv_layer.device, non_blocking=True)
slot_mapping = request.slot_mapping_device
kv_cache = extract_kv_from_layer(kv_layer, slot_mapping)
tbo_evt = torch.cuda.Event(enable_timing=False)
tbo_evt.record()
pp_rank = (self.parallel_config.rank //
self.parallel_config.tensor_parallel_size) % \
self.parallel_config.pipeline_parallel_size
if (self.pp_size == 1):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address, tbo_evt)
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), remote_address, tbo_evt)
elif (self.pp_size == 2):
if (pp_rank == 0):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address, tbo_evt)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank + 4), tbo_evt)
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), remote_address, tbo_evt)
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + self._rank + 4), tbo_evt)
else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address, tbo_evt)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank - 4), tbo_evt)
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), remote_address, tbo_evt)
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + self._rank - 4), tbo_evt)
elif (self.pp_size == 8):
for i in range(8):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + i), tbo_evt)
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + i), tbo_evt)
else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!")
else:
......
......@@ -21,6 +21,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import
TensorMemoryPool)
from vllm.utils import current_stream, get_ip
from vllm import envs
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
if TYPE_CHECKING:
from vllm.forward_context import ForwardContext
......@@ -111,9 +112,12 @@ class P2pNcclEngine:
self.recv_store_cv = threading.Condition()
self.send_stream = torch.cuda.Stream()
# self.send_stream = tbo_all_reduce_stream
self.recv_stream = torch.cuda.Stream()
self.p2p_async_kv_tokens = envs.VLLM_P2P_BUF_TOKENS
self.p2p_async_buf = None
self.tensor_split_num: int = 0
mem_pool_size_gb = self.config.get_from_extra_config(
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB)
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb) *
......@@ -208,7 +212,54 @@ class P2pNcclEngine:
return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
self.send_queue.append([tensor_id, remote_address, tensor, tbo_evt])
self.send_queue.append([tensor_id, remote_address, tensor])
self.send_queue_cv.notify()
else: # GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store))
oldest_tenser = self.send_store.pop(oldest_tenser_id)
oldest_tenser_size = oldest_tenser.element_size(
) * oldest_tenser.numel()
self.buffer_size -= oldest_tenser_size
logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
return True
def p2p_async_send_tensor(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
tbo_evt = None,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify()
return True
else:
if self.send_type == "PUT":
return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
kv_layer, slot_mapping = tensor # tesor (kv_layer, slot_mapping)
self.send_queue.append([tensor_id, remote_address, kv_layer, slot_mapping, tbo_evt])
self.send_queue_cv.notify()
else: # GET
with self.send_store_cv:
......@@ -313,6 +364,8 @@ class P2pNcclEngine:
self.zmq_address, remote_address.decode(), rank)
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
if "tensor_split_num" in data:
self.tensor_split_num = data["tensor_split_num"]
try:
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
......@@ -392,11 +445,16 @@ class P2pNcclEngine:
with self.send_queue_cv:
while not self.send_queue:
self.send_queue_cv.wait()
tensor_id, remote_address, tensor, tbo_evt = self.send_queue.popleft()
if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC:
tensor_id, remote_address, kv_layer, slot_mapping, tbo_evt = self.send_queue.popleft()
else:
tensor_id, remote_address, tensor = self.send_queue.popleft()
if not self.send_queue:
self.send_queue_cv.notify()
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.send_stream.wait_event(tbo_evt)
self._send_kv_p2p_sync(tensor_id, kv_layer, slot_mapping, remote_address)
else:
self._send_sync(tensor_id, tensor, remote_address)
def wait_for_sent(self):
......@@ -410,6 +468,75 @@ class P2pNcclEngine:
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d", duration * 1000, self.rank)
def _send_kv_p2p_sync(self, tensor_id: str, kv_layer: torch.Tensor,
slot_mapping: torch.Tensor, remote_address: str) -> bool:
if remote_address not in self.socks:
self._create_connect(remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
is_mla = (kv_layer.ndim == 3)
hidden_dim = kv_layer.shape[-1]
if self.p2p_async_buf is None:
if is_mla:
self.p2p_async_buf = torch.empty((self.p2p_async_kv_tokens, hidden_dim),
dtype=kv_layer.dtype, device=kv_layer.device)
else:
self.p2p_async_buf = torch.empty((2, self.p2p_async_kv_tokens, hidden_dim),
dtype=kv_layer.dtype, device=kv_layer.device)
pack_num = (slot_mapping.shape[0] - 1) // self.p2p_async_kv_tokens + 1
self.tensor_split_num = pack_num
with torch.cuda.stream(self.send_stream):
for pack_idx in range(pack_num):
start = pack_idx * self.p2p_async_kv_tokens
end = min((pack_idx + 1) * self.p2p_async_kv_tokens, slot_mapping.shape[0])
sub_index = slot_mapping[start:end]
if is_mla:
num_pages, page_size = kv_layer.shape[0], kv_layer.shape[1]
data = kv_layer.reshape(num_pages * page_size, -1)
torch.index_select(data, dim=0, index=sub_index, out=self.p2p_async_buf[:end-start])
tx_shape = (end - start, hidden_dim)
else:
num_pages, page_size = kv_layer.shape[1], kv_layer.shape[2]
data = kv_layer.reshape(2, num_pages * page_size, -1)
torch.index_select(data, dim=1, index=sub_index, out=self.p2p_async_buf[:, :end-start])
tx_shape = (2, end - start, hidden_dim)
if is_mla:
send_tensor = self.p2p_async_buf[:end-start]
else:
send_tensor = self.p2p_async_buf[:, :end-start]
header = {
"cmd": "PUT",
"tensor_id": tensor_id + "#" + str(pack_idx), # 拼 pack_idx
"pack_idx": pack_idx,
"tensor_split_num": pack_num,
"shape": tx_shape,
"dtype": str(kv_layer.dtype).replace("torch.", "")
}
sock.send(msgpack.dumps(header))
response = sock.recv()
if response != b"0":
logger.error(
"🔴Send Tensor Failed | %s 👉 %s | Rank:%s | shape:%s | size:%.4f GB | response:%s",
self.zmq_address, remote_address, rank,
tuple(send_tensor.shape), send_tensor.element_size() * send_tensor.numel() / 1024**3,
response.decode()
)
return False
self._send(comm, send_tensor, rank ^ 1, self.send_stream)
if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id)
return True
def _send_sync(
self,
tensor_id: str,
......
......@@ -166,11 +166,15 @@ if TYPE_CHECKING:
VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHTOP: bool = False
VLLM_USE_OPT_CAT: bool = False
VLLM_USE_OPT_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_USE_MORI_EP: bool = False
VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000
def get_default_cache_root():
return os.getenv(
......@@ -1104,6 +1108,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in
("true", "1")),
# vLLM will use triton moe_sum
"VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")),
# vLLM will use lightop moe_sum
"VLLM_USE_LIGHTOP_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "True").lower() in
("true", "1")),
# vLLM will use lightop moe_align_block_size
"VLLM_USE_LIGHTOP_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "True").lower() in
("true", "1")),
# vLLM will use opt merge_aatn_states, not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
......@@ -1131,6 +1147,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm pd separation will be used async
"VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
# pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS":
lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -30,11 +30,11 @@ try:
except ImportError:
is_mori_available = False
logger = init_logger(__name__)
_MORI_OP = None
@CustomOp.register("unquantized_ep_moe")
class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
"""MoE method without quantization."""
......@@ -167,6 +167,7 @@ class EPMoE(FusedMoE):
dp+ep MoE Expert Parallel Impl
"""
def __init__(
self,
num_experts: int, # Global number of experts
......@@ -229,24 +230,35 @@ class EPMoE(FusedMoE):
]
self.use_shared_expert = False
<<<<<<< HEAD
# self.token_dispatcher = MoEAlltoAllTokenDispatcher(
# self.local_num_experts, self.local_expert_indices,
# config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher",
# )
=======
self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.local_num_experts, self.local_expert_indices,
config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher",
)
>>>>>>> origin/v0.9.2-dev-ds
self.shared_expert_overlap = moe_shared_expert_overlap
self.shared_experts = None
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
self.scales = None
self.use_int8_dispatch = True
vllm_config = get_current_vllm_config()
self.max_num_inp_token_per_rank = vllm_config.scheduler_config.max_num_seqs
self.mori_op = self.get_mori_op()
<<<<<<< HEAD
=======
self.first = True
>>>>>>> origin/v0.9.2-dev-ds
def get_mori_op(self):
global _MORI_OP
if _MORI_OP is None:
......@@ -262,7 +274,7 @@ class EPMoE(FusedMoE):
vllm_config = get_current_vllm_config()
multi_node = self.ep_size / 8 > 1
mori_data_type=vllm_config.model_config.dtype
mori_data_type = vllm_config.model_config.dtype
mori_scale_type_size = vllm_config.model_config.dtype.itemsize
if self.use_int8_dispatch:
mori_scale_type_size = 4
......@@ -280,7 +292,7 @@ class EPMoE(FusedMoE):
max_token_type_size=2,
block_num=80,
warp_num_per_block=16,
#kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
# kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \
mori.ops.EpDispatchCombineKernelType.IntraNode
)
......@@ -307,7 +319,11 @@ class EPMoE(FusedMoE):
return quant_method
def sync(self):
<<<<<<< HEAD
torch.cuda.synchronize()
=======
# torch.cuda.synchronize()
>>>>>>> origin/v0.9.2-dev-ds
dist.barrier()
def forward(self, hidden_states: torch.Tensor,
......@@ -334,7 +350,6 @@ class EPMoE(FusedMoE):
if name not in NON_EXPERT_WEIGHTS
]
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
......@@ -368,8 +383,7 @@ class EPMoE(FusedMoE):
)
scales = self.scales
#self.sync()
# self.sync()
(
dispatch_output,
......@@ -382,9 +396,35 @@ class EPMoE(FusedMoE):
topk_weights,
scales,
topk_ids,
layer_idx=int(self.layer_name.split('.')[2])
#layer_idx=int(self.layer_name.split('.')[2])
)
<<<<<<< HEAD
#self.sync()
=======
# self.sync()
# expect_m = topk_ids.shape[0] * self.ep_size
# dispatch_output_clip = dispatch_output[:expect_m]
# dispatch_weights_clip = dispatch_weights[:expect_m]
# dispatch_indices_clip = dispatch_indices[:expect_m]
# dispatch_scales_clip = dispatch_scales[:expect_m]
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output_clip,
# topk_weights=dispatch_weights_clip,
# topk_ids=dispatch_indices_clip,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales_clip if self.use_int8_dispatch else None
# #routed_scaling_factor=self.routed_scaling_factor,
# )
>>>>>>> origin/v0.9.2-dev-ds
expert_output = self.quant_method.apply_ep(
layer=self,
......@@ -399,21 +439,21 @@ class EPMoE(FusedMoE):
num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0]*self.ep_size/self.dp_size,
scales=dispatch_scales if self.use_int8_dispatch else None
#routed_scaling_factor=self.routed_scaling_factor,
# routed_scaling_factor=self.routed_scaling_factor,
)
#self.sync()
# self.sync()
combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids)
final_hidden_states = combine_output[:hidden_states.shape[0], :]
#self.sync()
# self.sync()
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# shared_output = (
# self.maybe_all_reduce_tensor_model_parallel(
# shared_output))
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
......@@ -423,6 +463,7 @@ class EPMoE(FusedMoE):
return final_hidden_states
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
......@@ -443,5 +484,5 @@ direct_register_custom_op(
mutates_args=["hidden_states", "router_logits"],
fake_impl=ep_moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
tags=(torch.Tag.needs_fixed_stride_order,),
)
\ No newline at end of file
......@@ -181,7 +181,6 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.use_all_gather = current_platform.use_all_gather()
self.probs = None
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
# For smuggling this layer into the fused moe custom op
vllm_config = get_current_vllm_config()
......@@ -446,7 +445,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_output = self.shared_experts.get_output()
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
if hidden_states.dtype != torch.float16:
output = output + shared_output
else:
# Fix FP16 overflow
......
......@@ -45,14 +45,141 @@ from lightop import op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
os.environ['DPSK_FP16_QUICK'] = os.environ.get('DPSK_FP16_QUICK', '0')
dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None
@torch.compile
def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
torch.sum(x, dim=1, out=out)
out.mul_(routed_scaling_factor)
@triton.jit
def _moe_sum_reduce_kernel(
input_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
output_ptr,
output_stride_0,
output_stride_1,
token_num: int,
topk_num: int,
hidden_dim: int,
routed_scaling_factor: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
token_block_id = tl.program_id(0)
dim_block_id = tl.program_id(1)
token_start = token_block_id * BLOCK_M
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
dim_start = dim_block_id * BLOCK_DIM
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
for token_index in range(token_start, token_end):
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tmp = tl.load(
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
)
accumulator += tmp
accumulator = accumulator * routed_scaling_factor
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
tl.store(
store_t_ptr,
accumulator.to(input_ptr.dtype.element_ty),
mask=offs_dim < dim_end,
)
def moe_sum_reduce_triton(
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
):
assert input.is_contiguous()
assert output.is_contiguous()
token_num, topk_num, hidden_dim = input.shape
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
if token_num <= 32:
BLOCK_M = 1
BLOCK_DIM = 512
NUM_STAGE = 2
num_warps = 4
elif token_num <= 128:
BLOCK_M = 1
BLOCK_DIM = 1024
NUM_STAGE = 0
num_warps = 2
elif token_num <= 4096:
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 0
num_warps = 2
else:
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 2
num_warps = 8
grid = (
triton.cdiv(token_num, BLOCK_M),
triton.cdiv(hidden_dim, BLOCK_DIM),
)
_moe_sum_reduce_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
token_num=token_num,
topk_num=topk_num,
hidden_dim=hidden_dim,
routed_scaling_factor=routed_scaling_factor,
BLOCK_M=BLOCK_M,
BLOCK_DIM=BLOCK_DIM,
NUM_STAGE=NUM_STAGE,
num_warps=num_warps,
)
return
def moe_reduce_dispatch(
intermediate_cache3: torch.Tensor,
out_hidden_states: torch.Tensor,
begin_chunk_idx: int,
end_chunk_idx: int,
):
inter_cache_view = intermediate_cache3.view(*intermediate_cache3.shape)
n = intermediate_cache3.shape[0]
# 根据 n 大小选择不同的 reduce 实现
if 1 <= n <= 4:
moe_sum_reduce_torch_compile(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx], 1.0)
elif 4 < n <= 1024:
moe_sum_reduce_triton(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx], 1.0)
elif 1024 < n <= 32768:
ops.moe_sum_opt1(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx])
else:
ops.moe_sum(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx])
def get_moe_cache(top_k_num,N,K,device,dtype):
global moe_cache_singleton
if moe_cache_singleton is None:
......@@ -1266,7 +1393,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> None:
routed_scaling_factor: Optional[float] = 1.0) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8,
......@@ -1300,7 +1427,7 @@ def inplace_fused_experts_fake(
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> None:
routed_scaling_factor: Optional[float] = 1.0) -> None:
pass
......@@ -1338,7 +1465,7 @@ def outplace_fused_experts(
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
routed_scaling_factor: Optional[float] = 1.0) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
......@@ -1372,7 +1499,7 @@ def outplace_fused_experts_fake(
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
routed_scaling_factor: Optional[float] = 1.0) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -1431,7 +1558,7 @@ def fused_experts(
allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
routed_scaling_factor: Optional[float] = 1.0) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N = w1.size(1)
......@@ -1520,7 +1647,7 @@ def fused_experts_impl(
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
routed_scaling_factor: Optional[float] = 1.0,
) -> torch.Tensor:
num_tokens = hidden_states.size(0)
if use_nn_moe:
......@@ -1769,7 +1896,7 @@ def fused_experts_impl(
block_shape=block_shape,
use_nn_moe=use_nn_moe)
if envs.VLLM_USE_LIGHTOP and not dpsk_fp16_quick:
if envs.VLLM_USE_LIGHTOP:
from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=shared_output[begin_chunk_idx:end_chunk_idx],
......@@ -1788,6 +1915,14 @@ def fused_experts_impl(
# if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor
else:
if envs.VLLM_USE_LIGHTOP_MOE_SUM:
from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=None,
expert_mask=None, num_local_tokens=None, factor=1.0)
elif envs.VLLM_USE_OPT_MOE_SUM:
moe_reduce_dispatch(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx], begin_chunk_idx, end_chunk_idx)
else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
......@@ -1825,7 +1960,7 @@ def fused_moe(
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
routed_scaling_factor: Optional[float] = 1.0,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
......
......@@ -376,7 +376,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor:
if enable_eplb:
......@@ -423,7 +423,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation: str = "silu",
shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor:
......@@ -487,7 +487,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False,
**kwargs,
):
......@@ -527,7 +527,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor:
assert not use_grouped_topk
......@@ -683,7 +683,7 @@ class FusedMoE(torch.nn.Module):
activation: str = "silu",
enable_eplb: bool = False,
num_redundant_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
routed_scaling_factor: Optional[float] = 1.0,
):
super().__init__()
if params_dtype is None:
......@@ -1269,7 +1269,7 @@ class FusedMoE(torch.nn.Module):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False
) -> tuple[torch.Tensor, torch.Tensor]:
"""
......
......@@ -239,6 +239,14 @@ def moe_align_block_size(
expert_map = expert_map,
expert_mask = expert_mask,
num_local_tokens = None)
else:
if envs.VLLM_USE_LIGHTOP_MOE_ALIGN:
from lightop import op as op
op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad,
expert_map = None,
expert_mask = None,
num_local_tokens = None)
else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
......
......@@ -514,6 +514,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
......
......@@ -473,6 +473,7 @@ class BlockInt8MoEMethod:
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
**_
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -514,5 +515,7 @@ class BlockInt8MoEMethod:
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
use_nn_moe=use_nn_moe
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
\ No newline at end of file
......@@ -348,6 +348,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
......@@ -430,7 +431,9 @@ class MoeWNA16Method(FusedMoEMethodBase):
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size],
use_nn_moe=False)
use_nn_moe=False,
shared_output=shared_output,
)
@staticmethod
def get_weight_loader(layer, weight_loader):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment