Unverified Commit 9545bfb2 authored by Xiuyu Li's avatar Xiuyu Li Committed by GitHub
Browse files

fix: support gelu_new activation function in gpt2 (#3712)

parent 37373ef2
......@@ -14,6 +14,7 @@
"""Fused operators for activation layers."""
import logging
import math
from typing import Optional
import torch
......@@ -72,6 +73,16 @@ class GeluAndMul(CustomOp):
return out
class NewGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# TODO: Implement the CUDA kernel for NewGELU in sgl-kernel
return self.forward_native(x)
class QuickGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x)
......
......@@ -17,14 +17,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Tuple
from typing import Iterable, Optional, Tuple, Type
import torch
from torch import nn
from transformers import GPT2Config
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.activation import NewGELU
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
......@@ -97,6 +97,7 @@ class GPT2MLP(nn.Module):
self,
intermediate_size: int,
config: GPT2Config,
act_layer: Type[nn.Module] = NewGELU,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
......@@ -116,9 +117,7 @@ class GPT2MLP(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.act = get_act_fn(
config.activation_function, quant_config, intermediate_size
)
self.act = act_layer()
def forward(
self,
......@@ -136,6 +135,7 @@ class GPT2Block(nn.Module):
self,
layer_id: int,
config: GPT2Config,
act_layer: Type[nn.Module] = NewGELU,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
......@@ -148,7 +148,13 @@ class GPT2Block(nn.Module):
layer_id, config, quant_config, prefix=f"{prefix}.attn"
)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
self.mlp = GPT2MLP(
inner_dim,
config,
act_layer=act_layer,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
def forward(
self,
......@@ -190,7 +196,7 @@ class GPT2Model(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList(
[
GPT2Block(i, config, quant_config)
GPT2Block(i, config, quant_config=quant_config)
for i in range(config.num_hidden_layers)
]
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment