Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2ce32db6
Unverified
Commit
2ce32db6
authored
Nov 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 03, 2024
Browse files
Let reward model take text inputs instead of message lists (#1907)
Co-authored-by:
Kyle Corbitt
<
kyle@corbt.com
>
parent
793b79db
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
43 additions
and
58 deletions
+43
-58
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+2
-5
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-2
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+9
-3
python/sglang/srt/models/llama_embedding.py
python/sglang/srt/models/llama_embedding.py
+2
-10
python/sglang/srt/models/llama_reward.py
python/sglang/srt/models/llama_reward.py
+5
-0
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+1
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+3
-19
python/sglang/test/runners.py
python/sglang/test/runners.py
+2
-1
test/srt/models/test_embedding_models.py
test/srt/models/test_embedding_models.py
+4
-5
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+2
-1
test/srt/models/test_reward_models.py
test/srt/models/test_reward_models.py
+10
-10
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
No files found.
python/sglang/srt/hf_transformers_utils.py
View file @
2ce32db6
...
@@ -88,11 +88,8 @@ CONTEXT_LENGTH_KEYS = [
...
@@ -88,11 +88,8 @@ CONTEXT_LENGTH_KEYS = [
def
get_context_length
(
config
):
def
get_context_length
(
config
):
"""Get the context length of a model from a huggingface model configs.
"""Get the context length of a model from a huggingface model configs."""
And here the config should be text_config part if the model is a multimodal
text_config
=
config
LLM.
"""
text_config
=
getattr
(
config
,
"text_config"
,
config
)
rope_scaling
=
getattr
(
text_config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
text_config
,
"rope_scaling"
,
None
)
if
rope_scaling
:
if
rope_scaling
:
rope_scaling_factor
=
rope_scaling
.
get
(
"factor"
,
1
)
rope_scaling_factor
=
rope_scaling
.
get
(
"factor"
,
1
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
2ce32db6
...
@@ -238,7 +238,7 @@ class EmbeddingReqInput:
...
@@ -238,7 +238,7 @@ class EmbeddingReqInput:
self
.
rid
=
uuid
.
uuid4
().
hex
self
.
rid
=
uuid
.
uuid4
().
hex
if
self
.
sampling_params
is
None
:
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
{}
self
.
sampling_params
=
{}
self
.
sampling_params
[
"max_new_tokens"
]
=
1
self
.
sampling_params
[
"max_new_tokens"
]
=
0
else
:
else
:
if
self
.
rid
is
None
:
if
self
.
rid
is
None
:
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
self
.
batch_size
)]
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
self
.
batch_size
)]
...
@@ -248,7 +248,7 @@ class EmbeddingReqInput:
...
@@ -248,7 +248,7 @@ class EmbeddingReqInput:
if
self
.
sampling_params
is
None
:
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
[{}]
*
self
.
batch_size
self
.
sampling_params
=
[{}]
*
self
.
batch_size
for
i
in
range
(
self
.
batch_size
):
for
i
in
range
(
self
.
batch_size
):
self
.
sampling_params
[
i
][
"max_new_tokens"
]
=
1
self
.
sampling_params
[
i
][
"max_new_tokens"
]
=
0
def
regenerate_rid
(
self
):
def
regenerate_rid
(
self
):
self
.
rid
=
uuid
.
uuid4
().
hex
self
.
rid
=
uuid
.
uuid4
().
hex
...
...
python/sglang/srt/models/llama.py
View file @
2ce32db6
...
@@ -34,6 +34,7 @@ from sglang.srt.layers.linear import (
...
@@ -34,6 +34,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
from
sglang.srt.layers.torchao_utils
import
apply_torchao_config_
...
@@ -303,6 +304,7 @@ class LlamaForCausalLM(nn.Module):
...
@@ -303,6 +304,7 @@ class LlamaForCausalLM(nn.Module):
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -311,11 +313,15 @@ class LlamaForCausalLM(nn.Module):
...
@@ -311,11 +313,15 @@ class LlamaForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
)
->
LogitsProcessorOutput
:
)
->
LogitsProcessorOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
if
not
get_embedding
:
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
forward_batch
)
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
get_hidden_dim
(
self
,
module_name
):
def
get_hidden_dim
(
self
,
module_name
):
# return input_dim, output_dim
# return input_dim, output_dim
...
...
python/sglang/srt/models/llama_embedding.py
View file @
2ce32db6
...
@@ -36,9 +36,7 @@ class LlamaEmbeddingModel(nn.Module):
...
@@ -36,9 +36,7 @@ class LlamaEmbeddingModel(nn.Module):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
name
=
None
,
loaded_weight
=
None
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -49,7 +47,7 @@ class LlamaEmbeddingModel(nn.Module):
...
@@ -49,7 +47,7 @@ class LlamaEmbeddingModel(nn.Module):
]
]
params_dict
=
dict
(
self
.
model
.
named_parameters
())
params_dict
=
dict
(
self
.
model
.
named_parameters
())
def
load_weights_per_param
(
name
,
loaded_weight
)
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
return
return
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
...
@@ -78,12 +76,6 @@ class LlamaEmbeddingModel(nn.Module):
...
@@ -78,12 +76,6 @@ class LlamaEmbeddingModel(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
name
is
None
or
loaded_weight
is
None
:
for
name
,
loaded_weight
in
weights
:
load_weights_per_param
(
name
,
loaded_weight
)
else
:
load_weights_per_param
(
name
,
loaded_weight
)
class
MistralModel
(
LlamaEmbeddingModel
):
class
MistralModel
(
LlamaEmbeddingModel
):
pass
pass
...
...
python/sglang/srt/models/llama_reward.py
View file @
2ce32db6
...
@@ -52,7 +52,12 @@ class LlamaForSequenceClassification(nn.Module):
...
@@ -52,7 +52,12 @@ class LlamaForSequenceClassification(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
True
,
)
->
EmbeddingPoolerOutput
:
)
->
EmbeddingPoolerOutput
:
assert
(
get_embedding
),
"LlamaForSequenceClassification is only used for embedding"
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
scores
=
self
.
score
(
hidden_states
)
scores
=
self
.
score
(
hidden_states
)
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
2ce32db6
...
@@ -618,7 +618,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -618,7 +618,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens
.
cpu
().
numpy
()
for
i
,
image
in
enumerate
(
forward_batch
.
image_inputs
):
for
i
,
image
in
enumerate
(
forward_batch
.
image_inputs
):
if
image
==
None
:
if
image
is
None
:
continue
continue
start_idx
=
extend_start_loc_cpu
[
i
]
start_idx
=
extend_start_loc_cpu
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
...
...
python/sglang/srt/server.py
View file @
2ce32db6
...
@@ -254,7 +254,7 @@ app.put("/encode")(encode_request)
...
@@ -254,7 +254,7 @@ app.put("/encode")(encode_request)
async
def
judge_request
(
obj
:
EmbeddingReqInput
,
request
:
Request
):
async
def
judge_request
(
obj
:
EmbeddingReqInput
,
request
:
Request
):
"""Handle a reward model request."""
"""Handle a reward model request.
Now the arguments and return values are the same as embedding models.
"""
try
:
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
return
ret
...
@@ -696,24 +696,8 @@ class Runtime:
...
@@ -696,24 +696,8 @@ class Runtime:
self
,
self
,
prompt
:
Union
[
str
,
List
[
str
],
List
[
Dict
],
List
[
List
[
Dict
]]],
prompt
:
Union
[
str
,
List
[
str
],
List
[
Dict
],
List
[
List
[
Dict
]]],
):
):
if
isinstance
(
prompt
,
str
)
or
isinstance
(
prompt
[
0
],
str
):
json_data
=
{
"text"
:
prompt
}
# embedding
response
=
requests
.
post
(
self
.
url
+
"/encode"
,
json
=
json_data
)
json_data
=
{
"text"
:
prompt
,
}
response
=
requests
.
post
(
self
.
url
+
"/encode"
,
json
=
json_data
,
)
else
:
# reward
json_data
=
{
"conv"
:
prompt
,
}
response
=
requests
.
post
(
self
.
url
+
"/judge"
,
json
=
json_data
,
)
return
json
.
dumps
(
response
.
json
())
return
json
.
dumps
(
response
.
json
())
def
__del__
(
self
):
def
__del__
(
self
):
...
...
python/sglang/test/runners.py
View file @
2ce32db6
...
@@ -273,6 +273,7 @@ class SRTRunner:
...
@@ -273,6 +273,7 @@ class SRTRunner:
disable_cuda_graph
=
disable_cuda_graph
,
disable_cuda_graph
=
disable_cuda_graph
,
disable_radix_cache
=
disable_radix_cache
,
disable_radix_cache
=
disable_radix_cache
,
)
)
self
.
tokenizer
=
get_tokenizer
(
model_path
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -366,7 +367,7 @@ class SRTRunner:
...
@@ -366,7 +367,7 @@ class SRTRunner:
return
ModelOutput
(
embed_logits
=
logits
)
return
ModelOutput
(
embed_logits
=
logits
)
else
:
else
:
scores
=
[
x
[
"embedding"
][
0
]
for
x
in
response
]
scores
=
[
x
[
"embedding"
][
0
]
for
x
in
response
]
return
ModelOutput
(
scores
=
logit
s
)
return
ModelOutput
(
scores
=
score
s
)
def
__enter__
(
self
):
def
__enter__
(
self
):
return
self
return
self
...
...
test/srt/models/test_embedding_models.py
View file @
2ce32db6
...
@@ -30,6 +30,10 @@ TORCH_DTYPES = [torch.float16]
...
@@ -30,6 +30,10 @@ TORCH_DTYPES = [torch.float16]
class
TestEmbeddingModels
(
unittest
.
TestCase
):
class
TestEmbeddingModels
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
assert_close_prefill_logits
(
def
assert_close_prefill_logits
(
self
,
self
,
prompts
,
prompts
,
...
@@ -74,9 +78,4 @@ class TestEmbeddingModels(unittest.TestCase):
...
@@ -74,9 +78,4 @@ class TestEmbeddingModels(unittest.TestCase):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
try
:
mp
.
set_start_method
(
"spawn"
)
except
RuntimeError
:
pass
unittest
.
main
()
unittest
.
main
()
test/srt/models/test_generation_models.py
View file @
2ce32db6
...
@@ -63,9 +63,10 @@ TORCH_DTYPES = [torch.float16]
...
@@ -63,9 +63,10 @@ TORCH_DTYPES = [torch.float16]
class
TestGenerationModels
(
unittest
.
TestCase
):
class
TestGenerationModels
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
mp
.
set_start_method
(
"spawn"
)
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
assert_close_logits_and_output_strs
(
def
assert_close_logits_and_output_strs
(
self
,
self
,
...
...
test/srt/models/test_reward_models.py
View file @
2ce32db6
...
@@ -18,10 +18,10 @@ import unittest
...
@@ -18,10 +18,10 @@ import unittest
import
torch
import
torch
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
from
sglang.test.runners
import
HFRunner
,
SRTRunner
MODELS
=
[
MODELS
=
[
(
"LxzGordon/URM-LLaMa-3.1-8B"
,
1
,
2
e-2
),
(
"LxzGordon/URM-LLaMa-3.1-8B"
,
1
,
3
e-2
),
]
]
TORCH_DTYPES
=
[
torch
.
float16
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
@@ -43,6 +43,10 @@ CONVS = [
...
@@ -43,6 +43,10 @@ CONVS = [
class
TestRewardModels
(
unittest
.
TestCase
):
class
TestRewardModels
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
assert_close_reward_scores
(
def
assert_close_reward_scores
(
self
,
self
,
convs
,
convs
,
...
@@ -63,12 +67,13 @@ class TestRewardModels(unittest.TestCase):
...
@@ -63,12 +67,13 @@ class TestRewardModels(unittest.TestCase):
torch_dtype
=
torch_dtype
,
torch_dtype
=
torch_dtype
,
model_type
=
"reward"
,
model_type
=
"reward"
,
)
as
srt_runner
:
)
as
srt_runner
:
srt_outputs
=
srt_runner
.
forward
(
convs
)
prompts
=
srt_runner
.
tokenizer
.
apply_chat_template
(
convs
,
tokenize
=
False
)
srt_outputs
=
srt_runner
.
forward
(
prompts
)
hf_scores
=
torch
.
tensor
(
hf_outputs
.
scores
)
hf_scores
=
torch
.
tensor
(
hf_outputs
.
scores
)
srt_scores
=
torch
.
tensor
(
srt_outputs
.
scores
)
srt_scores
=
torch
.
tensor
(
srt_outputs
.
scores
)
print
(
hf_scores
)
print
(
f
"
{
hf_scores
=
}
"
)
print
(
srt_scores
)
print
(
f
"
{
srt_scores
=
}
"
)
assert
torch
.
all
(
assert
torch
.
all
(
abs
(
hf_scores
-
srt_scores
)
<
tolerance
abs
(
hf_scores
-
srt_scores
)
<
tolerance
...
@@ -83,9 +88,4 @@ class TestRewardModels(unittest.TestCase):
...
@@ -83,9 +88,4 @@ class TestRewardModels(unittest.TestCase):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
try
:
mp
.
set_start_method
(
"spawn"
)
except
RuntimeError
:
pass
unittest
.
main
()
unittest
.
main
()
test/srt/run_suite.py
View file @
2ce32db6
...
@@ -8,7 +8,7 @@ suites = {
...
@@ -8,7 +8,7 @@ suites = {
"models/test_embedding_models.py"
,
"models/test_embedding_models.py"
,
"models/test_generation_models.py"
,
"models/test_generation_models.py"
,
"models/test_lora.py"
,
"models/test_lora.py"
,
#
"models/test_reward_models.py",
"models/test_reward_models.py"
,
"sampling/penaltylib"
,
"sampling/penaltylib"
,
"test_chunked_prefill.py"
,
"test_chunked_prefill.py"
,
"test_double_sparsity.py"
,
"test_double_sparsity.py"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment