Unverified Commit 1248e850 authored by Wenxiang's avatar Wenxiang Committed by GitHub
Browse files

[Model] Adding support for MSFT Phi-3.5-MoE (#7729)


Co-authored-by: default avatarYour Name <you@example.com>
Co-authored-by: default avatarZeqi Lin <zelin@microsoft.com>
Co-authored-by: default avatarZeqi Lin <Zeqi.Lin@microsoft.com>
parent 2684efc4
...@@ -147,6 +147,10 @@ Decoder-only Language Models ...@@ -147,6 +147,10 @@ Decoder-only Language Models
- Phi-3-Small - Phi-3-Small
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc. - :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
- -
* - :code:`PhiMoEForCausalLM`
- Phi-3.5-MoE
- :code:`microsoft/Phi-3.5-MoE-instruct`, etc.
-
* - :code:`PersimmonForCausalLM` * - :code:`PersimmonForCausalLM`
- Persimmon - Persimmon
- :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc. - :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc.
......
"""Compare the outputs of HF and vLLM for moe models using greedy sampling.
Run `pytest tests/models/test_phimoe.py`.
"""
import pytest
import torch
from vllm.utils import is_cpu
from .utils import check_logprobs_close
MODELS = [
"microsoft/Phi-3.5-MoE-instruct",
]
def test_phimoe_routing_function():
from vllm.model_executor.models.phimoe import phimoe_routing_function
test_case = {
0: {
"hidden_states":
torch.tensor([1, 2, 3, 4, 5, 6, 7, 8],
dtype=torch.float32,
requires_grad=False).view(4, 2),
"gating_output":
torch.tensor([0.1, 0.2, 0.3, 0.4],
dtype=torch.float32,
requires_grad=False),
"topk":
2,
"renormalize":
False,
},
1: {
"hidden_states":
torch.tensor([1, 2, 3, 4, 5, 6, 7, 8],
dtype=torch.float32,
requires_grad=False).view(4, 2),
"gating_output":
torch.tensor([0.4, 0.2, 0.3, 0.4],
dtype=torch.float32,
requires_grad=False),
"topk":
2,
"renormalize":
False,
}
}
ground_truth = {
0: {
"topk_weights":
torch.tensor([1., 1.], dtype=torch.float32, requires_grad=False),
"topk_ids":
torch.tensor([3, 2], dtype=torch.long, requires_grad=False),
},
1: {
"topk_weights":
torch.tensor([0.5, 1.], dtype=torch.float32, requires_grad=False),
"topk_ids":
torch.tensor([0, 3], dtype=torch.long, requires_grad=False),
}
}
for test_id in test_case:
topk_weights, topk_ids = phimoe_routing_function(**test_case[test_id])
assert torch.allclose(topk_weights,
ground_truth[test_id]["topk_weights"])
assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"])
def get_gpu_memory():
try:
props = torch.cuda.get_device_properties(torch.cuda.current_device())
gpu_memory = props.total_memory / (1024**3)
return gpu_memory
except Exception:
return 0
@pytest.mark.skipif(condition=is_cpu(),
reason="This test takes a lot time to run on CPU, "
"and vllm CI's disk space is not enough for this model.")
@pytest.mark.skipif(condition=get_gpu_memory() < 100,
reason="Skip this test if GPU memory is insufficient.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
{
"3328": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2560": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"2816": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"3584": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3840": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1280": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"2304": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
}
}
\ No newline at end of file
{
"3840": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3584": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"2816": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1280": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3328": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"2560": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"2304": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
}
}
\ No newline at end of file
{
"2048": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3328": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2560": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"768": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"2816": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2304": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2
},
"1280": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"3840": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3584": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}
\ No newline at end of file
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import functools import functools
import json import json
import os import os
from typing import Any, Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple
import torch import torch
import triton import triton
...@@ -446,7 +446,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -446,7 +446,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
rand_perm1: torch.Tensor, rand_perm1: torch.Tensor,
rand_perm2: torch.Tensor, rand_perm2: torch.Tensor,
topk: int, topk: int,
renormalize: bool, custom_routing_function: Optional[Callable] = None,
renormalize: bool = True,
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False, use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -497,8 +498,12 @@ def fused_marlin_moe(hidden_states: torch.Tensor, ...@@ -497,8 +498,12 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
E = w1.shape[0] E = w1.shape[0]
N = w2.shape[1] * 16 N = w2.shape[1] * 16
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, if custom_routing_function is None:
renormalize) topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)
get_config_func = functools.partial(try_get_optimal_moe_config, get_config_func = functools.partial(try_get_optimal_moe_config,
w1.shape, w1.shape,
...@@ -695,6 +700,7 @@ def fused_moe( ...@@ -695,6 +700,7 @@ def fused_moe(
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -742,9 +748,12 @@ def fused_moe( ...@@ -742,9 +748,12 @@ def fused_moe(
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
topk, renormalize, topk, renormalize,
num_expert_group, topk_group) num_expert_group, topk_group)
else: elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize) renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)
return fused_experts(hidden_states, return fused_experts(hidden_states,
w1, w1,
......
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import torch import torch
...@@ -62,15 +62,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -62,15 +62,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def apply(self, def apply(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
router_logits: torch.Tensor, x: torch.Tensor,
top_k: int, router_logits: torch.Tensor,
renormalize: bool, top_k: int,
use_grouped_topk: bool, renormalize: bool,
topk_group: Optional[int] = None, use_grouped_topk: bool,
num_expert_group: Optional[int] = None) -> torch.Tensor: topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
return self.forward(x=x, return self.forward(x=x,
layer=layer, layer=layer,
...@@ -79,17 +82,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -79,17 +82,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize=renormalize, renormalize=renormalize,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group) num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
def forward_cuda(self,
layer: torch.nn.Module, def forward_cuda(
x: torch.Tensor, self,
use_grouped_topk: bool, layer: torch.nn.Module,
top_k: int, x: torch.Tensor,
router_logits: torch.Tensor, use_grouped_topk: bool,
renormalize: bool, top_k: int,
topk_group: Optional[int] = None, router_logits: torch.Tensor,
num_expert_group: Optional[int] = None) -> torch.Tensor: renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts) fused_experts)
...@@ -101,7 +108,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -101,7 +108,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
top_k=top_k, top_k=top_k,
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group) num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_experts(hidden_states=x, return fused_experts(hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
...@@ -114,20 +122,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -114,20 +122,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
raise NotImplementedError( raise NotImplementedError(
"The CPU backend currently does not support MoE.") "The CPU backend currently does not support MoE.")
def forward_tpu(self, def forward_tpu(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
use_grouped_topk: bool, x: torch.Tensor,
top_k: int, use_grouped_topk: bool,
router_logits: torch.Tensor, top_k: int,
renormalize: bool, router_logits: torch.Tensor,
topk_group: Optional[int] = None, renormalize: bool,
num_expert_group: Optional[int] = None) -> torch.Tensor: topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk assert not use_grouped_topk
assert num_expert_group is None assert num_expert_group is None
assert topk_group is None assert topk_group is None
assert custom_routing_function is None
return fused_moe(hidden_states=x, return fused_moe(hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
...@@ -172,6 +184,7 @@ class FusedMoE(torch.nn.Module): ...@@ -172,6 +184,7 @@ class FusedMoE(torch.nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
custom_routing_function: Optional[Callable] = None,
): ):
super().__init__() super().__init__()
...@@ -190,6 +203,7 @@ class FusedMoE(torch.nn.Module): ...@@ -190,6 +203,7 @@ class FusedMoE(torch.nn.Module):
assert num_expert_group is not None and topk_group is not None assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group self.num_expert_group = num_expert_group
self.topk_group = topk_group self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = ( self.quant_method: Optional[QuantizeMethodBase] = (
...@@ -390,7 +404,8 @@ class FusedMoE(torch.nn.Module): ...@@ -390,7 +404,8 @@ class FusedMoE(torch.nn.Module):
use_grouped_topk: bool, use_grouped_topk: bool,
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None): num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None):
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk) fused_topk, grouped_topk)
...@@ -405,11 +420,17 @@ class FusedMoE(torch.nn.Module): ...@@ -405,11 +420,17 @@ class FusedMoE(torch.nn.Module):
renormalize=renormalize, renormalize=renormalize,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group) topk_group=topk_group)
else: elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=top_k,
renormalize=renormalize) renormalize=renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
return topk_weights, topk_ids return topk_weights, topk_ids
...@@ -426,7 +447,8 @@ class FusedMoE(torch.nn.Module): ...@@ -426,7 +447,8 @@ class FusedMoE(torch.nn.Module):
renormalize=self.renormalize, renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group) num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
......
import enum import enum
from enum import Enum from enum import Enum
from typing import List, Optional from typing import Callable, List, Optional
import torch import torch
...@@ -256,15 +256,18 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -256,15 +256,18 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
) )
replace_tensor("w2_weight_scale", marlin_w2_scales) replace_tensor("w2_weight_scale", marlin_w2_scales)
def apply(self, def apply(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
router_logits: torch.Tensor, x: torch.Tensor,
top_k: int, router_logits: torch.Tensor,
renormalize: bool = True, top_k: int,
use_grouped_topk: bool = False, renormalize: bool = True,
num_expert_group: Optional[int] = None, use_grouped_topk: bool = False,
topk_group: Optional[int] = None) -> torch.Tensor: num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_marlin_moe) fused_marlin_moe)
...@@ -278,6 +281,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -278,6 +281,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
layer.w13_g_idx_sort_indices, layer.w13_g_idx_sort_indices,
layer.w2_g_idx_sort_indices, layer.w2_g_idx_sort_indices,
top_k, top_k,
custom_routing_function,
renormalize=renormalize, renormalize=renormalize,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale) w2_scale=layer.w2_weight_scale)
from typing import Any, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
...@@ -96,15 +96,18 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -96,15 +96,18 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
requires_grad=False) requires_grad=False)
layer.register_parameter("w2_scale", w2_scale) layer.register_parameter("w2_scale", w2_scale)
def apply(self, def apply(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
router_logits: torch.Tensor, x: torch.Tensor,
top_k: int, router_logits: torch.Tensor,
renormalize: bool = True, top_k: int,
use_grouped_topk: bool = False, renormalize: bool = True,
num_expert_group: Optional[int] = None, use_grouped_topk: bool = False,
topk_group: Optional[int] = None) -> torch.Tensor: num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
...@@ -114,7 +117,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -114,7 +117,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
top_k=top_k, top_k=top_k,
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group) num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_experts(x, return fused_experts(x,
layer.w13_weight, layer.w13_weight,
......
from typing import Any, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch.nn import Module from torch.nn import Module
...@@ -468,15 +468,18 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -468,15 +468,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad=False) requires_grad=False)
return return
def apply(self, def apply(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
router_logits: torch.Tensor, x: torch.Tensor,
top_k: int, router_logits: torch.Tensor,
renormalize: bool, top_k: int,
use_grouped_topk: bool, renormalize: bool,
topk_group: Optional[int] = None, use_grouped_topk: bool,
num_expert_group: Optional[int] = None) -> torch.Tensor: topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -487,7 +490,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -487,7 +490,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
top_k=top_k, top_k=top_k,
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group) num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_experts(x, return fused_experts(x,
layer.w13_weight, layer.w13_weight,
......
...@@ -503,8 +503,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): ...@@ -503,8 +503,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
dtype: torch.dtype, dtype: torch.dtype,
short_factor: List[float], short_factor: List[float],
long_factor: List[float], long_factor: List[float],
short_mscale: float = 1.0, short_mscale: Optional[float] = None,
long_mscale: float = 1.0, long_mscale: Optional[float] = None,
): ):
super().__init__() super().__init__()
...@@ -523,18 +523,22 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): ...@@ -523,18 +523,22 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
self.base = base self.base = base
self.short_factor = short_factor self.short_factor = short_factor
self.long_factor = long_factor self.long_factor = long_factor
self.short_mscale = short_mscale
self.long_mscale = long_mscale
scale = (self.max_position_embeddings /
self.original_max_position_embeddings)
scale = self.max_position_embeddings / \
self.original_max_position_embeddings
if scale <= 1.0: if scale <= 1.0:
self.scaling_factor = 1.0 scaling_factor = 1.0
else: else:
self.scaling_factor = math.sqrt( scaling_factor = math.sqrt(
1 + math.log(scale) / 1 + math.log(scale) /
math.log(self.original_max_position_embeddings)) math.log(self.original_max_position_embeddings))
if short_mscale is None:
short_mscale = scaling_factor
if long_mscale is None:
long_mscale = scaling_factor
self.short_mscale = short_mscale
self.long_mscale = long_mscale
short_cache = self._compute_cos_sin_cache( short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale) original_max_position_embeddings, short_factor, short_mscale)
...@@ -571,8 +575,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): ...@@ -571,8 +575,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
inv_freq = self._compute_inv_freq(rescale_factors) inv_freq = self._compute_inv_freq(rescale_factors)
t = torch.arange(max_position_embeddings, dtype=torch.float) t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * mscale * self.scaling_factor cos = freqs.cos() * mscale
sin = freqs.sin() * mscale * self.scaling_factor sin = freqs.sin() * mscale
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache
......
...@@ -50,6 +50,7 @@ _GENERATION_MODELS = { ...@@ -50,6 +50,7 @@ _GENERATION_MODELS = {
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment