Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
20b765a2
Unverified
Commit
20b765a2
authored
Feb 21, 2025
by
simveit
Committed by
GitHub
Feb 21, 2025
Browse files
Model: Support Qwen 72B RM model. (#3772)
parent
e3107222
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
75 additions
and
1 deletion
+75
-1
docs/references/supported_models.md
docs/references/supported_models.md
+2
-1
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+1
-0
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+2
-0
python/sglang/srt/models/qwen2_rm.py
python/sglang/srt/models/qwen2_rm.py
+70
-0
No files found.
docs/references/supported_models.md
View file @
20b765a2
...
...
@@ -47,7 +47,8 @@
-
`python -m sglang.launch_server --model-path Skywork/Skywork-Reward-Gemma-2-27B-v0.2 --is-embedding`
-
InternLM2ForRewardModel
-
`python -m sglang.launch_server --model-path internlm/internlm2-7b-reward --is-embedding --trust-remote-code`
-
Qwen2ForRewardModel
-
`python -m sglang.launch_server --model-path Qwen/Qwen2.5-Math-RM-72B --is-embedding --trust-remote-code --tp-size=4`
## How to Support a New Language Model
To support a new model in SGLang, you only need to add a single file under
[
SGLang Models Directory
](
https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models
)
.
...
...
python/sglang/srt/configs/model_config.py
View file @
20b765a2
...
...
@@ -389,6 +389,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
or
"LlamaForSequenceClassification"
in
model_architectures
or
"LlamaForSequenceClassificationWithNormal_Weights"
in
model_architectures
or
"InternLM2ForRewardModel"
in
model_architectures
or
"Qwen2ForRewardModel"
in
model_architectures
):
return
False
else
:
...
...
python/sglang/srt/models/qwen2.py
View file @
20b765a2
...
...
@@ -379,6 +379,8 @@ class Qwen2ForCausalLM(nn.Module):
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
if
name
.
startswith
(
"lm_head"
):
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
...
...
python/sglang/srt/models/qwen2_rm.py
0 → 100644
View file @
20b765a2
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
Qwen2Config
from
sglang.srt.layers.pooler
import
EmbeddingPoolerOutput
,
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
,
Qwen2Model
class
Qwen2ForRewardModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Qwen2Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
num_labels
=
1
self
.
model
=
Qwen2Model
(
config
,
quant_config
=
quant_config
)
self
.
score
=
nn
.
Sequential
(
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
),
nn
.
ReLU
(),
nn
.
Linear
(
config
.
hidden_size
,
self
.
num_labels
),
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
False
)
self
.
eos_token_id
=
config
.
eos_token_id
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
True
,
)
->
EmbeddingPoolerOutput
:
assert
get_embedding
,
"Qwen2ForRewardModel is only used for embedding"
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
logits
=
self
.
score
(
hidden_states
)
pooled_logits
=
self
.
pooler
(
logits
,
forward_batch
).
embeddings
return
EmbeddingPoolerOutput
(
pooled_logits
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
return
Qwen2ForCausalLM
.
load_weights
(
self
,
weights
)
EntryClass
=
[
Qwen2ForRewardModel
,
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment