Unverified Commit 8db776f0 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

support QuickGELU (#3250)

parent 4eb4b401
...@@ -72,6 +72,15 @@ class GeluAndMul(CustomOp): ...@@ -72,6 +72,15 @@ class GeluAndMul(CustomOp):
return out return out
class QuickGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel
return self.forward_native(x)
class ScaledActivation(nn.Module): class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters. """An activation function with post-scale parameters.
......
...@@ -31,10 +31,10 @@ import torch ...@@ -31,10 +31,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from vllm.model_executor.layers.activation import QuickGELU
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.activation import QuickGELU
from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment