Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
5652c565
"deploy/vscode:/vscode.git/clone" did not exist on "e2800a45ad63d897facb65d165244dba37876416"
Unverified
Commit
5652c565
authored
Nov 24, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 24, 2024
Browse files
Update CI threshold & Improve code style (#2159)
parent
e3938b2f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
126 additions
and
41 deletions
+126
-41
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+34
-22
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+1
-0
python/sglang/srt/layers/fused_moe_patch.py
python/sglang/srt/layers/fused_moe_patch.py
+5
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+14
-9
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+6
-3
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+1
-1
test/srt/test_bench_serving.py
test/srt/test_bench_serving.py
+6
-6
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+59
-0
No files found.
.github/workflows/pr-test.yml
View file @
5652c565
...
@@ -50,7 +50,7 @@ jobs:
...
@@ -50,7 +50,7 @@ jobs:
timeout-minutes
:
25
timeout-minutes
:
25
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 run_suite.py --suite minimal --range-begin 0 --range-end
5
python3 run_suite.py --suite minimal --range-begin 0 --range-end
6
unit-test-backend-part-2
:
unit-test-backend-part-2
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
@@ -67,7 +67,7 @@ jobs:
...
@@ -67,7 +67,7 @@ jobs:
timeout-minutes
:
25
timeout-minutes
:
25
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 run_suite.py --suite minimal --range-begin
5
--range-end 14
python3 run_suite.py --suite minimal --range-begin
6
--range-end 14
unit-test-backend-part-3
:
unit-test-backend-part-3
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
@@ -103,6 +103,31 @@ jobs:
...
@@ -103,6 +103,31 @@ jobs:
cd test/srt
cd test/srt
python3 run_suite.py --suite minimal --range-begin 21
python3 run_suite.py --suite minimal --range-begin 21
unit-test-backend-2-gpu-part-1
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on
:
2-gpu-runner
steps
:
-
name
:
Checkout code
uses
:
actions/checkout@v3
-
name
:
Install dependencies
run
:
|
bash scripts/ci_install_dependency.sh
-
name
:
Evaluate data parallelism accuracy (DP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_data_parallelism.py
-
name
:
Evaluate MLA accuracy (TP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_mla.py
python3 test_mla_fp8.py
python3 test_dp_attention.py
performance-test-1-gpu-part-1
:
performance-test-1-gpu-part-1
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on
:
1-gpu-runner
runs-on
:
1-gpu-runner
...
@@ -178,23 +203,23 @@ jobs:
...
@@ -178,23 +203,23 @@ jobs:
run
:
|
run
:
|
bash scripts/ci_install_dependency.sh
bash scripts/ci_install_dependency.sh
-
name
:
Benchmark
offline throughput
(TP=2)
-
name
:
Benchmark
single latency
(TP=2)
timeout-minutes
:
10
timeout-minutes
:
10
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 -m unittest test_bench_
serving
.TestBench
Serving
.test_moe_
offline_throughput_
default
python3 -m unittest test_bench_
one_batch
.TestBench
OneBatch
.test_moe_default
-
name
:
Benchmark offline throughput
(w/o RadixAttention)
(TP=2)
-
name
:
Benchmark offline throughput (TP=2)
timeout-minutes
:
10
timeout-minutes
:
10
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_
without_radix_cache
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_
default
-
name
:
Benchmark
single latency
(TP=2)
-
name
:
Benchmark
offline throughput (w/o RadixAttention)
(TP=2)
timeout-minutes
:
10
timeout-minutes
:
10
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 -m unittest test_bench_
one_batch
.TestBench
OneBatch
.test_moe_
default
python3 -m unittest test_bench_
serving
.TestBench
Serving
.test_moe_
offline_throughput_without_radix_cache
accuracy-test-1-gpu
:
accuracy-test-1-gpu
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...
@@ -238,23 +263,10 @@ jobs:
...
@@ -238,23 +263,10 @@ jobs:
cd test/srt
cd test/srt
python3 test_moe_eval_accuracy_large.py
python3 test_moe_eval_accuracy_large.py
-
name
:
Evaluate MLA accuracy (TP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_mla.py
python3 test_mla_fp8.py
python3 test_dp_attention.py
-
name
:
Evaluate data parallelism accuracy (DP=2)
timeout-minutes
:
10
run
:
|
cd test/srt
python3 test_data_parallelism.py
finish
:
finish
:
needs
:
[
needs
:
[
unit-test-frontend
,
unit-test-backend-part-1
,
unit-test-backend-part-2
,
unit-test-backend-part-3
,
unit-test-backend-part-4
,
unit-test-frontend
,
unit-test-backend-part-1
,
unit-test-backend-part-2
,
unit-test-backend-part-3
,
unit-test-backend-part-4
,
unit-test-backend-2-gpu-part-1
,
performance-test-1-gpu-part-1
,
performance-test-1-gpu-part-2
,
performance-test-2-gpu
,
performance-test-1-gpu-part-1
,
performance-test-1-gpu-part-2
,
performance-test-2-gpu
,
accuracy-test-1-gpu
,
accuracy-test-2-gpu
accuracy-test-1-gpu
,
accuracy-test-2-gpu
]
]
...
...
python/sglang/bench_one_batch.py
View file @
5652c565
...
@@ -212,6 +212,7 @@ def extend(reqs, model_runner):
...
@@ -212,6 +212,7 @@ def extend(reqs, model_runner):
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
tree_cache
=
None
,
tree_cache
=
None
,
model_config
=
model_runner
.
model_config
,
model_config
=
model_runner
.
model_config
,
enable_overlap
=
False
,
)
)
batch
.
prepare_for_extend
()
batch
.
prepare_for_extend
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
...
python/sglang/srt/layers/fused_moe
/
patch.py
→
python/sglang/srt/layers/fused_moe
_
patch.py
View file @
5652c565
"""
Torch-native implementation for FusedMoE. This is used for torch.compile.
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
"""
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
import
torch
import
torch
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
5652c565
...
@@ -437,9 +437,12 @@ class ScheduleBatch:
...
@@ -437,9 +437,12 @@ class ScheduleBatch:
token_to_kv_pool
:
BaseTokenToKVPool
=
None
token_to_kv_pool
:
BaseTokenToKVPool
=
None
tree_cache
:
BasePrefixCache
=
None
tree_cache
:
BasePrefixCache
=
None
#
For utility
#
Batch configs
model_config
:
ModelConfig
=
None
model_config
:
ModelConfig
=
None
forward_mode
:
ForwardMode
=
None
forward_mode
:
ForwardMode
=
None
enable_overlap
:
bool
=
False
# Sampling info
sampling_info
:
SamplingBatchInfo
=
None
sampling_info
:
SamplingBatchInfo
=
None
next_batch_sampling_info
:
SamplingBatchInfo
=
None
next_batch_sampling_info
:
SamplingBatchInfo
=
None
...
@@ -488,10 +491,11 @@ class ScheduleBatch:
...
@@ -488,10 +491,11 @@ class ScheduleBatch:
def
init_new
(
def
init_new
(
cls
,
cls
,
reqs
:
List
[
Req
],
reqs
:
List
[
Req
],
req_to_token_pool
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
,
token_to_kv_pool
:
ReqToTokenPool
,
tree_cache
,
tree_cache
:
BasePrefixCache
,
model_config
,
model_config
:
ModelConfig
,
enable_overlap
:
bool
,
):
):
return
cls
(
return
cls
(
reqs
=
reqs
,
reqs
=
reqs
,
...
@@ -499,6 +503,7 @@ class ScheduleBatch:
...
@@ -499,6 +503,7 @@ class ScheduleBatch:
token_to_kv_pool
=
token_to_kv_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
tree_cache
=
tree_cache
,
tree_cache
=
tree_cache
,
model_config
=
model_config
,
model_config
=
model_config
,
enable_overlap
=
enable_overlap
,
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
),
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
),
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_grammar
=
any
(
req
.
grammar
for
req
in
reqs
),
has_grammar
=
any
(
req
.
grammar
for
req
in
reqs
),
...
@@ -612,7 +617,7 @@ class ScheduleBatch:
...
@@ -612,7 +617,7 @@ class ScheduleBatch:
assert
len
(
self
.
out_cache_loc
)
==
self
.
extend_num_tokens
assert
len
(
self
.
out_cache_loc
)
==
self
.
extend_num_tokens
def
prepare_for_extend
(
self
,
enable_overlap_schedule
:
bool
=
False
):
def
prepare_for_extend
(
self
):
self
.
forward_mode
=
ForwardMode
.
EXTEND
self
.
forward_mode
=
ForwardMode
.
EXTEND
bs
=
len
(
self
.
reqs
)
bs
=
len
(
self
.
reqs
)
...
@@ -706,7 +711,7 @@ class ScheduleBatch:
...
@@ -706,7 +711,7 @@ class ScheduleBatch:
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
self
,
self
.
model_config
.
vocab_size
,
self
.
model_config
.
vocab_size
,
enable_overlap_schedule
=
enable_overlap
_schedule
,
enable_overlap_schedule
=
self
.
enable_overlap
,
)
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
...
@@ -897,7 +902,7 @@ class ScheduleBatch:
...
@@ -897,7 +902,7 @@ class ScheduleBatch:
self
.
seq_lens_sum
=
0
self
.
seq_lens_sum
=
0
self
.
extend_num_tokens
=
0
self
.
extend_num_tokens
=
0
def
prepare_for_decode
(
self
,
enable_overlap
:
bool
=
False
):
def
prepare_for_decode
(
self
):
self
.
forward_mode
=
ForwardMode
.
DECODE
self
.
forward_mode
=
ForwardMode
.
DECODE
self
.
input_ids
=
self
.
output_ids
self
.
input_ids
=
self
.
output_ids
...
@@ -914,7 +919,7 @@ class ScheduleBatch:
...
@@ -914,7 +919,7 @@ class ScheduleBatch:
else
:
else
:
locs
=
self
.
seq_lens
locs
=
self
.
seq_lens
if
enable_overlap
:
if
self
.
enable_overlap
:
# Do not use in-place operations in the overlap mode
# Do not use in-place operations in the overlap mode
self
.
req_to_token_pool
.
write
(
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
locs
),
self
.
out_cache_loc
(
self
.
req_pool_indices
,
locs
),
self
.
out_cache_loc
...
...
python/sglang/srt/managers/scheduler.py
View file @
5652c565
...
@@ -466,6 +466,7 @@ class Scheduler:
...
@@ -466,6 +466,7 @@ class Scheduler:
self
.
token_to_kv_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
self
.
tree_cache
,
self
.
model_config
,
self
.
model_config
,
self
.
enable_overlap
,
)
)
idle_batch
.
prepare_for_idle
()
idle_batch
.
prepare_for_idle
()
return
idle_batch
return
idle_batch
...
@@ -842,14 +843,15 @@ class Scheduler:
...
@@ -842,14 +843,15 @@ class Scheduler:
self
.
token_to_kv_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
self
.
tree_cache
,
self
.
model_config
,
self
.
model_config
,
self
.
enable_overlap
,
)
)
new_batch
.
prepare_for_extend
(
self
.
enable_overlap
)
new_batch
.
prepare_for_extend
()
# Mixed-style chunked prefill
# Mixed-style chunked prefill
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
if
self
.
is_mixed_chunk
and
self
.
running_batch
is
not
None
:
self
.
running_batch
.
filter_batch
()
self
.
running_batch
.
filter_batch
()
if
not
self
.
running_batch
.
is_empty
():
if
not
self
.
running_batch
.
is_empty
():
self
.
running_batch
.
prepare_for_decode
(
self
.
enable_overlap
)
self
.
running_batch
.
prepare_for_decode
()
new_batch
.
mix_with_running
(
self
.
running_batch
)
new_batch
.
mix_with_running
(
self
.
running_batch
)
new_batch
.
decoding_reqs
=
self
.
running_batch
.
reqs
new_batch
.
decoding_reqs
=
self
.
running_batch
.
reqs
self
.
running_batch
=
None
self
.
running_batch
=
None
...
@@ -900,7 +902,7 @@ class Scheduler:
...
@@ -900,7 +902,7 @@ class Scheduler:
self
.
batch_is_full
=
False
self
.
batch_is_full
=
False
# Update batch tensors
# Update batch tensors
batch
.
prepare_for_decode
(
self
.
enable_overlap
)
batch
.
prepare_for_decode
()
return
batch
return
batch
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
def
run_batch
(
self
,
batch
:
ScheduleBatch
):
...
@@ -1055,6 +1057,7 @@ class Scheduler:
...
@@ -1055,6 +1057,7 @@ class Scheduler:
continue
continue
if
self
.
enable_overlap
and
req
.
finished
():
if
self
.
enable_overlap
and
req
.
finished
():
# Free the one delayed token
self
.
token_to_kv_pool
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
])
self
.
token_to_kv_pool
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
])
continue
continue
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
5652c565
...
@@ -23,7 +23,7 @@ import torch
...
@@ -23,7 +23,7 @@ import torch
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.fused_moe
.
patch
import
fused_moe_forward_native
from
sglang.srt.layers.fused_moe
_
patch
import
fused_moe_forward_native
from
sglang.srt.layers.logits_processor
import
(
from
sglang.srt.layers.logits_processor
import
(
LogitsMetadata
,
LogitsMetadata
,
LogitsProcessor
,
LogitsProcessor
,
...
...
test/srt/test_bench_serving.py
View file @
5652c565
...
@@ -20,7 +20,7 @@ class TestBenchServing(unittest.TestCase):
...
@@ -20,7 +20,7 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
self
.
assertGreater
(
res
[
"output_throughput"
],
28
50
)
self
.
assertGreater
(
res
[
"output_throughput"
],
33
50
)
def
test_offline_throughput_non_stream_small_batch_size
(
self
):
def
test_offline_throughput_non_stream_small_batch_size
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
@@ -47,7 +47,7 @@ class TestBenchServing(unittest.TestCase):
...
@@ -47,7 +47,7 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
self
.
assertGreater
(
res
[
"output_throughput"
],
290
0
)
self
.
assertGreater
(
res
[
"output_throughput"
],
335
0
)
def
test_offline_throughput_without_chunked_prefill
(
self
):
def
test_offline_throughput_without_chunked_prefill
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
@@ -74,7 +74,7 @@ class TestBenchServing(unittest.TestCase):
...
@@ -74,7 +74,7 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
self
.
assertGreater
(
res
[
"output_throughput"
],
29
50
)
self
.
assertGreater
(
res
[
"output_throughput"
],
34
50
)
def
test_offline_throughput_default_fp8
(
self
):
def
test_offline_throughput_default_fp8
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
@@ -85,7 +85,7 @@ class TestBenchServing(unittest.TestCase):
...
@@ -85,7 +85,7 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
self
.
assertGreater
(
res
[
"output_throughput"
],
3
20
0
)
self
.
assertGreater
(
res
[
"output_throughput"
],
3
85
0
)
def
test_online_latency_default
(
self
):
def
test_online_latency_default
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
@@ -109,7 +109,7 @@ class TestBenchServing(unittest.TestCase):
...
@@ -109,7 +109,7 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
self
.
assertGreater
(
res
[
"output_throughput"
],
190
0
)
self
.
assertGreater
(
res
[
"output_throughput"
],
215
0
)
def
test_moe_offline_throughput_without_radix_cache
(
self
):
def
test_moe_offline_throughput_without_radix_cache
(
self
):
res
=
run_bench_serving
(
res
=
run_bench_serving
(
...
@@ -120,7 +120,7 @@ class TestBenchServing(unittest.TestCase):
...
@@ -120,7 +120,7 @@ class TestBenchServing(unittest.TestCase):
)
)
if
is_in_ci
():
if
is_in_ci
():
self
.
assertGreater
(
res
[
"output_throughput"
],
1
9
50
)
self
.
assertGreater
(
res
[
"output_throughput"
],
2
150
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/test_srt_endpoint.py
View file @
5652c565
...
@@ -6,6 +6,7 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
...
@@ -6,6 +6,7 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
import
json
import
json
import
unittest
import
unittest
import
numpy
as
np
import
requests
import
requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.srt.utils
import
kill_child_process
...
@@ -132,6 +133,7 @@ class TestSRTEndpoint(unittest.TestCase):
...
@@ -132,6 +133,7 @@ class TestSRTEndpoint(unittest.TestCase):
)
)
def
test_logprob_with_chunked_prefill
(
self
):
def
test_logprob_with_chunked_prefill
(
self
):
"""Test a long prompt that requests output logprobs will not hit OOM."""
new_tokens
=
4
new_tokens
=
4
prompts
=
"I have a very good idea on this. "
*
8000
prompts
=
"I have a very good idea on this. "
*
8000
...
@@ -154,6 +156,63 @@ class TestSRTEndpoint(unittest.TestCase):
...
@@ -154,6 +156,63 @@ class TestSRTEndpoint(unittest.TestCase):
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
new_tokens
)
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
new_tokens
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
new_tokens
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
new_tokens
)
def
test_logprob_match
(
self
):
"""Test the output logprobs are close to the input logprobs if we run a prefill again."""
def
run_generate
(
prompt
,
return_logprob
=
False
,
max_new_tokens
=
512
,
logprob_start_len
=-
1
):
if
isinstance
(
prompt
,
str
):
prompt_kwargs
=
{
"text"
:
prompt
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompt
}
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
**
prompt_kwargs
,
"sampling_params"
:
{
"temperature"
:
1.0
,
"max_new_tokens"
:
max_new_tokens
,
"ignore_eos"
:
True
,
},
"return_logprob"
:
return_logprob
,
"return_text_in_logprobs"
:
True
,
"logprob_start_len"
:
logprob_start_len
,
},
)
return
response
.
json
()
prompt
=
"I have a very good idea on how to"
gen
=
run_generate
(
prompt
,
return_logprob
=
True
,
logprob_start_len
=
0
)
output_logprobs
=
np
.
array
(
[
x
[
0
]
for
x
in
gen
[
"meta_info"
][
"output_token_logprobs"
]]
)
num_prompts_tokens
=
gen
[
"meta_info"
][
"prompt_tokens"
]
input_tokens
=
[
x
[
1
]
for
x
in
gen
[
"meta_info"
][
"input_token_logprobs"
]]
output_tokens
=
[
x
[
1
]
for
x
in
gen
[
"meta_info"
][
"output_token_logprobs"
]]
new_prompt
=
input_tokens
+
output_tokens
score
=
run_generate
(
new_prompt
,
return_logprob
=
True
,
logprob_start_len
=
0
,
max_new_tokens
=
0
)
output_logprobs_score
=
np
.
array
(
[
x
[
0
]
for
x
in
score
[
"meta_info"
][
"input_token_logprobs"
][
num_prompts_tokens
:]
]
)
print
(
f
"
{
output_logprobs
[
-
10
:]
=
}
"
)
print
(
f
"
{
output_logprobs_score
[
-
10
:]
=
}
"
)
diff
=
np
.
abs
(
output_logprobs
-
output_logprobs_score
)
max_diff
=
np
.
max
(
diff
)
self
.
assertLess
(
max_diff
,
0.2
)
def
test_get_server_info
(
self
):
def
test_get_server_info
(
self
):
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
response_json
=
response
.
json
()
response_json
=
response
.
json
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment