Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3602692c
Unverified
Commit
3602692c
authored
Aug 27, 2024
by
Yineng Zhang
Committed by
GitHub
Aug 27, 2024
Browse files
feat: replace get_act_fn for gpt_bigcode (#1231)
parent
909f3436
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
84 additions
and
1 deletion
+84
-1
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+83
-0
python/sglang/srt/models/gpt_bigcode.py
python/sglang/srt/models/gpt_bigcode.py
+1
-1
No files found.
python/sglang/srt/layers/activation.py
View file @
3602692c
...
@@ -13,10 +13,20 @@ limitations under the License.
...
@@ -13,10 +13,20 @@ limitations under the License.
"""Fused operators for activation layers."""
"""Fused operators for activation layers."""
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
flashinfer.activation
import
gelu_tanh_and_mul
,
silu_and_mul
from
flashinfer.activation
import
gelu_tanh_and_mul
,
silu_and_mul
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.utils
import
set_weight_attrs
class
SiluAndMul
(
CustomOp
):
class
SiluAndMul
(
CustomOp
):
...
@@ -53,3 +63,76 @@ class GeluAndMul(CustomOp):
...
@@ -53,3 +63,76 @@ class GeluAndMul(CustomOp):
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
gelu_tanh_and_mul
(
x
,
out
)
gelu_tanh_and_mul
(
x
,
out
)
return
out
return
out
class
ScaledActivation
(
nn
.
Module
):
"""An activation function with post-scale parameters.
This is used for some quantization methods like AWQ.
"""
def
__init__
(
self
,
act_module
:
nn
.
Module
,
intermediate_size
:
int
,
input_is_parallel
:
bool
=
True
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
super
().
__init__
()
self
.
act
=
act_module
self
.
input_is_parallel
=
input_is_parallel
if
input_is_parallel
:
tp_size
=
get_tensor_model_parallel_world_size
()
intermediate_size_per_partition
=
divide
(
intermediate_size
,
tp_size
)
else
:
intermediate_size_per_partition
=
intermediate_size
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
scales
=
nn
.
Parameter
(
torch
.
empty
(
intermediate_size_per_partition
,
dtype
=
params_dtype
)
)
set_weight_attrs
(
self
.
scales
,
{
"weight_loader"
:
self
.
weight_loader
})
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
act
(
x
)
/
self
.
scales
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
param_data
=
param
.
data
if
self
.
input_is_parallel
:
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
param_data
.
shape
[
0
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
0
,
start_idx
,
shard_size
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
_ACTIVATION_REGISTRY
=
{
"gelu"
:
nn
.
GELU
(),
"gelu_pytorch_tanh"
:
nn
.
GELU
(
approximate
=
"tanh"
),
}
def
get_act_fn
(
act_fn_name
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
intermediate_size
:
Optional
[
int
]
=
None
,
input_is_parallel
:
bool
=
True
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
nn
.
Module
:
"""Get an activation function by name."""
act_fn_name
=
act_fn_name
.
lower
()
if
act_fn_name
not
in
_ACTIVATION_REGISTRY
:
raise
ValueError
(
f
"Activation function
{
act_fn_name
!
r
}
is not supported."
)
act_fn
=
_ACTIVATION_REGISTRY
[
act_fn_name
]
if
quant_config
is
not
None
and
act_fn_name
in
quant_config
.
get_scaled_act_names
():
if
intermediate_size
is
None
:
raise
ValueError
(
"intermediate_size must be specified for scaled "
"activation functions."
)
return
ScaledActivation
(
act_fn
,
intermediate_size
,
input_is_parallel
,
params_dtype
)
return
act_fn
python/sglang/srt/models/gpt_bigcode.py
View file @
3602692c
...
@@ -23,7 +23,6 @@ from torch import nn
...
@@ -23,7 +23,6 @@ from torch import nn
from
transformers
import
GPTBigCodeConfig
from
transformers
import
GPTBigCodeConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -33,6 +32,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
...
@@ -33,6 +32,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.activation
import
get_act_fn
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.sampler
import
Sampler
from
sglang.srt.layers.sampler
import
Sampler
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment