Unverified Commit c411f32e authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: replace GeluAndMul (#1234)

parent bf53bf51
...@@ -18,7 +18,7 @@ from typing import Optional ...@@ -18,7 +18,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm.distributed import ( from vllm.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -43,18 +43,24 @@ class SiluAndMul(CustomOp): ...@@ -43,18 +43,24 @@ class SiluAndMul(CustomOp):
class GeluAndMul(CustomOp): class GeluAndMul(CustomOp):
def __init__(self, **kwargs): def __init__(self, approximate="tanh"):
super().__init__() super().__init__()
self.approximate = approximate
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate="tanh") * x[..., d:] return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,) output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "tanh":
gelu_tanh_and_mul(x, out) gelu_tanh_and_mul(x, out)
elif self.approximate == "none":
gelu_and_mul(x, out)
else:
raise RuntimeError("GeluAndMul only support tanh or none")
return out return out
......
...@@ -23,7 +23,6 @@ from torch import nn ...@@ -23,7 +23,6 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -34,6 +33,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -34,6 +33,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -60,7 +60,7 @@ class GemmaMLP(nn.Module): ...@@ -60,7 +60,7 @@ class GemmaMLP(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
) )
self.act_fn = GeluAndMul() self.act_fn = GeluAndMul("none")
def forward(self, x): def forward(self, x):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
......
...@@ -96,7 +96,7 @@ class TestGenerationModels(unittest.TestCase): ...@@ -96,7 +96,7 @@ class TestGenerationModels(unittest.TestCase):
if hf_logprobs.shape[0] <= 100: if hf_logprobs.shape[0] <= 100:
assert torch.all( assert torch.all(
abs(hf_logprobs - srt_logprobs) < prefill_tolerance abs(hf_logprobs - srt_logprobs) < prefill_tolerance
), "prefill logprobs are not all close" ), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}"
print(f"hf_outputs.output_strs={hf_outputs.output_strs}") print(f"hf_outputs.output_strs={hf_outputs.output_strs}")
print(f"srt_outputs.output_strs={srt_outputs.output_strs}") print(f"srt_outputs.output_strs={srt_outputs.output_strs}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment