"tests/v1/entrypoints/openai/test_completion.py" did not exist on "dfb1a15dcb4c24bf7ff0ba7ddfc5d623ad519d7d"
Unverified Commit 5b8a1fde authored by Junhao Li's avatar Junhao Li Committed by GitHub
Browse files

[Model][Bugfix] Add FATReLU activation and support for openbmb/MiniCPM-S-1B-sft (#9396)

parent fb60ae9b
......@@ -159,7 +159,7 @@ Text Generation
-
* - :code:`MiniCPMForCausalLM`
- MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, :code:`openbmb/MiniCPM-S-1B-sft`, etc.
- ✅︎
- ✅︎
* - :code:`MiniCPM3ForCausalLM`
......
......@@ -13,6 +13,33 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
class FatreluAndMul(CustomOp):
"""An activation function for FATReLU.
The function computes x -> FATReLU(x[:d]) * x[d:] where
d = x.shape[-1] // 2.
This is used in openbmb/MiniCPM-S-1B-sft.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def __init__(self, threshold: float = 0.):
super().__init__()
self.threshold = threshold
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
x1 = x[..., :d]
x2 = x[..., d:]
x1 = F.threshold(x1, self.threshold, 0.0)
return x1 * x2
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
......
......@@ -33,7 +33,7 @@ from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
......@@ -152,6 +152,7 @@ class MiniCPMMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
hidden_act_param: float,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -163,10 +164,13 @@ class MiniCPMMLP(nn.Module):
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
if hidden_act == "silu":
self.act_fn = SiluAndMul()
elif hidden_act == "fatrelu":
self.act_fn = FatreluAndMul(threshold=hidden_act_param)
else:
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu and fatrelu are supported for now.")
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
......@@ -304,6 +308,7 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_size=self.hidden_size,
intermediate_size=self.config.intermediate_size,
hidden_act=self.config.hidden_act,
hidden_act_param=getattr(self.config, "hidden_act_param", 0.),
quant_config=self.quant_config,
)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment