Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
3ae78a09
"vscode:/vscode.git/clone" did not exist on "52babf86549dab8bb57574e0c332f29e4e448e78"
Unverified
Commit
3ae78a09
authored
Feb 07, 2024
by
Arcmoon
Committed by
GitHub
Feb 06, 2024
Browse files
Add gptq quantization model support (#141)
parent
ccbe1e67
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
14 deletions
+25
-14
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+1
-2
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+7
-2
python/sglang/srt/models/qwen.py
python/sglang/srt/models/qwen.py
+16
-9
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-1
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
3ae78a09
...
@@ -19,10 +19,9 @@ class RadixAttention(nn.Module):
...
@@ -19,10 +19,9 @@ class RadixAttention(nn.Module):
head_dim
,
head_dim
,
scaling
,
scaling
,
num_kv_heads
,
num_kv_heads
,
layer_id
,
layer_id
):
):
super
().
__init__
()
super
().
__init__
()
self
.
tp_q_head_num
=
num_heads
self
.
tp_q_head_num
=
num_heads
self
.
tp_k_head_num
=
num_kv_heads
self
.
tp_k_head_num
=
num_kv_heads
self
.
tp_v_head_num
=
num_kv_heads
self
.
tp_v_head_num
=
num_kv_heads
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
3ae78a09
...
@@ -12,10 +12,13 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
...
@@ -12,10 +12,13 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from
sglang.srt.utils
import
is_multimodal_model
from
sglang.srt.utils
import
is_multimodal_model
from
sglang.utils
import
get_available_gpu_memory
from
sglang.utils
import
get_available_gpu_memory
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.model_loader
import
_set_default_torch_dtype
from
vllm.model_executor.model_loader
import
_set_default_torch_dtype
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
import
sglang
import
sglang
QUANTIONCONFIG_MAPPING
=
{
'awq'
:
AWQConfig
,
'gptq'
:
GPTQConfig
}
logger
=
logging
.
getLogger
(
"model_runner"
)
logger
=
logging
.
getLogger
(
"model_runner"
)
...
@@ -280,8 +283,10 @@ class ModelRunner:
...
@@ -280,8 +283,10 @@ class ModelRunner:
self
.
model_config
.
hf_config
,
"quantization_config"
,
None
self
.
model_config
.
hf_config
,
"quantization_config"
,
None
)
)
if
hf_quant_config
is
not
None
:
if
hf_quant_config
is
not
None
:
# TODO: config quantization awq etc
quant_config_class
=
QUANTIONCONFIG_MAPPING
.
get
(
hf_quant_config
[
'quant_method'
])
quant_config
=
AWQConfig
.
from_config
(
hf_quant_config
)
if
quant_config_class
is
None
:
raise
ValueError
(
f
"Unsupported quantization method:
{
hf_quant_config
[
'quant_method'
]
}
"
)
quant_config
=
quant_config_class
.
from_config
(
hf_quant_config
)
logger
.
info
(
f
"quant_config:
{
quant_config
}
"
)
logger
.
info
(
f
"quant_config:
{
quant_config
}
"
)
linear_method
=
quant_config
.
get_linear_method
()
linear_method
=
quant_config
.
get_linear_method
()
model
=
model_class
(
model
=
model_class
(
...
...
python/sglang/srt/models/qwen.py
View file @
3ae78a09
...
@@ -34,6 +34,7 @@ class QWenMLP(nn.Module):
...
@@ -34,6 +34,7 @@ class QWenMLP(nn.Module):
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
=
"silu"
,
hidden_act
:
str
=
"silu"
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
...
@@ -41,12 +42,14 @@ class QWenMLP(nn.Module):
...
@@ -41,12 +42,14 @@ class QWenMLP(nn.Module):
2
*
[
intermediate_size
],
2
*
[
intermediate_size
],
bias
=
False
,
bias
=
False
,
gather_output
=
False
,
gather_output
=
False
,
linear_method
=
linear_method
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
linear_method
=
linear_method
)
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
raise
ValueError
(
...
@@ -71,6 +74,7 @@ class QWenAttention(nn.Module):
...
@@ -71,6 +74,7 @@ class QWenAttention(nn.Module):
layer_id
:
int
=
0
,
layer_id
:
int
=
0
,
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -82,13 +86,18 @@ class QWenAttention(nn.Module):
...
@@ -82,13 +86,18 @@ class QWenAttention(nn.Module):
# pylint: disable=invalid-name
# pylint: disable=invalid-name
self
.
c_attn
=
QKVParallelLinear
(
self
.
c_attn
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
True
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
bias
=
True
,
linear_method
=
linear_method
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
linear_method
=
linear_method
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
...
@@ -121,7 +130,7 @@ class QWenAttention(nn.Module):
...
@@ -121,7 +130,7 @@ class QWenAttention(nn.Module):
class
QWenBlock
(
nn
.
Module
):
class
QWenBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
QWenConfig
,
layer_id
):
def
__init__
(
self
,
config
:
QWenConfig
,
layer_id
,
linear_method
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
ln_1
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
...
@@ -134,11 +143,12 @@ class QWenBlock(nn.Module):
...
@@ -134,11 +143,12 @@ class QWenBlock(nn.Module):
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
layer_id
=
layer_id
,
layer_id
=
layer_id
,
linear_method
=
linear_method
)
)
self
.
ln_2
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
QWenMLP
(
config
.
hidden_size
,
config
.
intermediate_size
//
2
)
self
.
mlp
=
QWenMLP
(
config
.
hidden_size
,
config
.
intermediate_size
//
2
,
linear_method
=
linear_method
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -165,7 +175,7 @@ class QWenBlock(nn.Module):
...
@@ -165,7 +175,7 @@ class QWenBlock(nn.Module):
class
QWenModel
(
nn
.
Module
):
class
QWenModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
QWenConfig
):
def
__init__
(
self
,
config
:
QWenConfig
,
linear_method
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
...
@@ -176,7 +186,7 @@ class QWenModel(nn.Module):
...
@@ -176,7 +186,7 @@ class QWenModel(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
h
=
nn
.
ModuleList
(
self
.
h
=
nn
.
ModuleList
(
[
QWenBlock
(
config
,
i
)
for
i
in
range
(
config
.
num_hidden_layers
)]
[
QWenBlock
(
config
,
i
,
linear_method
=
linear_method
)
for
i
in
range
(
config
.
num_hidden_layers
)]
)
)
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
...
@@ -202,7 +212,7 @@ class QWenLMHeadModel(nn.Module):
...
@@ -202,7 +212,7 @@ class QWenLMHeadModel(nn.Module):
def
__init__
(
self
,
config
:
QWenConfig
,
linear_method
=
None
):
def
__init__
(
self
,
config
:
QWenConfig
,
linear_method
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
transformer
=
QWenModel
(
config
)
self
.
transformer
=
QWenModel
(
config
,
linear_method
=
linear_method
)
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
self
.
lm_head
=
ParallelLMHead
(
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
@@ -219,9 +229,6 @@ class QWenLMHeadModel(nn.Module):
...
@@ -219,9 +229,6 @@ class QWenLMHeadModel(nn.Module):
)
)
return
next_tokens
return
next_tokens
_column_parallel_weights
=
[]
_row_parallel_weights
=
[
"c_proj.weight"
]
def
load_weights
(
def
load_weights
(
self
,
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
...
...
python/sglang/srt/utils.py
View file @
3ae78a09
...
@@ -259,4 +259,4 @@ def load_image(image_file):
...
@@ -259,4 +259,4 @@ def load_image(image_file):
else
:
else
:
image
=
Image
.
open
(
BytesIO
(
base64
.
b64decode
(
image_file
)))
image
=
Image
.
open
(
BytesIO
(
base64
.
b64decode
(
image_file
)))
return
image
return
image
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment