Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
94752ac8
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "eb38c7d1cae1c616de3c1c0ce40353c720f7e3c7"
Unverified
Commit
94752ac8
authored
Aug 11, 2024
by
Yineng Zhang
Committed by
GitHub
Aug 11, 2024
Browse files
feat: use FlashInfer rmsnorm and silu (#907)
parent
43fbb6d9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
156 additions
and
10 deletions
+156
-10
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+29
-0
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+62
-0
python/sglang/srt/models/internlm2.py
python/sglang/srt/models/internlm2.py
+2
-7
python/sglang/srt/models/llama2.py
python/sglang/srt/models/llama2.py
+2
-2
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-1
python/sglang/test/test_layernorm.py
python/sglang/test/test_layernorm.py
+60
-0
No files found.
python/sglang/srt/layers/activation.py
0 → 100644
View file @
94752ac8
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import
torch
import
torch.nn
as
nn
from
flashinfer.activation
import
silu_and_mul
class
SiluAndMul
(
nn
.
Module
):
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
silu
(
x
[...,
:
d
])
*
x
[...,
d
:]
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
x
.
shape
[:
-
1
]
+
(
d
,)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
silu_and_mul
(
x
,
out
)
return
out
python/sglang/srt/layers/layernorm.py
0 → 100644
View file @
94752ac8
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
from
flashinfer.norm
import
fused_add_rmsnorm
,
rmsnorm
class
RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
,
)
->
None
:
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
residual
is
not
None
:
fused_add_rmsnorm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
x
,
residual
out
=
rmsnorm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
orig_dtype
=
x
.
dtype
x
=
x
.
to
(
torch
.
float32
)
if
residual
is
not
None
:
x
=
x
+
residual
.
to
(
torch
.
float32
)
residual
=
x
.
to
(
orig_dtype
)
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
.
to
(
orig_dtype
)
*
self
.
weight
if
residual
is
None
:
return
x
else
:
return
x
,
residual
python/sglang/srt/models/internlm2.py
View file @
94752ac8
...
...
@@ -23,8 +23,6 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
...
...
@@ -38,13 +36,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
class
InternLM2MLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
...
...
@@ -74,7 +73,6 @@ class InternLM2MLP(nn.Module):
class
InternLM2Attention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
...
...
@@ -150,7 +148,6 @@ class InternLM2Attention(nn.Module):
class
InternLMDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -207,7 +204,6 @@ class InternLMDecoderLayer(nn.Module):
class
InternLM2Model
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -254,7 +250,6 @@ class InternLM2Model(nn.Module):
class
InternLM2ForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
python/sglang/srt/models/llama2.py
View file @
94752ac8
...
...
@@ -24,8 +24,6 @@ from torch import nn
from
transformers
import
LlamaConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
...
...
@@ -39,6 +37,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
,
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
python/sglang/srt/server.py
View file @
94752ac8
...
...
@@ -384,7 +384,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if
not
server_args
.
disable_flashinfer
:
assert_pkg_version
(
"flashinfer"
,
"0.1.
3
"
,
"0.1.
4
"
,
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html."
,
...
...
python/sglang/test/test_layernorm.py
0 → 100644
View file @
94752ac8
import
itertools
import
unittest
import
torch
from
sglang.srt.layers.layernorm
import
RMSNorm
class
TestRMSNorm
(
unittest
.
TestCase
):
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
]
NUM_TOKENS
=
[
7
,
83
,
4096
]
HIDDEN_SIZES
=
[
768
,
769
,
770
,
771
,
5120
,
5124
,
5125
,
5126
,
8192
,
8199
]
ADD_RESIDUAL
=
[
False
,
True
]
SEEDS
=
[
0
]
@
classmethod
def
setUpClass
(
cls
):
if
not
torch
.
cuda
.
is_available
():
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
torch
.
set_default_device
(
"cuda"
)
def
_run_rms_norm_test
(
self
,
num_tokens
,
hidden_size
,
add_residual
,
dtype
,
seed
):
torch
.
manual_seed
(
seed
)
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
=
dtype
)
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
scale
=
1
/
(
2
*
hidden_size
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
*
scale
residual
=
torch
.
randn_like
(
x
)
*
scale
if
add_residual
else
None
with
torch
.
inference_mode
():
ref_out
=
layer
.
forward_native
(
x
,
residual
)
out
=
layer
(
x
,
residual
)
if
add_residual
:
self
.
assertTrue
(
torch
.
allclose
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
))
self
.
assertTrue
(
torch
.
allclose
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
))
else
:
self
.
assertTrue
(
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
))
def
test_rms_norm
(
self
):
for
params
in
itertools
.
product
(
self
.
NUM_TOKENS
,
self
.
HIDDEN_SIZES
,
self
.
ADD_RESIDUAL
,
self
.
DTYPES
,
self
.
SEEDS
,
):
with
self
.
subTest
(
num_tokens
=
params
[
0
],
hidden_size
=
params
[
1
],
add_residual
=
params
[
2
],
dtype
=
params
[
3
],
seed
=
params
[
4
],
):
self
.
_run_rms_norm_test
(
*
params
)
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment