Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9545bfb2
Unverified
Commit
9545bfb2
authored
Mar 04, 2025
by
Xiuyu Li
Committed by
GitHub
Mar 04, 2025
Browse files
fix: support gelu_new activation function in gpt2 (#3712)
parent
37373ef2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
7 deletions
+24
-7
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+11
-0
python/sglang/srt/models/gpt2.py
python/sglang/srt/models/gpt2.py
+13
-7
No files found.
python/sglang/srt/layers/activation.py
View file @
9545bfb2
...
...
@@ -14,6 +14,7 @@
"""Fused operators for activation layers."""
import
logging
import
math
from
typing
import
Optional
import
torch
...
...
@@ -72,6 +73,16 @@ class GeluAndMul(CustomOp):
return
out
class
NewGELU
(
CustomOp
):
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
c
=
math
.
sqrt
(
2.0
/
math
.
pi
)
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
c
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3.0
))))
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# TODO: Implement the CUDA kernel for NewGELU in sgl-kernel
return
self
.
forward_native
(
x
)
class
QuickGELU
(
CustomOp
):
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
x
*
torch
.
sigmoid
(
1.702
*
x
)
...
...
python/sglang/srt/models/gpt2.py
View file @
9545bfb2
...
...
@@ -17,14 +17,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Tuple
,
Type
import
torch
from
torch
import
nn
from
transformers
import
GPT2Config
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.activation
import
get_act_fn
from
sglang.srt.layers.activation
import
NewGELU
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
...
...
@@ -97,6 +97,7 @@ class GPT2MLP(nn.Module):
self
,
intermediate_size
:
int
,
config
:
GPT2Config
,
act_layer
:
Type
[
nn
.
Module
]
=
NewGELU
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
...
...
@@ -116,9 +117,7 @@ class GPT2MLP(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.c_proj"
,
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
self
.
act
=
act_layer
()
def
forward
(
self
,
...
...
@@ -136,6 +135,7 @@ class GPT2Block(nn.Module):
self
,
layer_id
:
int
,
config
:
GPT2Config
,
act_layer
:
Type
[
nn
.
Module
]
=
NewGELU
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
...
...
@@ -148,7 +148,13 @@ class GPT2Block(nn.Module):
layer_id
,
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
GPT2MLP
(
inner_dim
,
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
self
.
mlp
=
GPT2MLP
(
inner_dim
,
config
,
act_layer
=
act_layer
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
def
forward
(
self
,
...
...
@@ -190,7 +196,7 @@ class GPT2Model(nn.Module):
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
h
=
nn
.
ModuleList
(
[
GPT2Block
(
i
,
config
,
quant_config
)
GPT2Block
(
i
,
config
,
quant_config
=
quant_config
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment