Commit 5a5e4f3b authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev-ds' into v0.9.2-dev-ds

# Conflicts:
#	vllm/model_executor/layers/fused_moe/ep_moe/layer.py
parents f505d366 a7992f79
...@@ -19,6 +19,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention ...@@ -19,6 +19,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| Qwen3MoeForCausalLM | QWen3MoE | Yes | - | - | v0.8.4 | Yes | | Qwen3MoeForCausalLM | QWen3MoE | Yes | - | - | v0.8.4 | Yes |
| ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | No | Yes | v0.5.0 | Yes | | ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | No | Yes | v0.5.0 | Yes |
| Glm4ForCausalLM | GLM-4-0414 | No/Yes | - | - | v0.8.5.post1 | Yes | | Glm4ForCausalLM | GLM-4-0414 | No/Yes | - | - | v0.8.5.post1 | Yes |
| Glm4MoeForCausalLM | GLM-4.5,GLM-4.5-Air | No/Yes | - | - | v0.9.2 | Yes |
| DeepseekForCausalLM | Deepseek | Yes | No | - | v0.5.0 | Yes | | DeepseekForCausalLM | Deepseek | Yes | No | - | v0.5.0 | Yes |
| DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | No | - | v0.6.2 | Yes | | DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | No | - | v0.6.2 | Yes |
| DeepseekVLV2ForCausalLM | DeepSeek-VL2 | Yes | No | - | v0.7.2 | Yes | | DeepseekVLV2ForCausalLM | DeepSeek-VL2 | Yes | No | - | v0.7.2 | Yes |
......
...@@ -965,7 +965,7 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, ...@@ -965,7 +965,7 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3); vllm::Fp8KVCacheDataType::kFp8E4M3);
} }
} else if (kv_cache_dtype == "fp8_e5m2") { } else if (kv_cache_dtype == "fp8_e5m2") {
if (src_cache.dtype() == at::ScalarType::Float) { if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E5M2); CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (src_cache.dtype() == at::ScalarType::Half) { } else if (src_cache.dtype() == at::ScalarType::Half) {
...@@ -980,7 +980,7 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, ...@@ -980,7 +980,7 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) { } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E5M2); vllm::Fp8KVCacheDataType::kFp8E5M2);
} }
} else { } else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
} }
......
...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt1.rc2.' + sha[:7] version = 'das.opt1.' + sha[:7]
else: else:
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt1.rc2' version = 'das.opt1'
# dtk version # dtk version
......
...@@ -2174,7 +2174,6 @@ def gather_cache(src_cache: torch.Tensor, ...@@ -2174,7 +2174,6 @@ def gather_cache(src_cache: torch.Tensor,
cu_seq_lens, batch_size, seq_starts) cu_seq_lens, batch_size, seq_starts)
#dst_fp8->bf16 #dst_fp8->bf16
convert_fp8(dst, dst_fp8, scale, kv_dtype) convert_fp8(dst, dst_fp8, scale, kv_dtype)
else: else:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts) cu_seq_lens, batch_size, seq_starts)
......
...@@ -943,12 +943,13 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -943,12 +943,13 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_query.dtype, head_size, block_size, gqa_ratio, decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len, self.sliding_window, decode_meta.max_decode_seq_len, self.sliding_window,
self.kv_cache_dtype, self.alibi_slopes) self.kv_cache_dtype, self.alibi_slopes)
if use_custom: max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else != AttentionType.ENCODER_DECODER else
decode_meta.max_encoder_seq_len) decode_meta.max_encoder_seq_len)
assert max_seq_len is not None assert max_seq_len is not None
if use_custom:
max_num_partitions = ( max_num_partitions = (
(max_seq_len + _PARTITION_SIZE_ROCM - 1) // (max_seq_len + _PARTITION_SIZE_ROCM - 1) //
_PARTITION_SIZE_ROCM) _PARTITION_SIZE_ROCM)
...@@ -1002,6 +1003,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -1002,6 +1003,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor
if envs.VLLM_USE_FLASH_ATTN_PA: if envs.VLLM_USE_FLASH_ATTN_PA:
from flash_attn import vllm_flash_attn_with_kvcache from flash_attn import vllm_flash_attn_with_kvcache
if decode_meta.use_cuda_graph:
max_seq_len = 0
if envs.VLLM_USE_PA_PRINT_PARAM: if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA SIZE:") print("PA SIZE:")
print(f"q.shape = {decode_query.unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}, kv_cache_dtype = {self.kv_cache_dtype}") print(f"q.shape = {decode_query.unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}, kv_cache_dtype = {self.kv_cache_dtype}")
...@@ -1024,6 +1027,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -1024,6 +1027,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
k_scale=layer._k_scale, k_scale=layer._k_scale,
v_scale=layer._v_scale, v_scale=layer._v_scale,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
max_seqlen_k=max_seq_len,
).squeeze(1) ).squeeze(1)
else: else:
out_pa[:] = paged_attn.forward_decode( out_pa[:] = paged_attn.forward_decode(
......
...@@ -30,6 +30,7 @@ try: ...@@ -30,6 +30,7 @@ try:
except AttributeError: except AttributeError:
tag_cudagraph_unsafe = () # type: ignore[assignment] tag_cudagraph_unsafe = () # type: ignore[assignment]
class Attention(nn.Module): class Attention(nn.Module):
"""Attention layer. """Attention layer.
...@@ -212,9 +213,9 @@ class Attention(nn.Module): ...@@ -212,9 +213,9 @@ class Attention(nn.Module):
# attn_metadata = get_forward_context().attn_metadata # attn_metadata = get_forward_context().attn_metadata
# #if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)): # #if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
# if key is not None and value is not None: # if key is not None and value is not None:
# self.calc_kv_scales(query, key, value) # self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name) self.layer_name)
if self.use_output: if self.use_output:
output_shape = (output_shape output_shape = (output_shape
if output_shape is not None else query.shape) if output_shape is not None else query.shape)
...@@ -439,6 +440,7 @@ direct_register_custom_op( ...@@ -439,6 +440,7 @@ direct_register_custom_op(
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,) tags=tag_cudagraph_unsafe,)
def unified_attention( def unified_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
......
...@@ -100,7 +100,7 @@ def flash_mla_with_kvcache( ...@@ -100,7 +100,7 @@ def flash_mla_with_kvcache(
softmax_scale = q.shape[-1]**(-0.5) softmax_scale = q.shape[-1]**(-0.5)
if current_platform.is_rocm(): if current_platform.is_rocm():
if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2": if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
kv_dtype = "fp8_e4m3" if kv_cache_dtype == "fp8" else kv_cache_dtype kv_dtype = "fp8_e4m3" if kv_cache_dtype == "fp8" else kv_cache_dtype
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q, q,
k_cache, k_cache,
......
...@@ -326,7 +326,7 @@ class ModelConfig: ...@@ -326,7 +326,7 @@ class ModelConfig:
"""Whether to disable sliding window. If True, we will disable the sliding """Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the window functionality of the model, capping to sliding window size. If the
model does not support sliding window, this argument is ignored.""" model does not support sliding window, this argument is ignored."""
disable_cascade_attn: bool = False disable_cascade_attn: bool = True
"""Disable cascade attention for V1. While cascade attention does not """Disable cascade attention for V1. While cascade attention does not
change the mathematical correctness, disabling it could be useful for change the mathematical correctness, disabling it could be useful for
preventing potential numerical issues. Note that even if this is set to preventing potential numerical issues. Note that even if this is set to
...@@ -418,7 +418,6 @@ class ModelConfig: ...@@ -418,7 +418,6 @@ class ModelConfig:
- "transformers" will use the Transformers model implementation.""" - "transformers" will use the Transformers model implementation."""
override_attention_dtype: Optional[str] = None override_attention_dtype: Optional[str] = None
"""Override dtype for attention""" """Override dtype for attention"""
enable_chunked_prefill: Optional[bool] = None enable_chunked_prefill: Optional[bool] = None
"""If True, prefill requests can be chunked based """If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.""" on the remaining max_num_batched_tokens."""
......
...@@ -18,6 +18,7 @@ from vllm.forward_context import get_forward_context ...@@ -18,6 +18,7 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
...@@ -213,27 +214,83 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -213,27 +214,83 @@ class P2pNcclConnector(KVConnectorBase_V1):
kv_cache_layer = kv_cache[ \ kv_cache_layer = kv_cache[ \
forward_context.virtual_engine] forward_context.virtual_engine]
kv_cache = self.p2p_nccl_engine.recv_tensor( if not envs.VLLM_P2P_ASYNC:
request.request_id + "#" + layer_name) kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name)
if kv_cache is None:
logger.warning("🚧src_kv_cache is None, %s", if kv_cache is None:
request.request_id) logger.warning("🚧src_kv_cache is None, %s",
continue request.request_id)
continue
inject_kv_into_layer(kv_cache_layer, kv_cache,
request.slot_mapping, request.request_id) inject_kv_into_layer(kv_cache_layer, kv_cache,
tensor_id = request.request_id + "#" + layer_name request.slot_mapping, request.request_id)
if tensor_id in self.p2p_nccl_engine.recv_store: tensor_id = request.request_id + "#" + layer_name
tensor = self.p2p_nccl_engine.recv_store.pop(tensor_id, None) if tensor_id in self.p2p_nccl_engine.recv_store:
self.p2p_nccl_engine.send_request_id_to_tensor_ids.pop( tensor = self.p2p_nccl_engine.recv_store.pop(tensor_id, None)
request.request_id, None) self.p2p_nccl_engine.send_request_id_to_tensor_ids.pop(
self.p2p_nccl_engine.recv_request_id_to_tensor_ids.pop( request.request_id, None)
request.request_id, None) self.p2p_nccl_engine.recv_request_id_to_tensor_ids.pop(
addr = 0 request.request_id, None)
if isinstance(tensor, tuple): addr = 0
addr, _, _ = tensor if isinstance(tensor, tuple):
self.p2p_nccl_engine.pool.free(addr) addr, _, _ = tensor
self.p2p_nccl_engine.pool.free(addr)
else:
dst_kv_cache_layer_shape = kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
assert kv_cache_layer.is_contiguous()
dst_kv_cache_layer = kv_cache_layer.reshape(
num_pages * page_size, -1)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
assert kv_cache_layer.is_contiguous()
dst_kv_cache_layer = kv_cache_layer.reshape(
2, num_pages * page_size, -1)
inject_start_index = 0
for num in range(self.p2p_nccl_engine.tensor_split_num):
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name + "#" + str(num))
if kv_cache is None:
logger.warning("🚧src_kv_cache is None, %s",
request.request_id)
continue
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()):
num_token = kv_cache.shape[0]
if len(request.slot_mapping) == num_token:
dst_kv_cache_layer[request.slot_mapping, ...] = kv_cache
else:
dst_kv_cache_layer[request.slot_mapping[inject_start_index:inject_start_index + num_token],
...] = kv_cache
else:
num_token = kv_cache.shape[1]
if len(request.slot_mapping) == num_token:
dst_kv_cache_layer[:, request.slot_mapping, ...] = kv_cache
else:
dst_kv_cache_layer[:, request.slot_mapping[inject_start_index:inject_start_index + num_token],
...] = kv_cache
inject_start_index += num_token
# inject_kv_into_layer(kv_cache_layer, kv_cache,
# request.slot_mapping, request.request_id)
tensor_id = request.request_id + "#" + layer_name + "#" + str(num)
if tensor_id in self.p2p_nccl_engine.recv_store:
tensor = self.p2p_nccl_engine.recv_store.pop(tensor_id, None)
self.p2p_nccl_engine.send_request_id_to_tensor_ids.pop(
request.request_id, None)
self.p2p_nccl_engine.recv_request_id_to_tensor_ids.pop(
request.request_id, None)
addr = 0
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.p2p_nccl_engine.pool.free(addr)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
def wait_for_layer_load(self, layer_name: str) -> None: def wait_for_layer_load(self, layer_name: str) -> None:
...@@ -296,30 +353,29 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -296,30 +353,29 @@ class P2pNcclConnector(KVConnectorBase_V1):
request.slot_mapping_device = \ request.slot_mapping_device = \
request.slot_mapping.pin_memory().to(device=kv_layer.device, non_blocking=True) request.slot_mapping.pin_memory().to(device=kv_layer.device, non_blocking=True)
slot_mapping = request.slot_mapping_device slot_mapping = request.slot_mapping_device
kv_cache = extract_kv_from_layer(kv_layer, slot_mapping)
tbo_evt = torch.cuda.Event(enable_timing=False) tbo_evt = torch.cuda.Event(enable_timing=False)
tbo_evt.record() tbo_evt.record()
pp_rank = (self.parallel_config.rank // pp_rank = (self.parallel_config.rank //
self.parallel_config.tensor_parallel_size) % \ self.parallel_config.tensor_parallel_size) % \
self.parallel_config.pipeline_parallel_size self.parallel_config.pipeline_parallel_size
if (self.pp_size == 1): if (self.pp_size == 1):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address, tbo_evt) (kv_layer, slot_mapping), remote_address, tbo_evt)
elif (self.pp_size == 2): elif (self.pp_size == 2):
if (pp_rank == 0): if (pp_rank == 0):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address, tbo_evt) (kv_layer, slot_mapping), remote_address, tbo_evt)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank + 4), tbo_evt) (kv_layer, slot_mapping), ip + ":" + str(port + self._rank + 4), tbo_evt)
else: else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address, tbo_evt) (kv_layer, slot_mapping), remote_address, tbo_evt)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + self._rank - 4), tbo_evt) (kv_layer, slot_mapping), ip + ":" + str(port + self._rank - 4), tbo_evt)
elif (self.pp_size == 8): elif (self.pp_size == 8):
for i in range(8): for i in range(8):
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
kv_cache, ip + ":" + str(port + i), tbo_evt) (kv_layer, slot_mapping), ip + ":" + str(port + i), tbo_evt)
else: else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!") print("Error: only suppprt pp1 pp2 pp8!!!!!!")
else: else:
......
...@@ -21,6 +21,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ...@@ -21,6 +21,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import
TensorMemoryPool) TensorMemoryPool)
from vllm.utils import current_stream, get_ip from vllm.utils import current_stream, get_ip
from vllm import envs from vllm import envs
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
...@@ -61,7 +62,7 @@ def set_p2p_nccl_context(num_channels: str): ...@@ -61,7 +62,7 @@ def set_p2p_nccl_context(num_channels: str):
class P2pNcclEngine: class P2pNcclEngine:
def __init__(self, def __init__(self,
local_rank: int, local_rank: int,
config: KVTransferConfig, config: KVTransferConfig,
...@@ -111,8 +112,11 @@ class P2pNcclEngine: ...@@ -111,8 +112,11 @@ class P2pNcclEngine:
self.recv_store_cv = threading.Condition() self.recv_store_cv = threading.Condition()
self.send_stream = torch.cuda.Stream() self.send_stream = torch.cuda.Stream()
# self.send_stream = tbo_all_reduce_stream
self.recv_stream = torch.cuda.Stream() self.recv_stream = torch.cuda.Stream()
self.p2p_async_kv_tokens = envs.VLLM_P2P_BUF_TOKENS
self.p2p_async_buf = None
self.tensor_split_num: int = 0
mem_pool_size_gb = self.config.get_from_extra_config( mem_pool_size_gb = self.config.get_from_extra_config(
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB) "mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB)
...@@ -208,7 +212,54 @@ class P2pNcclEngine: ...@@ -208,7 +212,54 @@ class P2pNcclEngine:
return self._send_sync(tensor_id, tensor, remote_address) return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC": elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv: with self.send_queue_cv:
self.send_queue.append([tensor_id, remote_address, tensor, tbo_evt]) self.send_queue.append([tensor_id, remote_address, tensor])
self.send_queue_cv.notify()
else: # GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store))
oldest_tenser = self.send_store.pop(oldest_tenser_id)
oldest_tenser_size = oldest_tenser.element_size(
) * oldest_tenser.numel()
self.buffer_size -= oldest_tenser_size
logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
return True
def p2p_async_send_tensor(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
tbo_evt = None,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify()
return True
else:
if self.send_type == "PUT":
return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
kv_layer, slot_mapping = tensor # tesor (kv_layer, slot_mapping)
self.send_queue.append([tensor_id, remote_address, kv_layer, slot_mapping, tbo_evt])
self.send_queue_cv.notify() self.send_queue_cv.notify()
else: # GET else: # GET
with self.send_store_cv: with self.send_store_cv:
...@@ -313,6 +364,8 @@ class P2pNcclEngine: ...@@ -313,6 +364,8 @@ class P2pNcclEngine:
self.zmq_address, remote_address.decode(), rank) self.zmq_address, remote_address.decode(), rank)
elif data["cmd"] == "PUT": elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"] tensor_id = data["tensor_id"]
if "tensor_split_num" in data:
self.tensor_split_num = data["tensor_split_num"]
try: try:
with torch.cuda.stream(self.recv_stream): with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"], tensor = torch.empty(data["shape"],
...@@ -392,12 +445,17 @@ class P2pNcclEngine: ...@@ -392,12 +445,17 @@ class P2pNcclEngine:
with self.send_queue_cv: with self.send_queue_cv:
while not self.send_queue: while not self.send_queue:
self.send_queue_cv.wait() self.send_queue_cv.wait()
tensor_id, remote_address, tensor, tbo_evt = self.send_queue.popleft() if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC:
tensor_id, remote_address, kv_layer, slot_mapping, tbo_evt = self.send_queue.popleft()
else:
tensor_id, remote_address, tensor = self.send_queue.popleft()
if not self.send_queue: if not self.send_queue:
self.send_queue_cv.notify() self.send_queue_cv.notify()
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None: if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.send_stream.wait_event(tbo_evt) self.send_stream.wait_event(tbo_evt)
self._send_sync(tensor_id, tensor, remote_address) self._send_kv_p2p_sync(tensor_id, kv_layer, slot_mapping, remote_address)
else:
self._send_sync(tensor_id, tensor, remote_address)
def wait_for_sent(self): def wait_for_sent(self):
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
...@@ -409,6 +467,75 @@ class P2pNcclEngine: ...@@ -409,6 +467,75 @@ class P2pNcclEngine:
logger.debug( logger.debug(
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d", duration * 1000, self.rank) " to be empty, rank:%d", duration * 1000, self.rank)
def _send_kv_p2p_sync(self, tensor_id: str, kv_layer: torch.Tensor,
slot_mapping: torch.Tensor, remote_address: str) -> bool:
if remote_address not in self.socks:
self._create_connect(remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
is_mla = (kv_layer.ndim == 3)
hidden_dim = kv_layer.shape[-1]
if self.p2p_async_buf is None:
if is_mla:
self.p2p_async_buf = torch.empty((self.p2p_async_kv_tokens, hidden_dim),
dtype=kv_layer.dtype, device=kv_layer.device)
else:
self.p2p_async_buf = torch.empty((2, self.p2p_async_kv_tokens, hidden_dim),
dtype=kv_layer.dtype, device=kv_layer.device)
pack_num = (slot_mapping.shape[0] - 1) // self.p2p_async_kv_tokens + 1
self.tensor_split_num = pack_num
with torch.cuda.stream(self.send_stream):
for pack_idx in range(pack_num):
start = pack_idx * self.p2p_async_kv_tokens
end = min((pack_idx + 1) * self.p2p_async_kv_tokens, slot_mapping.shape[0])
sub_index = slot_mapping[start:end]
if is_mla:
num_pages, page_size = kv_layer.shape[0], kv_layer.shape[1]
data = kv_layer.reshape(num_pages * page_size, -1)
torch.index_select(data, dim=0, index=sub_index, out=self.p2p_async_buf[:end-start])
tx_shape = (end - start, hidden_dim)
else:
num_pages, page_size = kv_layer.shape[1], kv_layer.shape[2]
data = kv_layer.reshape(2, num_pages * page_size, -1)
torch.index_select(data, dim=1, index=sub_index, out=self.p2p_async_buf[:, :end-start])
tx_shape = (2, end - start, hidden_dim)
if is_mla:
send_tensor = self.p2p_async_buf[:end-start]
else:
send_tensor = self.p2p_async_buf[:, :end-start]
header = {
"cmd": "PUT",
"tensor_id": tensor_id + "#" + str(pack_idx), # 拼 pack_idx
"pack_idx": pack_idx,
"tensor_split_num": pack_num,
"shape": tx_shape,
"dtype": str(kv_layer.dtype).replace("torch.", "")
}
sock.send(msgpack.dumps(header))
response = sock.recv()
if response != b"0":
logger.error(
"🔴Send Tensor Failed | %s 👉 %s | Rank:%s | shape:%s | size:%.4f GB | response:%s",
self.zmq_address, remote_address, rank,
tuple(send_tensor.shape), send_tensor.element_size() * send_tensor.numel() / 1024**3,
response.decode()
)
return False
self._send(comm, send_tensor, rank ^ 1, self.send_stream)
if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id)
return True
def _send_sync( def _send_sync(
self, self,
......
...@@ -1610,7 +1610,7 @@ class EngineArgs: ...@@ -1610,7 +1610,7 @@ class EngineArgs:
action = "Enabling" if \ action = "Enabling" if \
incremental_prefill_supported else "Disabling" incremental_prefill_supported else "Disabling"
if model_config.enable_chunked_prefill is not None and \ if model_config.enable_chunked_prefill is not None and \
model_config.enable_chunked_prefill is False: model_config.enable_chunked_prefill is False:
self.enable_chunked_prefill = False self.enable_chunked_prefill = False
......
...@@ -166,11 +166,15 @@ if TYPE_CHECKING: ...@@ -166,11 +166,15 @@ if TYPE_CHECKING:
VLLM_USE_GLOBAL_CACHE13: bool = False VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHTOP: bool = False VLLM_USE_LIGHTOP: bool = False
VLLM_USE_OPT_CAT: bool = False VLLM_USE_OPT_CAT: bool = False
VLLM_USE_OPT_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_USE_MORI_EP: bool = False VLLM_USE_MORI_EP: bool = False
VLLM_P2P_ASYNC: bool = False VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1104,6 +1108,18 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1104,6 +1108,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_OPT_CAT": "VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use triton moe_sum
"VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")),
# vLLM will use lightop moe_sum
"VLLM_USE_LIGHTOP_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "True").lower() in
("true", "1")),
# vLLM will use lightop moe_align_block_size
"VLLM_USE_LIGHTOP_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "True").lower() in
("true", "1")),
# vLLM will use opt merge_aatn_states, not triton # vLLM will use opt merge_aatn_states, not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
...@@ -1131,6 +1147,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1131,6 +1147,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm pd separation will be used async # vllm pd separation will be used async
"VLLM_P2P_ASYNC": "VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))), lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
# pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS":
lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -30,11 +30,11 @@ try: ...@@ -30,11 +30,11 @@ try:
except ImportError: except ImportError:
is_mori_available = False is_mori_available = False
logger = init_logger(__name__) logger = init_logger(__name__)
_MORI_OP = None _MORI_OP = None
@CustomOp.register("unquantized_ep_moe") @CustomOp.register("unquantized_ep_moe")
class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
"""MoE method without quantization.""" """MoE method without quantization."""
...@@ -44,20 +44,20 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -44,20 +44,20 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self.topk_indices_dtype = None self.topk_indices_dtype = None
self.moe = moe self.moe = moe
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
def apply_ep( def apply_ep(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace=True, inplace=True,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward( return self.forward(
...@@ -73,17 +73,17 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -73,17 +73,17 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
def forward_cuda( def forward_cuda(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace=True, inplace=True,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
# process MoE # process MoE
...@@ -109,48 +109,48 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod): ...@@ -109,48 +109,48 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
return output return output
def forward_cpu( def forward_cpu(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace=True, inplace=True,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
**kwargs, **kwargs,
): ):
raise NotImplementedError raise NotImplementedError
def forward_hpu( def forward_hpu(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace=True, inplace=True,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def forward_tpu( def forward_tpu(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace=True, inplace=True,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -167,49 +167,50 @@ class EPMoE(FusedMoE): ...@@ -167,49 +167,50 @@ class EPMoE(FusedMoE):
dp+ep MoE Expert Parallel Impl dp+ep MoE Expert Parallel Impl
""" """
def __init__( def __init__(
self, self,
num_experts: int, # Global number of experts num_experts: int, # Global number of experts
top_k: int, top_k: int,
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False, reduce_results: bool = False,
renormalize: bool = True, renormalize: bool = True,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
ep_size: Optional[int] = None, ep_size: Optional[int] = None,
dp_size: Optional[int] = None, dp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
enable_eplb: bool = False, enable_eplb: bool = False,
num_redundant_experts: int = 0, num_redundant_experts: int = 0,
moe_permute_fusion: bool = False, moe_permute_fusion: bool = False,
moe_shared_expert_overlap: bool = False moe_shared_expert_overlap: bool = False
): ):
super().__init__(num_experts, top_k, hidden_size, super().__init__(num_experts, top_k, hidden_size,
intermediate_size, params_dtype, intermediate_size, params_dtype,
reduce_results, renormalize, reduce_results, renormalize,
use_grouped_topk, num_expert_group, use_grouped_topk, num_expert_group,
topk_group, quant_config, tp_size, topk_group, quant_config, tp_size,
ep_size, dp_size, prefix, ep_size, dp_size, prefix,
custom_routing_function, scoring_func, custom_routing_function, scoring_func,
e_score_correction_bias, e_score_correction_bias,
apply_router_weight_on_input, apply_router_weight_on_input,
activation, activation,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
num_redundant_experts=num_redundant_experts, num_redundant_experts=num_redundant_experts,
) )
self.ep_moe_config: EpMoeConfig = EpMoeConfig.make( self.ep_moe_config: EpMoeConfig = EpMoeConfig.make(
moe_router_topk=self.top_k, moe_router_topk=self.top_k,
# TODO: support fusion permute # TODO: support fusion permute
...@@ -222,31 +223,42 @@ class EPMoE(FusedMoE): ...@@ -222,31 +223,42 @@ class EPMoE(FusedMoE):
) )
local_expert_indices_offset = ( local_expert_indices_offset = (
self.ep_rank * self.local_num_experts self.ep_rank * self.local_num_experts
) )
self.local_expert_indices = [ self.local_expert_indices = [
local_expert_indices_offset + i for i in range(self.local_num_experts) local_expert_indices_offset + i for i in range(self.local_num_experts)
] ]
self.use_shared_expert = False self.use_shared_expert = False
<<<<<<< HEAD
# self.token_dispatcher = MoEAlltoAllTokenDispatcher( # self.token_dispatcher = MoEAlltoAllTokenDispatcher(
# self.local_num_experts, self.local_expert_indices, # self.local_num_experts, self.local_expert_indices,
# config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher", # config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher",
# ) # )
=======
self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.local_num_experts, self.local_expert_indices,
config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher",
)
>>>>>>> origin/v0.9.2-dev-ds
self.shared_expert_overlap = moe_shared_expert_overlap self.shared_expert_overlap = moe_shared_expert_overlap
self.shared_experts = None self.shared_experts = None
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
self.scales = None self.scales = None
self.use_int8_dispatch = True self.use_int8_dispatch = True
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.max_num_inp_token_per_rank = vllm_config.scheduler_config.max_num_seqs self.max_num_inp_token_per_rank = vllm_config.scheduler_config.max_num_seqs
self.mori_op = self.get_mori_op() self.mori_op = self.get_mori_op()
<<<<<<< HEAD
=======
self.first = True
>>>>>>> origin/v0.9.2-dev-ds
def get_mori_op(self): def get_mori_op(self):
global _MORI_OP global _MORI_OP
if _MORI_OP is None: if _MORI_OP is None:
...@@ -258,14 +270,14 @@ class EPMoE(FusedMoE): ...@@ -258,14 +270,14 @@ class EPMoE(FusedMoE):
# assert world_group is not None # assert world_group is not None
# torch._C._distributed_c10d._register_process_group("default", world_group) # torch._C._distributed_c10d._register_process_group("default", world_group)
# mori.shmem.shmem_torch_process_group_init("default") # mori.shmem.shmem_torch_process_group_init("default")
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
multi_node = self.ep_size / 8 > 1 multi_node = self.ep_size / 8 > 1
mori_data_type=vllm_config.model_config.dtype mori_data_type = vllm_config.model_config.dtype
mori_scale_type_size = vllm_config.model_config.dtype.itemsize mori_scale_type_size = vllm_config.model_config.dtype.itemsize
if self.use_int8_dispatch: if self.use_int8_dispatch:
mori_scale_type_size = 4 mori_scale_type_size = 4
config = mori.ops.EpDispatchCombineConfig( config = mori.ops.EpDispatchCombineConfig(
data_type=mori_data_type, data_type=mori_data_type,
...@@ -280,12 +292,12 @@ class EPMoE(FusedMoE): ...@@ -280,12 +292,12 @@ class EPMoE(FusedMoE):
max_token_type_size=2, max_token_type_size=2,
block_num=80, block_num=80,
warp_num_per_block=16, warp_num_per_block=16,
#kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode # kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \ kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode if multi_node else \
mori.ops.EpDispatchCombineKernelType.IntraNode mori.ops.EpDispatchCombineKernelType.IntraNode
) )
_MORI_OP = mori.ops.EpDispatchCombineOp(config) _MORI_OP = mori.ops.EpDispatchCombineOp(config)
return _MORI_OP return _MORI_OP
def set_shared_experts(self, shared_experts: torch.nn.Module): def set_shared_experts(self, shared_experts: torch.nn.Module):
...@@ -305,15 +317,19 @@ class EPMoE(FusedMoE): ...@@ -305,15 +317,19 @@ class EPMoE(FusedMoE):
assert quant_method is not None assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase) assert isinstance(quant_method, FusedMoEMethodBase)
return quant_method return quant_method
def sync(self): def sync(self):
<<<<<<< HEAD
torch.cuda.synchronize() torch.cuda.synchronize()
=======
# torch.cuda.synchronize()
>>>>>>> origin/v0.9.2-dev-ds
dist.barrier() dist.barrier()
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor):
return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits, return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits,
self.layer_name) self.layer_name)
def get_expert_weights(self) -> Iterable[torch.Tensor]: def get_expert_weights(self) -> Iterable[torch.Tensor]:
weights = list(self.named_parameters()) weights = list(self.named_parameters())
...@@ -332,30 +348,29 @@ class EPMoE(FusedMoE): ...@@ -332,30 +348,29 @@ class EPMoE(FusedMoE):
return [ return [
weight.view(self.local_num_experts, -1) for name, weight in weights weight.view(self.local_num_experts, -1) for name, weight in weights
if name not in NON_EXPERT_WEIGHTS if name not in NON_EXPERT_WEIGHTS
] ]
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor):
topk_weights, topk_ids = self.select_experts( topk_weights, topk_ids = self.select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
top_k=self.top_k, top_k=self.top_k,
renormalize=self.renormalize, renormalize=self.renormalize,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
indices_type=torch.int32, indices_type=torch.int32,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate) use_fused_gate=self.use_fused_gate)
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
if self.use_int8_dispatch: if self.use_int8_dispatch:
hidden_states, scales = per_token_quant_int8(hidden_states) hidden_states, scales = per_token_quant_int8(hidden_states)
else: else:
...@@ -368,23 +383,48 @@ class EPMoE(FusedMoE): ...@@ -368,23 +383,48 @@ class EPMoE(FusedMoE):
) )
scales = self.scales scales = self.scales
# self.sync()
#self.sync()
( (
dispatch_output, dispatch_output,
dispatch_weights, dispatch_weights,
dispatch_scales, dispatch_scales,
dispatch_indices, dispatch_indices,
dispatch_recv_num_token, dispatch_recv_num_token,
) = self.mori_op.dispatch( ) = self.mori_op.dispatch(
hidden_states, hidden_states,
topk_weights, topk_weights,
scales, scales,
topk_ids, topk_ids,
layer_idx=int(self.layer_name.split('.')[2]) #layer_idx=int(self.layer_name.split('.')[2])
) )
<<<<<<< HEAD
#self.sync() #self.sync()
=======
# self.sync()
# expect_m = topk_ids.shape[0] * self.ep_size
# dispatch_output_clip = dispatch_output[:expect_m]
# dispatch_weights_clip = dispatch_weights[:expect_m]
# dispatch_indices_clip = dispatch_indices[:expect_m]
# dispatch_scales_clip = dispatch_scales[:expect_m]
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output_clip,
# topk_weights=dispatch_weights_clip,
# topk_ids=dispatch_indices_clip,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales_clip if self.use_int8_dispatch else None
# #routed_scaling_factor=self.routed_scaling_factor,
# )
>>>>>>> origin/v0.9.2-dev-ds
expert_output = self.quant_method.apply_ep( expert_output = self.quant_method.apply_ep(
layer=self, layer=self,
...@@ -399,32 +439,33 @@ class EPMoE(FusedMoE): ...@@ -399,32 +439,33 @@ class EPMoE(FusedMoE):
num_local_tokens=dispatch_recv_num_token, num_local_tokens=dispatch_recv_num_token,
config_select_bs=hidden_states.shape[0]*self.ep_size/self.dp_size, config_select_bs=hidden_states.shape[0]*self.ep_size/self.dp_size,
scales=dispatch_scales if self.use_int8_dispatch else None scales=dispatch_scales if self.use_int8_dispatch else None
#routed_scaling_factor=self.routed_scaling_factor, # routed_scaling_factor=self.routed_scaling_factor,
) )
#self.sync() # self.sync()
combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids) combine_output, _ = self.mori_op.combine(expert_output, dispatch_weights, topk_ids)
final_hidden_states = combine_output[:hidden_states.shape[0], :] final_hidden_states = combine_output[:hidden_states.shape[0], :]
#self.sync() # self.sync()
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None: if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
# shared_output = ( # shared_output = (
# self.maybe_all_reduce_tensor_model_parallel( # self.maybe_all_reduce_tensor_model_parallel(
# shared_output)) # shared_output))
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick: if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
else: else:
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
return final_hidden_states return final_hidden_states
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None assert self.quant_method is not None
...@@ -433,7 +474,7 @@ def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, ...@@ -433,7 +474,7 @@ def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -443,5 +484,5 @@ direct_register_custom_op( ...@@ -443,5 +484,5 @@ direct_register_custom_op(
mutates_args=["hidden_states", "router_logits"], mutates_args=["hidden_states", "router_logits"],
fake_impl=ep_moe_forward_fake, fake_impl=ep_moe_forward_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order,),
) )
\ No newline at end of file
...@@ -181,7 +181,6 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -181,7 +181,6 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self.use_all_gather = current_platform.use_all_gather() self.use_all_gather = current_platform.use_all_gather()
self.probs = None self.probs = None
self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
# For smuggling this layer into the fused moe custom op # For smuggling this layer into the fused moe custom op
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
...@@ -446,7 +445,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -446,7 +445,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
if self.config.moe_shared_expert_overlap and self.shared_experts is not None: if self.config.moe_shared_expert_overlap and self.shared_experts is not None:
shared_output = self.shared_experts.get_output() shared_output = self.shared_experts.get_output()
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick: if hidden_states.dtype != torch.float16:
output = output + shared_output output = output + shared_output
else: else:
# Fix FP16 overflow # Fix FP16 overflow
......
...@@ -45,14 +45,141 @@ from lightop import op ...@@ -45,14 +45,141 @@ from lightop import op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled # from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
os.environ['DPSK_FP16_QUICK'] = os.environ.get('DPSK_FP16_QUICK', '0')
dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
logger = init_logger(__name__) logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None moe_cache_singleton = None
@torch.compile
def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
torch.sum(x, dim=1, out=out)
out.mul_(routed_scaling_factor)
@triton.jit
def _moe_sum_reduce_kernel(
input_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
output_ptr,
output_stride_0,
output_stride_1,
token_num: int,
topk_num: int,
hidden_dim: int,
routed_scaling_factor: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
token_block_id = tl.program_id(0)
dim_block_id = tl.program_id(1)
token_start = token_block_id * BLOCK_M
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
dim_start = dim_block_id * BLOCK_DIM
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
for token_index in range(token_start, token_end):
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tmp = tl.load(
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
)
accumulator += tmp
accumulator = accumulator * routed_scaling_factor
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
tl.store(
store_t_ptr,
accumulator.to(input_ptr.dtype.element_ty),
mask=offs_dim < dim_end,
)
def moe_sum_reduce_triton(
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
):
assert input.is_contiguous()
assert output.is_contiguous()
token_num, topk_num, hidden_dim = input.shape
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
if token_num <= 32:
BLOCK_M = 1
BLOCK_DIM = 512
NUM_STAGE = 2
num_warps = 4
elif token_num <= 128:
BLOCK_M = 1
BLOCK_DIM = 1024
NUM_STAGE = 0
num_warps = 2
elif token_num <= 4096:
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 0
num_warps = 2
else:
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 2
num_warps = 8
grid = (
triton.cdiv(token_num, BLOCK_M),
triton.cdiv(hidden_dim, BLOCK_DIM),
)
_moe_sum_reduce_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
token_num=token_num,
topk_num=topk_num,
hidden_dim=hidden_dim,
routed_scaling_factor=routed_scaling_factor,
BLOCK_M=BLOCK_M,
BLOCK_DIM=BLOCK_DIM,
NUM_STAGE=NUM_STAGE,
num_warps=num_warps,
)
return
def moe_reduce_dispatch(
intermediate_cache3: torch.Tensor,
out_hidden_states: torch.Tensor,
begin_chunk_idx: int,
end_chunk_idx: int,
):
inter_cache_view = intermediate_cache3.view(*intermediate_cache3.shape)
n = intermediate_cache3.shape[0]
# 根据 n 大小选择不同的 reduce 实现
if 1 <= n <= 4:
moe_sum_reduce_torch_compile(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx], 1.0)
elif 4 < n <= 1024:
moe_sum_reduce_triton(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx], 1.0)
elif 1024 < n <= 32768:
ops.moe_sum_opt1(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx])
else:
ops.moe_sum(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx])
def get_moe_cache(top_k_num,N,K,device,dtype): def get_moe_cache(top_k_num,N,K,device,dtype):
global moe_cache_singleton global moe_cache_singleton
if moe_cache_singleton is None: if moe_cache_singleton is None:
...@@ -1266,7 +1393,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1266,7 +1393,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> None: routed_scaling_factor: Optional[float] = 1.0) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8, activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8,
...@@ -1300,7 +1427,7 @@ def inplace_fused_experts_fake( ...@@ -1300,7 +1427,7 @@ def inplace_fused_experts_fake(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> None: routed_scaling_factor: Optional[float] = 1.0) -> None:
pass pass
...@@ -1338,7 +1465,7 @@ def outplace_fused_experts( ...@@ -1338,7 +1465,7 @@ def outplace_fused_experts(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor: routed_scaling_factor: Optional[float] = 1.0) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input, False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
...@@ -1372,7 +1499,7 @@ def outplace_fused_experts_fake( ...@@ -1372,7 +1499,7 @@ def outplace_fused_experts_fake(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor: routed_scaling_factor: Optional[float] = 1.0) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1431,7 +1558,7 @@ def fused_experts( ...@@ -1431,7 +1558,7 @@ def fused_experts(
allow_cutlass_block_scaled_grouped_gemm: bool = False, allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor: routed_scaling_factor: Optional[float] = 1.0) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better # For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available. # permute/unpermute ops are available.
N = w1.size(1) N = w1.size(1)
...@@ -1520,7 +1647,7 @@ def fused_experts_impl( ...@@ -1520,7 +1647,7 @@ def fused_experts_impl(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
if use_nn_moe: if use_nn_moe:
...@@ -1769,7 +1896,7 @@ def fused_experts_impl( ...@@ -1769,7 +1896,7 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
if envs.VLLM_USE_LIGHTOP and not dpsk_fp16_quick: if envs.VLLM_USE_LIGHTOP:
from lightop import op as op from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()), op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=shared_output[begin_chunk_idx:end_chunk_idx], output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=shared_output[begin_chunk_idx:end_chunk_idx],
...@@ -1789,8 +1916,16 @@ def fused_experts_impl( ...@@ -1789,8 +1916,16 @@ def fused_experts_impl(
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), # ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor # out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor
else: else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), if envs.VLLM_USE_LIGHTOP_MOE_SUM:
out_hidden_states[begin_chunk_idx:end_chunk_idx]) from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=None,
expert_mask=None, num_local_tokens=None, factor=1.0)
elif envs.VLLM_USE_OPT_MOE_SUM:
moe_reduce_dispatch(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx], begin_chunk_idx, end_chunk_idx)
else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states return out_hidden_states
...@@ -1825,7 +1960,7 @@ def fused_moe( ...@@ -1825,7 +1960,7 @@ def fused_moe(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
......
...@@ -376,7 +376,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -376,7 +376,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
...@@ -423,7 +423,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -423,7 +423,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation: str = "silu", activation: str = "silu",
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -487,7 +487,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -487,7 +487,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
**kwargs, **kwargs,
): ):
...@@ -527,7 +527,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -527,7 +527,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
assert not use_grouped_topk assert not use_grouped_topk
...@@ -683,7 +683,7 @@ class FusedMoE(torch.nn.Module): ...@@ -683,7 +683,7 @@ class FusedMoE(torch.nn.Module):
activation: str = "silu", activation: str = "silu",
enable_eplb: bool = False, enable_eplb: bool = False,
num_redundant_experts: int = 0, num_redundant_experts: int = 0,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
): ):
super().__init__() super().__init__()
if params_dtype is None: if params_dtype is None:
...@@ -1269,7 +1269,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1269,7 +1269,7 @@ class FusedMoE(torch.nn.Module):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = 1.0,
use_fused_gate: Optional[bool] = False use_fused_gate: Optional[bool] = False
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
......
...@@ -240,8 +240,16 @@ def moe_align_block_size( ...@@ -240,8 +240,16 @@ def moe_align_block_size(
expert_mask = expert_mask, expert_mask = expert_mask,
num_local_tokens = None) num_local_tokens = None)
else: else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, if envs.VLLM_USE_LIGHTOP_MOE_ALIGN:
expert_ids, num_tokens_post_pad) from lightop import op as op
op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad,
expert_map = None,
expert_mask = None,
num_local_tokens = None)
else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
if expert_map is not None: if expert_map is not None:
expert_ids = expert_map[expert_ids] expert_ids = expert_map[expert_ids]
......
...@@ -514,6 +514,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -514,6 +514,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
......
...@@ -473,6 +473,7 @@ class BlockInt8MoEMethod: ...@@ -473,6 +473,7 @@ class BlockInt8MoEMethod:
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
**_ **_
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -514,5 +515,7 @@ class BlockInt8MoEMethod: ...@@ -514,5 +515,7 @@ class BlockInt8MoEMethod:
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size, block_shape=self.quant_config.weight_block_size,
use_nn_moe=use_nn_moe use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
\ No newline at end of file
...@@ -348,6 +348,7 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -348,6 +348,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
...@@ -430,7 +431,9 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -430,7 +431,9 @@ class MoeWNA16Method(FusedMoEMethodBase):
w1_zp=layer.w13_qzeros if has_zp else None, w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None, w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size], block_shape=[0, layer.group_size],
use_nn_moe=False) use_nn_moe=False,
shared_output=shared_output,
)
@staticmethod @staticmethod
def get_weight_loader(layer, weight_loader): def get_weight_loader(layer, weight_loader):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment