Unverified Commit 55dcce91 authored by Lu Fang's avatar Lu Fang Committed by GitHub
Browse files

Upstream Llama4 Support to Main (#16113)


Signed-off-by: default avatarAston Zhang <22279212+astonzhang@users.noreply.github.com>
Signed-off-by: default avatarChris Thi <chris.c.thi@gmail.com>
Signed-off-by: default avatardrisspg <drisspguessous@gmail.com>
Signed-off-by: default avatarJon Swenson <jmswen@gmail.com>
Signed-off-by: default avatarKeyun Tong <tongkeyun@gmail.com>
Signed-off-by: default avatarLu Fang <fanglu@meta.com>
Signed-off-by: default avatarXiaodong Wang <xdwang@meta.com>
Signed-off-by: default avatarYang Chen <yangche@fb.com>
Signed-off-by: default avatarYe (Charlotte) Qi <yeq@meta.com>
Signed-off-by: default avatarYong Hoon Shin <yhshin@meta.com>
Signed-off-by: default avatarZijing Liu <liuzijing2014@gmail.com>
Signed-off-by: default avatarLu Fang <lufang@fb.com>
Signed-off-by: default avatarLu Fang <fanglu@fb.com>
Signed-off-by: default avatarLucia Fang <fanglu@fb.com>
Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarLu Fang <fanglu@fb.com>
Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8017c8db
...@@ -23,6 +23,7 @@ def cutlass_moe_fp8( ...@@ -23,6 +23,7 @@ def cutlass_moe_fp8(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.half, out_dtype: torch.dtype = torch.half,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a a8w8-quantized Mixture of Experts (MoE) layer This function computes a a8w8-quantized Mixture of Experts (MoE) layer
...@@ -96,8 +97,14 @@ def cutlass_moe_fp8( ...@@ -96,8 +97,14 @@ def cutlass_moe_fp8(
n = w2_q.size(1) n = w2_q.size(1)
topk = topk_ids.size(1) topk = topk_ids.size(1)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False) a2_scale.numel() != 1 if a2_scale is not None else False)
if apply_router_weight_on_input:
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
# TODO: this only works for topK=1, will need to update for topK>1
a = a * topk_weights.to(out_dtype)
a_q, a1_scale = ops.scaled_fp8_quant( a_q, a1_scale = ops.scaled_fp8_quant(
a, a1_scale, use_per_token_if_dynamic=per_act_token) a, a1_scale, use_per_token_if_dynamic=per_act_token)
...@@ -139,6 +146,8 @@ def cutlass_moe_fp8( ...@@ -139,6 +146,8 @@ def cutlass_moe_fp8(
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
expert_offsets[:-1], problem_sizes2, ab_strides2, expert_offsets[:-1], problem_sizes2, ab_strides2,
ab_strides2, c_strides2) ab_strides2, c_strides2)
# Gather tokens
return (c2[c_map].view(m, topk, k) * c2 = c2[c_map].view(m, topk, k)
topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) if not apply_router_weight_on_input:
c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
return c2.sum(dim=1)
...@@ -954,6 +954,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -954,6 +954,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
...@@ -967,10 +968,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -967,10 +968,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None: block_shape: Optional[List[int]] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, use_fp8_w8a8, use_int8_w8a16, activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int4_w4a16, global_num_experts, expert_map, use_int8_w8a16, use_int4_w4a16, global_num_experts,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
block_shape) a2_scale, block_shape)
def inplace_fused_experts_fake( def inplace_fused_experts_fake(
...@@ -980,6 +981,7 @@ def inplace_fused_experts_fake( ...@@ -980,6 +981,7 @@ def inplace_fused_experts_fake(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
...@@ -1010,6 +1012,7 @@ def outplace_fused_experts( ...@@ -1010,6 +1012,7 @@ def outplace_fused_experts(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
...@@ -1023,10 +1026,11 @@ def outplace_fused_experts( ...@@ -1023,10 +1026,11 @@ def outplace_fused_experts(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor: block_shape: Optional[List[int]] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, use_fp8_w8a8, use_int8_w8a16, False, activation, apply_router_weight_on_input,
use_int4_w4a16, global_num_experts, expert_map, use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, global_num_experts, expert_map, w1_scale,
a2_scale, block_shape) w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)
def outplace_fused_experts_fake( def outplace_fused_experts_fake(
...@@ -1084,6 +1088,7 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1084,6 +1088,7 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
...@@ -1099,6 +1104,7 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1099,6 +1104,7 @@ def fused_experts(hidden_states: torch.Tensor,
allow_deep_gemm: bool = False) -> torch.Tensor: allow_deep_gemm: bool = False) -> torch.Tensor:
if (allow_deep_gemm and use_fp8_w8a8 if (allow_deep_gemm and use_fp8_w8a8
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8( return deep_gemm_moe_fp8(
hidden_states=hidden_states, hidden_states=hidden_states,
w1=w1, w1=w1,
...@@ -1122,6 +1128,7 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1122,6 +1128,7 @@ def fused_experts(hidden_states: torch.Tensor,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
...@@ -1143,6 +1150,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1143,6 +1150,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
...@@ -1270,7 +1278,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1270,7 +1278,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
False, apply_router_weight_on_input,
top_k_num, top_k_num,
config, config,
compute_type=compute_type, compute_type=compute_type,
...@@ -1307,7 +1315,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1307,7 +1315,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
True, not apply_router_weight_on_input,
1, 1,
config, config,
compute_type=compute_type, compute_type=compute_type,
......
...@@ -65,7 +65,9 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -65,7 +65,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -156,9 +158,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -156,9 +158,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward(x=x, return self.forward(
x=x,
layer=layer, layer=layer,
router_logits=router_logits, router_logits=router_logits,
top_k=top_k, top_k=top_k,
...@@ -171,7 +175,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -171,7 +175,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
activation=activation) activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
def forward_cuda( def forward_cuda(
self, self,
...@@ -188,6 +193,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -188,6 +193,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
...@@ -202,13 +208,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -202,13 +208,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias)
return fused_experts(hidden_states=x, return fused_experts(
hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map) expert_map=expert_map)
...@@ -228,9 +236,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -228,9 +236,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False,
**kwargs, **kwargs,
): ):
assert activation == "silu", f"{activation} is not supported." assert activation == "silu", f"{activation} is not supported."
assert apply_router_weight_on_input is False
return layer.ipex_fusion( return layer.ipex_fusion(
x, x,
use_grouped_topk, use_grouped_topk,
...@@ -259,6 +269,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -259,6 +269,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert not use_grouped_topk assert not use_grouped_topk
...@@ -266,6 +277,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -266,6 +277,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
assert topk_group is None assert topk_group is None
assert custom_routing_function is None assert custom_routing_function is None
assert layer is not None assert layer is not None
assert apply_router_weight_on_input is False
if scoring_func != "softmax": if scoring_func != "softmax":
raise NotImplementedError( raise NotImplementedError(
"Only softmax scoring function is supported for HPU.") "Only softmax scoring function is supported for HPU.")
...@@ -290,12 +302,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -290,12 +302,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert not use_grouped_topk assert not use_grouped_topk
assert num_expert_group is None assert num_expert_group is None
assert topk_group is None assert topk_group is None
assert custom_routing_function is None assert custom_routing_function is None
assert apply_router_weight_on_input is False
if scoring_func != "softmax": if scoring_func != "softmax":
raise NotImplementedError( raise NotImplementedError(
"Only softmax scoring function is supported for TPU.") "Only softmax scoring function is supported for TPU.")
...@@ -401,6 +415,7 @@ class FusedMoE(torch.nn.Module): ...@@ -401,6 +415,7 @@ class FusedMoE(torch.nn.Module):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
): ):
super().__init__() super().__init__()
...@@ -486,6 +501,7 @@ class FusedMoE(torch.nn.Module): ...@@ -486,6 +501,7 @@ class FusedMoE(torch.nn.Module):
self.quant_method = quant_config.get_quant_method(self, prefix) self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None assert self.quant_method is not None
self.apply_router_weight_on_input = apply_router_weight_on_input
moe_quant_params = { moe_quant_params = {
"num_experts": self.local_num_experts, "num_experts": self.local_num_experts,
"hidden_size": hidden_size, "hidden_size": hidden_size,
...@@ -853,6 +869,7 @@ class FusedMoE(torch.nn.Module): ...@@ -853,6 +869,7 @@ class FusedMoE(torch.nn.Module):
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation, activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
) )
if self.dp_size > 1: if self.dp_size > 1:
......
...@@ -92,6 +92,7 @@ class RMSNorm(CustomOp): ...@@ -92,6 +92,7 @@ class RMSNorm(CustomOp):
eps: float = 1e-6, eps: float = 1e-6,
var_hidden_size: Optional[int] = None, var_hidden_size: Optional[int] = None,
has_weight: bool = True, has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -100,7 +101,9 @@ class RMSNorm(CustomOp): ...@@ -100,7 +101,9 @@ class RMSNorm(CustomOp):
self.variance_size_override = (None if var_hidden_size == hidden_size self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size) else var_hidden_size)
self.has_weight = has_weight self.has_weight = has_weight
if dtype is not None:
self.weight = torch.ones(hidden_size, dtype=dtype)
else:
self.weight = torch.ones(hidden_size) self.weight = torch.ones(hidden_size)
if self.has_weight: if self.has_weight:
self.weight = nn.Parameter(self.weight) self.weight = nn.Parameter(self.weight)
......
...@@ -469,6 +469,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -469,6 +469,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
...@@ -476,6 +477,10 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -476,6 +477,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
raise NotImplementedError( raise NotImplementedError(
"Expert Parallelism is not supported for " "Expert Parallelism is not supported for "
"fused Marlin MoE method.") "fused Marlin MoE method.")
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
......
...@@ -224,6 +224,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -224,6 +224,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -240,13 +241,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -240,13 +241,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias)
return fused_experts(x, return fused_experts(
x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True, use_fp8_w8a8=True,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
...@@ -438,6 +441,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ...@@ -438,6 +441,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -474,6 +478,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ...@@ -474,6 +478,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
out_dtype=x.dtype, out_dtype=x.dtype,
apply_router_weight_on_input=apply_router_weight_on_input,
) )
...@@ -778,6 +783,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -778,6 +783,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
...@@ -785,6 +791,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -785,6 +791,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
raise NotImplementedError( raise NotImplementedError(
"Expert Parallelism is not supported for " "Expert Parallelism is not supported for "
"fused Marlin MoE method.") "fused Marlin MoE method.")
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for "
"fused Marlin MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
......
...@@ -113,6 +113,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -113,6 +113,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -129,7 +130,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -129,7 +130,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias)
return fused_experts(x, return fused_experts(
x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
...@@ -138,6 +140,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): ...@@ -138,6 +140,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
activation=activation, activation=activation,
use_int8_w8a16=True, use_int8_w8a16=True,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_scale, w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale) w2_scale=layer.w2_scale)
......
...@@ -773,6 +773,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -773,6 +773,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -800,6 +801,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -800,6 +801,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
activation=activation, activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map, expert_map=expert_map,
w1_scale=(layer.w13_weight_scale_inv w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale), if self.block_quant else layer.w13_weight_scale),
......
...@@ -338,9 +338,15 @@ class GGUFMoEMethod(FusedMoEMethodBase): ...@@ -338,9 +338,15 @@ class GGUFMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
): ):
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused GGUF MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
......
...@@ -592,9 +592,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -592,9 +592,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input is not None:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
# The input must currently be float16 # The input must currently be float16
orig_dtype = x.dtype orig_dtype = x.dtype
......
...@@ -293,6 +293,7 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -293,6 +293,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -312,7 +313,8 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -312,7 +313,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
weight_bits = self.quant_config.weight_bits weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp has_zp = self.quant_config.has_zp
return fused_experts(x, return fused_experts(
x,
layer.w13_qweight, layer.w13_qweight,
layer.w2_qweight, layer.w2_qweight,
topk_weights=topk_weights, topk_weights=topk_weights,
...@@ -321,6 +323,7 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -321,6 +323,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
use_int4_w4a16=weight_bits == 4, use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8, use_int8_w8a16=weight_bits == 8,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_scales, w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales, w2_scale=layer.w2_scales,
......
...@@ -202,6 +202,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -202,6 +202,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -217,7 +219,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -217,7 +219,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias)
return fused_experts(x, return fused_experts(
x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
...@@ -225,6 +228,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -225,6 +228,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
inplace=True, inplace=True,
use_fp8_w8a8=True, use_fp8_w8a8=True,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
......
...@@ -851,6 +851,70 @@ class Llama3RotaryEmbedding(RotaryEmbedding): ...@@ -851,6 +851,70 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
return new_freqs return new_freqs
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
):
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base)
inv_freqs = inv_freqs[:(self.rotary_dim // 2)]
return inv_freqs
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
# self.max_position_embeddings here is number of image patches
# i.e. (image_size // patch_size) ** 2
num_patches = self.max_position_embeddings
img_idx = torch.arange(num_patches,
dtype=torch.int32) \
.reshape(num_patches, 1)
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN
num_patches_single_dim = int(math.sqrt(num_patches))
frequencies_x = img_idx % num_patches_single_dim
frequencies_y = img_idx // num_patches_single_dim
freqs_x = ((frequencies_x + 1)[..., None] *
inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
freqs_y = ((frequencies_y + 1)[..., None] *
inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
freqs = torch.cat([freqs_x, freqs_y],
dim=-1).float().contiguous()[..., ::2]
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
cache = torch.view_as_complex(
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
return cache
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
query_ = torch.view_as_complex(query.float().reshape(
*query.shape[:-1], -1, 2))
key_ = torch.view_as_complex(key.float().reshape(
*key.shape[:-1], -1, 2))
broadcast_shape = [
d if i == 1 or i == (query_.ndim - 1) else 1
for i, d in enumerate(query_.shape)
]
freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
return query_out.type_as(query), key_out.type_as(key)
class MRotaryEmbedding(RotaryEmbedding): class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections.""" """Rotary Embedding with Multimodal Sections."""
...@@ -1130,6 +1194,10 @@ def get_rope( ...@@ -1130,6 +1194,10 @@ def get_rope(
scaling_factor, low_freq_factor, scaling_factor, low_freq_factor,
high_freq_factor, high_freq_factor,
original_max_position) original_max_position)
elif scaling_type == "mllama4":
rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype)
elif scaling_type == "default": elif scaling_type == "default":
if "mrope_section" in rope_scaling: if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding( rotary_emb = MRotaryEmbedding(
......
...@@ -111,9 +111,11 @@ def _initialize_model( ...@@ -111,9 +111,11 @@ def _initialize_model(
vllm_config: VllmConfig, vllm_config: VllmConfig,
*, *,
prefix: str = "", prefix: str = "",
model_class: Optional[type[nn.Module]] = None,
) -> nn.Module: ) -> nn.Module:
"""Initialize a model with the given configurations.""" """Initialize a model with the given configurations."""
model_config = vllm_config.model_config model_config = vllm_config.model_config
if model_class is None:
model_class, _ = get_model_architecture(model_config) model_class, _ = get_model_architecture(model_config)
if vllm_config.quant_config is not None: if vllm_config.quant_config is not None:
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -65,6 +65,7 @@ class LlamaMLP(nn.Module): ...@@ -65,6 +65,7 @@ class LlamaMLP(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
prefix: str = "", prefix: str = "",
reduce_results: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
...@@ -79,6 +80,7 @@ class LlamaMLP(nn.Module): ...@@ -79,6 +80,7 @@ class LlamaMLP(nn.Module):
output_size=hidden_size, output_size=hidden_size,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
) )
if hidden_act != "silu": if hidden_act != "silu":
...@@ -292,7 +294,7 @@ class LlamaModel(nn.Module): ...@@ -292,7 +294,7 @@ class LlamaModel(nn.Module):
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
...@@ -466,10 +468,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -466,10 +468,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"ffn_norm": "post_attention_layernorm", "ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens", "tok_embeddings": "model.embed_tokens",
"output": "lm_head", "output": "lm_head",
"norm": "model.norm" "norm": "model.norm",
} }
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
...@@ -478,7 +484,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -478,7 +484,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config self.lora_config = lora_config
self.model = self._init_model(vllm_config=vllm_config, self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"),
layer_type=layer_type)
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
...@@ -513,8 +520,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -513,8 +520,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): def _init_model(self,
return LlamaModel(vllm_config=vllm_config, prefix=prefix) vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = LlamaDecoderLayer):
return LlamaModel(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
This diff is collapsed.
This diff is collapsed.
...@@ -196,6 +196,7 @@ _MULTIMODAL_MODELS = { ...@@ -196,6 +196,7 @@ _MULTIMODAL_MODELS = {
# [Encoder-decoder] # [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
"Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"), "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501 "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
} }
......
...@@ -19,9 +19,10 @@ ...@@ -19,9 +19,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Iterable, Set, Tuple, Type from typing import Iterable, Set, Tuple
import torch import torch
import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -124,7 +125,7 @@ class TeleChat2ForCausalLM(LlamaForCausalLM): ...@@ -124,7 +125,7 @@ class TeleChat2ForCausalLM(LlamaForCausalLM):
def _init_model(self, def _init_model(self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): layer_type: type[nn.Module] = LlamaDecoderLayer):
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix) return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
......
...@@ -22,9 +22,8 @@ ...@@ -22,9 +22,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Type
import torch import torch
import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -39,7 +38,7 @@ class TeleFLMModel(LlamaModel): ...@@ -39,7 +38,7 @@ class TeleFLMModel(LlamaModel):
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer, layer_type: type[nn.Module] = LlamaDecoderLayer,
): ):
super().__init__(vllm_config=vllm_config, super().__init__(vllm_config=vllm_config,
prefix=prefix, prefix=prefix,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment