Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c411f32e
Unverified
Commit
c411f32e
authored
Aug 29, 2024
by
Yineng Zhang
Committed by
GitHub
Aug 28, 2024
Browse files
feat: replace GeluAndMul (#1234)
parent
bf53bf51
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
7 deletions
+13
-7
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+10
-4
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+2
-2
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+1
-1
No files found.
python/sglang/srt/layers/activation.py
View file @
c411f32e
...
...
@@ -18,7 +18,7 @@ from typing import Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
flashinfer.activation
import
gelu_tanh_and_mul
,
silu_and_mul
from
flashinfer.activation
import
gelu_and_mul
,
gelu_tanh_and_mul
,
silu_and_mul
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
...
...
@@ -43,18 +43,24 @@ class SiluAndMul(CustomOp):
class
GeluAndMul
(
CustomOp
):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
approximate
=
"tanh"
):
super
().
__init__
()
self
.
approximate
=
approximate
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
gelu
(
x
[...,
:
d
],
approximate
=
"tanh"
)
*
x
[...,
d
:]
return
F
.
gelu
(
x
[...,
:
d
],
approximate
=
self
.
approximate
)
*
x
[...,
d
:]
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
x
.
shape
[:
-
1
]
+
(
d
,)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
gelu_tanh_and_mul
(
x
,
out
)
if
self
.
approximate
==
"tanh"
:
gelu_tanh_and_mul
(
x
,
out
)
elif
self
.
approximate
==
"none"
:
gelu_and_mul
(
x
,
out
)
else
:
raise
RuntimeError
(
"GeluAndMul only support tanh or none"
)
return
out
...
...
python/sglang/srt/models/gemma.py
View file @
c411f32e
...
...
@@ -23,7 +23,6 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
...
...
@@ -34,6 +33,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
@@ -60,7 +60,7 @@ class GemmaMLP(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
act_fn
=
GeluAndMul
()
self
.
act_fn
=
GeluAndMul
(
"none"
)
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
...
...
test/srt/models/test_generation_models.py
View file @
c411f32e
...
...
@@ -96,7 +96,7 @@ class TestGenerationModels(unittest.TestCase):
if
hf_logprobs
.
shape
[
0
]
<=
100
:
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
prefill_tolerance
),
"prefill logprobs are not all close"
),
f
"prefill logprobs are not all close
with model_path=
{
model_path
}
prompts=
{
prompts
}
prefill_tolerance=
{
prefill_tolerance
}
"
print
(
f
"hf_outputs.output_strs=
{
hf_outputs
.
output_strs
}
"
)
print
(
f
"srt_outputs.output_strs=
{
srt_outputs
.
output_strs
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment