Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
70cc0749
Unverified
Commit
70cc0749
authored
Aug 03, 2024
by
Ying Sheng
Committed by
GitHub
Aug 03, 2024
Browse files
Add model accuracy test - step 1 (#866)
parent
7dd8a7e6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
330 additions
and
3 deletions
+330
-3
.github/workflows/unit-test.yml
.github/workflows/unit-test.yml
+2
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+24
-3
python/sglang/test/runners.py
python/sglang/test/runners.py
+237
-0
test/srt/models/test_causal_models.py
test/srt/models/test_causal_models.py
+67
-0
No files found.
.github/workflows/unit-test.yml
View file @
70cc0749
...
...
@@ -35,6 +35,7 @@ jobs:
pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ --force-reinstall
pip install --upgrade transformers
pip install accelerate
-
name
:
Test Frontend Language with SRT Backend
run
:
|
...
...
@@ -50,6 +51,7 @@ jobs:
run
:
|
cd test/srt
python3 test_eval_accuracy.py
python3 models/test_causal_models.py
-
name
:
Test Frontend Language with OpenAI Backend
run
:
|
...
...
python/sglang/srt/server.py
View file @
70cc0749
...
...
@@ -28,7 +28,7 @@ import sys
import
threading
import
time
from
http
import
HTTPStatus
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Union
# Fix a bug of Python threading
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
...
...
@@ -481,10 +481,10 @@ class Runtime:
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
)
async
def
a
dd_request
(
async
def
a
sync_generate
(
self
,
prompt
:
str
,
sampling_params
:
Dict
,
sampling_params
:
Optional
[
Dict
]
=
None
,
):
json_data
=
{
"text"
:
prompt
,
...
...
@@ -507,5 +507,26 @@ class Runtime:
yield
cur
pos
+=
len
(
cur
)
add_request
=
async_generate
def
generate
(
self
,
prompt
:
str
,
sampling_params
:
Optional
[
Dict
]
=
None
,
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
False
,
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
):
json_data
=
{
"text"
:
prompt
,
"sampling_params"
:
sampling_params
,
"return_logprob"
:
return_logprob
,
"top_logprobs_num"
:
top_logprobs_num
,
}
response
=
requests
.
post
(
self
.
url
+
"/generate"
,
json
=
json_data
,
)
return
json
.
dumps
(
response
.
json
())
def
__del__
(
self
):
self
.
shutdown
()
python/sglang/test/runners.py
0 → 100644
View file @
70cc0749
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import
json
import
multiprocessing
from
dataclasses
import
dataclass
from
typing
import
List
,
Union
import
torch
import
torch.nn.functional
as
F
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
sglang.srt.server
import
Runtime
DEFAULT_PROMPTS
=
[
"The capital of France is"
,
"The capital of the United Kindom is"
,
"Today is a sunny day and I like"
,
]
NUM_TOP_LOGPROBS
=
5
def
is_embedding_model
(
model_path
):
# FIXME incomplete list
if
"e5-mistral-7b-instruct"
in
model_path
.
lower
():
return
True
return
False
def
get_dtype_str
(
torch_dtype
):
if
torch_dtype
is
torch
.
float16
:
return
"float16"
else
:
raise
NotImplementedError
()
@
dataclass
class
ModelOutput
:
output_strs
:
str
=
None
top_input_logprobs
:
torch
.
Tensor
=
None
top_output_logprobs
:
torch
.
Tensor
=
None
embed_logits
:
torch
.
Tensor
=
None
class
HFRunner
:
def
__init__
(
self
,
model_path
,
torch_dtype
=
torch
.
float16
,
is_embedding_model
=
None
,
):
self
.
in_queue
=
multiprocessing
.
Queue
()
self
.
out_queue
=
multiprocessing
.
Queue
()
self
.
model_proc
=
multiprocessing
.
Process
(
target
=
self
.
start_model_process
,
args
=
(
self
.
in_queue
,
self
.
out_queue
,
model_path
,
torch_dtype
,
is_embedding_model
,
),
)
self
.
model_proc
.
start
()
def
start_model_process
(
self
,
in_queue
,
out_queue
,
model_path
,
torch_dtype
,
is_embedding_model
):
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
True
,
)
self
.
is_embedding_model
=
(
is_embedding_model
(
model_path
)
if
is_embedding_model
is
None
else
is_embedding_model
)
if
not
self
.
is_embedding_model
:
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
torch_dtype
=
torch_dtype
,
low_cpu_mem_usage
=
True
,
trust_remote_code
=
True
,
).
cuda
()
else
:
from
sentence_transformers
import
SentenceTransformer
self
.
model
=
SentenceTransformer
(
model_path
,
device
=
"cpu"
,
).
to
(
dtype
=
torch_dtype
)
while
True
:
prompts
,
max_new_tokens
=
in_queue
.
get
()
if
prompts
is
not
None
:
if
not
self
.
is_embedding_model
:
output_strs
=
[]
prefill_logprobs
=
[]
for
p
in
prompts
:
if
isinstance
(
p
,
str
):
input_ids
=
self
.
tokenizer
.
encode
(
p
,
return_tensors
=
"pt"
).
cuda
()
else
:
input_ids
=
torch
.
tensor
([
p
],
device
=
"cuda"
)
output_ids
=
self
.
model
.
generate
(
input_ids
,
do_sample
=
False
,
max_new_tokens
=
max_new_tokens
)
output_strs
.
append
(
self
.
tokenizer
.
decode
(
output_ids
[
0
]))
logits
=
self
.
model
.
forward
(
input_ids
).
logits
[
0
]
logprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
).
tolist
()
# index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
# print("index", index_of_max)
logprobs
=
[
sorted
(
token_logprobs
,
reverse
=
True
)[:
NUM_TOP_LOGPROBS
]
for
token_logprobs
in
logprobs
]
prefill_logprobs
.
append
(
logprobs
)
out_queue
.
put
(
ModelOutput
(
output_strs
=
output_strs
,
top_input_logprobs
=
prefill_logprobs
)
)
else
:
assert
isinstance
(
prompts
,
List
[
str
])
logits
=
self
.
model
.
encode
(
prompts
).
tolist
()
out_queue
.
put
(
ModelOutput
(
embed_logits
=
logits
))
def
forward
(
self
,
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]]
=
DEFAULT_PROMPTS
,
max_new_tokens
=
64
,
):
self
.
in_queue
.
put
((
prompts
,
max_new_tokens
))
return
self
.
out_queue
.
get
()
def
terminate
(
self
):
self
.
model_proc
.
terminate
()
self
.
in_queue
=
self
.
out_queue
=
None
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
self
.
model_proc
.
terminate
()
self
.
in_queue
=
self
.
out_queue
=
None
class
SRTRunner
:
def
__init__
(
self
,
model_path
,
tp_size
=
1
,
torch_dtype
=
torch
.
float16
,
is_embedding_model
=
None
,
):
self
.
is_embedding_model
=
(
is_embedding_model
(
model_path
)
if
is_embedding_model
is
None
else
is_embedding_model
)
if
self
.
is_embedding_model
:
raise
NotImplementedError
()
self
.
runtime
=
Runtime
(
model_path
=
model_path
,
tp_size
=
tp_size
,
dtype
=
get_dtype_str
(
torch_dtype
),
)
def
forward
(
self
,
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]]
=
DEFAULT_PROMPTS
,
max_new_tokens
=
64
,
):
# the return value contains logprobs from prefill
output_strs
=
[]
top_input_logprobs
=
[]
sampling_params
=
{
"max_new_tokens"
:
max_new_tokens
,
"temperature"
:
0
}
for
prompt
in
prompts
:
response
=
self
.
runtime
.
generate
(
prompt
,
sampling_params
=
sampling_params
,
return_logprob
=
True
,
top_logprobs_num
=
NUM_TOP_LOGPROBS
,
)
response
=
json
.
loads
(
response
)
output_strs
.
append
(
response
[
"text"
])
top_input_logprobs
.
append
(
[
[
tup
[
0
]
for
tup
in
x
[:
NUM_TOP_LOGPROBS
]]
for
x
in
response
[
"meta_info"
][
"input_top_logprobs"
][
1
:]
]
+
[
[
tup
[
0
]
for
tup
in
response
[
"meta_info"
][
"output_top_logprobs"
][
0
][
:
NUM_TOP_LOGPROBS
]
]
]
)
# print(response["meta_info"]["output_top_logprobs"][0])
return
ModelOutput
(
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
)
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
self
.
runtime
.
shutdown
()
del
self
.
runtime
test/srt/models/test_causal_models.py
0 → 100644
View file @
70cc0749
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import
unittest
import
torch
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
MODELS
=
[
(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
1
),
# ("meta-llama/Meta-Llama-3.1-8B-Instruct", 2),
]
TORCH_DTYPES
=
[
torch
.
float16
]
class
TestCausalModels
(
unittest
.
TestCase
):
def
assert_close_prefill_logits
(
self
,
prompts
,
model_path
,
tp_size
,
torch_dtype
,
)
->
None
:
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
is_embedding_model
=
False
)
as
hf_runner
:
hf_outputs
=
hf_runner
.
forward
(
prompts
)
with
SRTRunner
(
model_path
,
tp_size
=
tp_size
,
torch_dtype
=
torch_dtype
,
is_embedding_model
=
False
,
)
as
srt_runner
:
srt_outputs
=
srt_runner
.
forward
(
prompts
)
for
i
in
range
(
len
(
prompts
)):
hf_logprobs
=
torch
.
Tensor
(
hf_outputs
.
top_input_logprobs
[
i
])
srt_logprobs
=
torch
.
Tensor
(
srt_outputs
.
top_input_logprobs
[
i
])
tolerance
=
2e-2
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
tolerance
),
f
"prefill logprobs not all close"
def
test_prefill_logits
(
self
):
for
model
,
tp_size
in
MODELS
:
for
torch_dtype
in
TORCH_DTYPES
:
self
.
assert_close_prefill_logits
(
DEFAULT_PROMPTS
,
model
,
tp_size
,
torch_dtype
)
if
__name__
==
"__main__"
:
unittest
.
main
(
warnings
=
"ignore"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment