Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
c9064e6f
Unverified
Commit
c9064e6f
authored
Aug 24, 2024
by
Yineng Zhang
Committed by
GitHub
Aug 24, 2024
Browse files
feat: use gelu_tanh_and_mul (#1193)
parent
a5b14ad0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
3 deletions
+74
-3
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+17
-1
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+2
-2
python/sglang/test/test_activation.py
python/sglang/test/test_activation.py
+55
-0
No files found.
python/sglang/srt/layers/activation.py
View file @
c9064e6f
...
@@ -15,7 +15,7 @@ limitations under the License.
...
@@ -15,7 +15,7 @@ limitations under the License.
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
flashinfer.activation
import
silu_and_mul
from
flashinfer.activation
import
gelu_tanh_and_mul
,
silu_and_mul
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
...
@@ -37,3 +37,19 @@ class SiluAndMul(CustomOp):
...
@@ -37,3 +37,19 @@ class SiluAndMul(CustomOp):
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
silu_and_mul
(
x
,
out
)
silu_and_mul
(
x
,
out
)
return
out
return
out
class
GeluAndMul
(
CustomOp
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
gelu
(
x
[...,
:
d
],
approximate
=
"tanh"
)
*
x
[...,
d
:]
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
x
.
shape
[:
-
1
]
+
(
d
,)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
gelu_tanh_and_mul
(
x
,
out
)
return
out
python/sglang/srt/models/gemma2.py
View file @
c9064e6f
...
@@ -25,7 +25,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
...
@@ -25,7 +25,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
# FIXME: temporary solution, remove after next vllm release
# FIXME: temporary solution, remove after next vllm release
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.activation
import
GeluAndMul
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
...
@@ -39,6 +38,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
...
@@ -39,6 +38,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
@@ -135,7 +135,7 @@ class Gemma2MLP(nn.Module):
...
@@ -135,7 +135,7 @@ class Gemma2MLP(nn.Module):
"function. Please set `hidden_act` and `hidden_activation` to "
"function. Please set `hidden_act` and `hidden_activation` to "
"`gelu_pytorch_tanh`."
"`gelu_pytorch_tanh`."
)
)
self
.
act_fn
=
GeluAndMul
(
approximate
=
"tanh"
)
self
.
act_fn
=
GeluAndMul
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
...
...
python/sglang/test/test_activation.py
0 → 100644
View file @
c9064e6f
import
itertools
import
unittest
import
torch
from
sglang.srt.layers.activation
import
GeluAndMul
class
TestGeluAndMul
(
unittest
.
TestCase
):
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
D
=
[
512
,
4096
,
5120
,
13824
]
SEEDS
=
[
0
]
@
classmethod
def
setUpClass
(
cls
):
if
not
torch
.
cuda
.
is_available
():
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
torch
.
set_default_device
(
"cuda"
)
def
_run_gelu_and_mul_test
(
self
,
num_tokens
,
d
,
dtype
,
seed
):
torch
.
manual_seed
(
seed
)
layer
=
GeluAndMul
().
to
(
dtype
=
dtype
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
)
with
torch
.
inference_mode
():
ref_out
=
layer
.
forward_native
(
x
)
out
=
layer
.
forward_cuda
(
x
)
if
dtype
==
torch
.
bfloat16
:
atol
=
rtol
=
1e-2
else
:
atol
=
rtol
=
1e-3
self
.
assertTrue
(
torch
.
allclose
(
out
,
ref_out
,
atol
=
atol
,
rtol
=
rtol
))
def
test_gelu_and_mul
(
self
):
for
params
in
itertools
.
product
(
self
.
NUM_TOKENS
,
self
.
D
,
self
.
DTYPES
,
self
.
SEEDS
,
):
with
self
.
subTest
(
num_tokens
=
params
[
0
],
d
=
params
[
1
],
dtype
=
params
[
2
],
seed
=
params
[
3
],
):
self
.
_run_gelu_and_mul_test
(
*
params
)
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment