Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2ce32db6
Unverified
Commit
2ce32db6
authored
Nov 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 03, 2024
Browse files
Let reward model take text inputs instead of message lists (#1907)
Co-authored-by:
Kyle Corbitt
<
kyle@corbt.com
>
parent
793b79db
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
43 additions
and
58 deletions
+43
-58
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+2
-5
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-2
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+9
-3
python/sglang/srt/models/llama_embedding.py
python/sglang/srt/models/llama_embedding.py
+2
-10
python/sglang/srt/models/llama_reward.py
python/sglang/srt/models/llama_reward.py
+5
-0
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+1
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+3
-19
python/sglang/test/runners.py
python/sglang/test/runners.py
+2
-1
test/srt/models/test_embedding_models.py
test/srt/models/test_embedding_models.py
+4
-5
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+2
-1
test/srt/models/test_reward_models.py
test/srt/models/test_reward_models.py
+10
-10
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
No files found.
python/sglang/srt/hf_transformers_utils.py
View file @
2ce32db6
...
...
@@ -88,11 +88,8 @@ CONTEXT_LENGTH_KEYS = [
def
get_context_length
(
config
):
"""Get the context length of a model from a huggingface model configs.
And here the config should be text_config part if the model is a multimodal
LLM.
"""
text_config
=
getattr
(
config
,
"text_config"
,
config
)
"""Get the context length of a model from a huggingface model configs."""
text_config
=
config
rope_scaling
=
getattr
(
text_config
,
"rope_scaling"
,
None
)
if
rope_scaling
:
rope_scaling_factor
=
rope_scaling
.
get
(
"factor"
,
1
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
2ce32db6
...
...
@@ -238,7 +238,7 @@ class EmbeddingReqInput:
self
.
rid
=
uuid
.
uuid4
().
hex
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
{}
self
.
sampling_params
[
"max_new_tokens"
]
=
1
self
.
sampling_params
[
"max_new_tokens"
]
=
0
else
:
if
self
.
rid
is
None
:
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
self
.
batch_size
)]
...
...
@@ -248,7 +248,7 @@ class EmbeddingReqInput:
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
[{}]
*
self
.
batch_size
for
i
in
range
(
self
.
batch_size
):
self
.
sampling_params
[
i
][
"max_new_tokens"
]
=
1
self
.
sampling_params
[
i
][
"max_new_tokens"
]
=
0
def
regenerate_rid
(
self
):
self
.
rid
=
uuid
.
uuid4
().
hex
...
...
python/sglang/srt/models/llama.py
View file @
2ce32db6
...
...
@@ -34,6 +34,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
...
...
@@ -303,6 +304,7 @@ class LlamaForCausalLM(nn.Module):
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -311,11 +313,15 @@ class LlamaForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
)
->
LogitsProcessorOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
get_hidden_dim
(
self
,
module_name
):
# return input_dim, output_dim
...
...
python/sglang/srt/models/llama_embedding.py
View file @
2ce32db6
...
...
@@ -36,9 +36,7 @@ class LlamaEmbeddingModel(nn.Module):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
name
=
None
,
loaded_weight
=
None
):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
...
@@ -49,7 +47,7 @@ class LlamaEmbeddingModel(nn.Module):
]
params_dict
=
dict
(
self
.
model
.
named_parameters
())
def
load_weights_per_param
(
name
,
loaded_weight
)
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
return
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
...
...
@@ -78,12 +76,6 @@ class LlamaEmbeddingModel(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
name
is
None
or
loaded_weight
is
None
:
for
name
,
loaded_weight
in
weights
:
load_weights_per_param
(
name
,
loaded_weight
)
else
:
load_weights_per_param
(
name
,
loaded_weight
)
class
MistralModel
(
LlamaEmbeddingModel
):
pass
...
...
python/sglang/srt/models/llama_reward.py
View file @
2ce32db6
...
...
@@ -52,7 +52,12 @@ class LlamaForSequenceClassification(nn.Module):
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
True
,
)
->
EmbeddingPoolerOutput
:
assert
(
get_embedding
),
"LlamaForSequenceClassification is only used for embedding"
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
scores
=
self
.
score
(
hidden_states
)
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
2ce32db6
...
...
@@ -618,7 +618,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens
.
cpu
().
numpy
()
for
i
,
image
in
enumerate
(
forward_batch
.
image_inputs
):
if
image
==
None
:
if
image
is
None
:
continue
start_idx
=
extend_start_loc_cpu
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
...
...
python/sglang/srt/server.py
View file @
2ce32db6
...
...
@@ -254,7 +254,7 @@ app.put("/encode")(encode_request)
async
def
judge_request
(
obj
:
EmbeddingReqInput
,
request
:
Request
):
"""Handle a reward model request."""
"""Handle a reward model request.
Now the arguments and return values are the same as embedding models.
"""
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
...
...
@@ -696,24 +696,8 @@ class Runtime:
self
,
prompt
:
Union
[
str
,
List
[
str
],
List
[
Dict
],
List
[
List
[
Dict
]]],
):
if
isinstance
(
prompt
,
str
)
or
isinstance
(
prompt
[
0
],
str
):
# embedding
json_data
=
{
"text"
:
prompt
,
}
response
=
requests
.
post
(
self
.
url
+
"/encode"
,
json
=
json_data
,
)
else
:
# reward
json_data
=
{
"conv"
:
prompt
,
}
response
=
requests
.
post
(
self
.
url
+
"/judge"
,
json
=
json_data
,
)
json_data
=
{
"text"
:
prompt
}
response
=
requests
.
post
(
self
.
url
+
"/encode"
,
json
=
json_data
)
return
json
.
dumps
(
response
.
json
())
def
__del__
(
self
):
...
...
python/sglang/test/runners.py
View file @
2ce32db6
...
...
@@ -273,6 +273,7 @@ class SRTRunner:
disable_cuda_graph
=
disable_cuda_graph
,
disable_radix_cache
=
disable_radix_cache
,
)
self
.
tokenizer
=
get_tokenizer
(
model_path
)
def
forward
(
self
,
...
...
@@ -366,7 +367,7 @@ class SRTRunner:
return
ModelOutput
(
embed_logits
=
logits
)
else
:
scores
=
[
x
[
"embedding"
][
0
]
for
x
in
response
]
return
ModelOutput
(
scores
=
logit
s
)
return
ModelOutput
(
scores
=
score
s
)
def
__enter__
(
self
):
return
self
...
...
test/srt/models/test_embedding_models.py
View file @
2ce32db6
...
...
@@ -30,6 +30,10 @@ TORCH_DTYPES = [torch.float16]
class
TestEmbeddingModels
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
assert_close_prefill_logits
(
self
,
prompts
,
...
...
@@ -74,9 +78,4 @@ class TestEmbeddingModels(unittest.TestCase):
if
__name__
==
"__main__"
:
try
:
mp
.
set_start_method
(
"spawn"
)
except
RuntimeError
:
pass
unittest
.
main
()
test/srt/models/test_generation_models.py
View file @
2ce32db6
...
...
@@ -63,9 +63,10 @@ TORCH_DTYPES = [torch.float16]
class
TestGenerationModels
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
mp
.
set_start_method
(
"spawn"
)
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
assert_close_logits_and_output_strs
(
self
,
...
...
test/srt/models/test_reward_models.py
View file @
2ce32db6
...
...
@@ -18,10 +18,10 @@ import unittest
import
torch
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
from
sglang.test.runners
import
HFRunner
,
SRTRunner
MODELS
=
[
(
"LxzGordon/URM-LLaMa-3.1-8B"
,
1
,
2
e-2
),
(
"LxzGordon/URM-LLaMa-3.1-8B"
,
1
,
3
e-2
),
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
...
@@ -43,6 +43,10 @@ CONVS = [
class
TestRewardModels
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
assert_close_reward_scores
(
self
,
convs
,
...
...
@@ -63,12 +67,13 @@ class TestRewardModels(unittest.TestCase):
torch_dtype
=
torch_dtype
,
model_type
=
"reward"
,
)
as
srt_runner
:
srt_outputs
=
srt_runner
.
forward
(
convs
)
prompts
=
srt_runner
.
tokenizer
.
apply_chat_template
(
convs
,
tokenize
=
False
)
srt_outputs
=
srt_runner
.
forward
(
prompts
)
hf_scores
=
torch
.
tensor
(
hf_outputs
.
scores
)
srt_scores
=
torch
.
tensor
(
srt_outputs
.
scores
)
print
(
hf_scores
)
print
(
srt_scores
)
print
(
f
"
{
hf_scores
=
}
"
)
print
(
f
"
{
srt_scores
=
}
"
)
assert
torch
.
all
(
abs
(
hf_scores
-
srt_scores
)
<
tolerance
...
...
@@ -83,9 +88,4 @@ class TestRewardModels(unittest.TestCase):
if
__name__
==
"__main__"
:
try
:
mp
.
set_start_method
(
"spawn"
)
except
RuntimeError
:
pass
unittest
.
main
()
test/srt/run_suite.py
View file @
2ce32db6
...
...
@@ -8,7 +8,7 @@ suites = {
"models/test_embedding_models.py"
,
"models/test_generation_models.py"
,
"models/test_lora.py"
,
#
"models/test_reward_models.py",
"models/test_reward_models.py"
,
"sampling/penaltylib"
,
"test_chunked_prefill.py"
,
"test_double_sparsity.py"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment