Unverified Commit 07c43530 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Model] Support Grok1 (#13795)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 34e3494e
...@@ -286,6 +286,11 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -286,6 +286,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `parasail-ai/GritLM-7B-vllm`. * `parasail-ai/GritLM-7B-vllm`.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
- * `Grok1ModelForCausalLM`
* Grok1
* `hpcai-tech/grok-1`.
* ✅︎
* ✅︎
- * `InternLMForCausalLM` - * `InternLMForCausalLM`
* InternLM * InternLM
* `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. * `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.
......
...@@ -130,6 +130,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -130,6 +130,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"), "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"),
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
"Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1",
trust_remote_code=True),
"InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b", "InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b",
trust_remote_code=True), trust_remote_code=True),
"InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b", "InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b",
......
...@@ -1040,6 +1040,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1040,6 +1040,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
...@@ -1053,9 +1054,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1053,9 +1054,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None: block_shape: Optional[List[int]] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, activation, use_fp8_w8a8, use_int8_w8a16,
global_num_experts, expert_map, w1_scale, w2_scale, use_int4_w4a16, global_num_experts, expert_map,
w1_zp, w2_zp, a1_scale, a2_scale, block_shape) w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)
def inplace_fused_experts_fake( def inplace_fused_experts_fake(
...@@ -1064,6 +1066,7 @@ def inplace_fused_experts_fake( ...@@ -1064,6 +1066,7 @@ def inplace_fused_experts_fake(
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
...@@ -1093,6 +1096,7 @@ def outplace_fused_experts( ...@@ -1093,6 +1096,7 @@ def outplace_fused_experts(
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
...@@ -1106,7 +1110,7 @@ def outplace_fused_experts( ...@@ -1106,7 +1110,7 @@ def outplace_fused_experts(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor: block_shape: Optional[List[int]] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, use_fp8_w8a8, use_int8_w8a16, False, activation, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, global_num_experts, expert_map, use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape) a2_scale, block_shape)
...@@ -1118,6 +1122,7 @@ def outplace_fused_experts_fake( ...@@ -1118,6 +1122,7 @@ def outplace_fused_experts_fake(
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
...@@ -1147,6 +1152,7 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1147,6 +1152,7 @@ def fused_experts(hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace: bool = False, inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
...@@ -1162,15 +1168,17 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1162,15 +1168,17 @@ def fused_experts(hidden_states: torch.Tensor,
if inplace: if inplace:
torch.ops.vllm.inplace_fused_experts( torch.ops.vllm.inplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, hidden_states, w1, w2, topk_weights, topk_ids, activation,
use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map, use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)
return hidden_states return hidden_states
else: else:
return torch.ops.vllm.outplace_fused_experts( return torch.ops.vllm.outplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, hidden_states, w1, w2, topk_weights, topk_ids, activation,
use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map, use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)
def fused_experts_impl(hidden_states: torch.Tensor, def fused_experts_impl(hidden_states: torch.Tensor,
...@@ -1179,6 +1187,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1179,6 +1187,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace: bool = False, inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
...@@ -1303,8 +1312,14 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1303,8 +1312,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape) block_shape=block_shape)
torch.ops._C.silu_and_mul(intermediate_cache2, if activation == "silu":
intermediate_cache1.view(-1, N)) torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
elif activation == "gelu":
torch.ops._C.gelu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
invoke_fused_moe_kernel(intermediate_cache2, invoke_fused_moe_kernel(intermediate_cache2,
w2, w2,
...@@ -1339,6 +1354,7 @@ def fused_moe( ...@@ -1339,6 +1354,7 @@ def fused_moe(
topk: int, topk: int,
renormalize: bool, renormalize: bool,
inplace: bool = False, inplace: bool = False,
activation: str = "silu",
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
...@@ -1370,6 +1386,8 @@ def fused_moe( ...@@ -1370,6 +1386,8 @@ def fused_moe(
- renormalize (bool): If True, renormalize the top-k weights to sum to 1. - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. - inplace (bool): If True, perform the operation in-place.
Defaults to False. Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- num_expert_group: Optional[int]: additional parameter for grouped_topk - num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk
...@@ -1420,6 +1438,7 @@ def fused_moe( ...@@ -1420,6 +1438,7 @@ def fused_moe(
topk_weights, topk_weights,
topk_ids, topk_ids,
inplace=inplace, inplace=inplace,
activation=activation,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
......
...@@ -120,7 +120,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -120,7 +120,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward(x=x, return self.forward(x=x,
layer=layer, layer=layer,
...@@ -134,7 +135,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -134,7 +135,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map, expert_map=expert_map,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
activation=activation)
def forward_cuda( def forward_cuda(
self, self,
...@@ -150,7 +152,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -150,7 +152,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -170,6 +173,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -170,6 +173,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map) expert_map=expert_map)
...@@ -186,9 +190,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -186,9 +190,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
**kwargs, **kwargs,
): ):
assert custom_routing_function is None assert custom_routing_function is None
assert activation == "silu", f"{activation} is not supported."
return layer.ipex_fusion( return layer.ipex_fusion(
x, x,
use_grouped_topk, use_grouped_topk,
...@@ -213,7 +219,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -213,7 +219,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert not use_grouped_topk assert not use_grouped_topk
assert num_expert_group is None assert num_expert_group is None
...@@ -225,6 +232,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -225,6 +232,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if e_score_correction_bias is not None: if e_score_correction_bias is not None:
raise NotImplementedError( raise NotImplementedError(
"Expert score correction bias is not supported for TPU.") "Expert score correction bias is not supported for TPU.")
assert activation == "silu", f"{activation} is not supported for TPU."
return fused_moe_pallas(hidden_states=x, return fused_moe_pallas(hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
...@@ -277,6 +285,7 @@ class FusedMoE(torch.nn.Module): ...@@ -277,6 +285,7 @@ class FusedMoE(torch.nn.Module):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
): ):
super().__init__() super().__init__()
...@@ -305,6 +314,7 @@ class FusedMoE(torch.nn.Module): ...@@ -305,6 +314,7 @@ class FusedMoE(torch.nn.Module):
self.custom_routing_function = custom_routing_function self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
self.activation = activation
self.expert_map = None self.expert_map = None
if self.ep_size > 1: if self.ep_size > 1:
...@@ -653,7 +663,9 @@ class FusedMoE(torch.nn.Module): ...@@ -653,7 +663,9 @@ class FusedMoE(torch.nn.Module):
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias) e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.) # Default set to False. (May have to add shared expert outputs.)
......
...@@ -469,7 +469,9 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -469,7 +469,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
if expert_map is not None: if expert_map is not None:
raise NotImplementedError( raise NotImplementedError(
"Expert Parallelism is not supported for " "Expert Parallelism is not supported for "
......
...@@ -219,6 +219,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -219,6 +219,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -240,6 +241,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -240,6 +241,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
...@@ -550,7 +552,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -550,7 +552,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
if expert_map is not None: if expert_map is not None:
raise NotImplementedError( raise NotImplementedError(
"Expert Parallelism is not supported for " "Expert Parallelism is not supported for "
......
...@@ -113,6 +113,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -113,6 +113,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -134,6 +135,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -134,6 +135,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation,
use_int8_w8a16=True, use_int8_w8a16=True,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
......
...@@ -675,6 +675,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -675,6 +675,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -698,6 +699,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -698,6 +699,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
......
...@@ -590,7 +590,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -590,7 +590,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
# The input must currently be float16 # The input must currently be float16
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.half() x = x.half()
......
This diff is collapsed.
...@@ -60,6 +60,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -60,6 +60,7 @@ _TEXT_GENERATION_MODELS = {
"GraniteForCausalLM": ("granite", "GraniteForCausalLM"), "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
"GritLM": ("gritlm", "GritLM"), "GritLM": ("gritlm", "GritLM"),
"Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"), "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment