Commit f9f1887d authored by 王敏's avatar 王敏
Browse files

merge 092-dev分支近期修改

parent 415b817b
......@@ -164,6 +164,10 @@ if TYPE_CHECKING:
VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False
VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHT_OP: bool = False
VLLM_USE_TRITON_CAT: bool = False
USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
VLLM_USE_ALLTOALL_EP: bool = False
def get_default_cache_root():
......@@ -1090,6 +1094,23 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")),
# vLLM will use global cache for moe
"VLLM_USE_LIGHT_OP":
lambda: (os.environ.get("VLLM_USE_LIGHT_OP", "True").lower() in
("true", "1")),
# vLLM will use global cache for moe
"VLLM_USE_TRITON_CAT":
lambda: (os.environ.get("VLLM_USE_TRITON_CAT", "True").lower() in
("true", "1")),
# vLLM will use opt merge_aatn_states,not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")),
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in
("true", "1")),
# vLLM will use all_to_all ep mode
"VLLM_USE_ALLTOALL_EP":
lambda: (os.environ.get("VLLM_USE_ALLTOALL_EP", "True").lower() in
......
......@@ -12,7 +12,7 @@ from vllm.platforms import current_platform
from vllm.config import get_current_vllm_config
from vllm.model_executor.custom_op import CustomOp
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.distributed.parallel_state import get_ep_group, get_node_count, is_use_cuda_graph
from vllm.distributed.parallel_state import get_ep_group, get_node_count
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
......@@ -362,10 +362,10 @@ class EPMoE(FusedMoE):
#dispatch_recv_num_token = dispatch_recv_num_token[0].item()
dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
dispatch_output = dispatch_output[:dispatch_recv_num_token]
dispatch_weights = dispatch_weights[:dispatch_recv_num_token]
dispatch_indices = dispatch_indices[:dispatch_recv_num_token]
# dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
# dispatch_output = dispatch_output[:dispatch_recv_num_token]
# dispatch_weights = dispatch_weights[:dispatch_recv_num_token]
# dispatch_indices = dispatch_indices[:dispatch_recv_num_token]
# valid_mask = ((dispatch_indices <= 255) & (dispatch_indices >= 0)).all(dim=1)
......
......@@ -216,6 +216,42 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
w2_marlin_list.append(w2_marlin_in)
layer.w2_weight = Parameter(torch.stack(w2_marlin_list, dim=0), requires_grad=False)
def apply_ep( #dp+ep
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
**_
) -> torch.Tensor:
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
)
def apply(
self,
layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment