Unverified Commit 9545bfb2 authored by Xiuyu Li's avatar Xiuyu Li Committed by GitHub
Browse files

fix: support gelu_new activation function in gpt2 (#3712)

parent 37373ef2
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Fused operators for activation layers.""" """Fused operators for activation layers."""
import logging import logging
import math
from typing import Optional from typing import Optional
import torch import torch
...@@ -72,6 +73,16 @@ class GeluAndMul(CustomOp): ...@@ -72,6 +73,16 @@ class GeluAndMul(CustomOp):
return out return out
class NewGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# TODO: Implement the CUDA kernel for NewGELU in sgl-kernel
return self.forward_native(x)
class QuickGELU(CustomOp): class QuickGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x) return x * torch.sigmoid(1.702 * x)
......
...@@ -17,14 +17,14 @@ ...@@ -17,14 +17,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights.""" """Inference-only GPT-2 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple, Type
import torch import torch
from torch import nn from torch import nn
from transformers import GPT2Config from transformers import GPT2Config
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.activation import NewGELU
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -97,6 +97,7 @@ class GPT2MLP(nn.Module): ...@@ -97,6 +97,7 @@ class GPT2MLP(nn.Module):
self, self,
intermediate_size: int, intermediate_size: int,
config: GPT2Config, config: GPT2Config,
act_layer: Type[nn.Module] = NewGELU,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
...@@ -116,9 +117,7 @@ class GPT2MLP(nn.Module): ...@@ -116,9 +117,7 @@ class GPT2MLP(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj", prefix=f"{prefix}.c_proj",
) )
self.act = get_act_fn( self.act = act_layer()
config.activation_function, quant_config, intermediate_size
)
def forward( def forward(
self, self,
...@@ -136,6 +135,7 @@ class GPT2Block(nn.Module): ...@@ -136,6 +135,7 @@ class GPT2Block(nn.Module):
self, self,
layer_id: int, layer_id: int,
config: GPT2Config, config: GPT2Config,
act_layer: Type[nn.Module] = NewGELU,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
...@@ -148,7 +148,13 @@ class GPT2Block(nn.Module): ...@@ -148,7 +148,13 @@ class GPT2Block(nn.Module):
layer_id, config, quant_config, prefix=f"{prefix}.attn" layer_id, config, quant_config, prefix=f"{prefix}.attn"
) )
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp") self.mlp = GPT2MLP(
inner_dim,
config,
act_layer=act_layer,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
def forward( def forward(
self, self,
...@@ -190,7 +196,7 @@ class GPT2Model(nn.Module): ...@@ -190,7 +196,7 @@ class GPT2Model(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
GPT2Block(i, config, quant_config) GPT2Block(i, config, quant_config=quant_config)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment