Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8233cc10
Unverified
Commit
8233cc10
authored
May 23, 2025
by
Byron Hsu
Committed by
GitHub
May 23, 2025
Browse files
[PD] Support logprob & Add failure test (#6558)
parent
1b2e8f76
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
580 additions
and
226 deletions
+580
-226
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+58
-44
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+42
-8
python/sglang/srt/disaggregation/mini_lb.py
python/sglang/srt/disaggregation/mini_lb.py
+57
-24
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+71
-33
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+84
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+6
-21
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+2
-36
scripts/playground/disaggregation/cli-logprob.py
scripts/playground/disaggregation/cli-logprob.py
+22
-0
test/srt/test_disaggregation.py
test/srt/test_disaggregation.py
+238
-56
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
8233cc10
...
@@ -36,6 +36,7 @@ from sglang.srt.disaggregation.utils import (
...
@@ -36,6 +36,7 @@ from sglang.srt.disaggregation.utils import (
DisaggregationMode
,
DisaggregationMode
,
FakeBootstrapHost
,
FakeBootstrapHost
,
KVClassType
,
KVClassType
,
MetadataBuffers
,
ReqToMetadataIdxAllocator
,
ReqToMetadataIdxAllocator
,
TransferBackend
,
TransferBackend
,
get_kv_class
,
get_kv_class
,
...
@@ -78,8 +79,7 @@ class DecodePreallocQueue:
...
@@ -78,8 +79,7 @@ class DecodePreallocQueue:
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
draft_token_to_kv_pool
:
Optional
[
KVCache
],
draft_token_to_kv_pool
:
Optional
[
KVCache
],
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
metadata_buffers
:
List
[
torch
.
Tensor
],
metadata_buffers
:
MetadataBuffers
,
aux_dtype
:
torch
.
dtype
,
scheduler
:
Scheduler
,
scheduler
:
Scheduler
,
transfer_queue
:
DecodeTransferQueue
,
transfer_queue
:
DecodeTransferQueue
,
tree_cache
:
BasePrefixCache
,
tree_cache
:
BasePrefixCache
,
...
@@ -94,7 +94,6 @@ class DecodePreallocQueue:
...
@@ -94,7 +94,6 @@ class DecodePreallocQueue:
self
.
token_to_kv_pool
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
token_to_kv_pool
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
draft_token_to_kv_pool
=
draft_token_to_kv_pool
self
.
draft_token_to_kv_pool
=
draft_token_to_kv_pool
self
.
is_mla_backend
=
is_mla_backend
(
self
.
token_to_kv_pool
)
self
.
is_mla_backend
=
is_mla_backend
(
self
.
token_to_kv_pool
)
self
.
aux_dtype
=
aux_dtype
self
.
metadata_buffers
=
metadata_buffers
self
.
metadata_buffers
=
metadata_buffers
self
.
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
self
.
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
...
@@ -133,15 +132,9 @@ class DecodePreallocQueue:
...
@@ -133,15 +132,9 @@ class DecodePreallocQueue:
kv_args
.
kv_data_lens
=
kv_data_lens
kv_args
.
kv_data_lens
=
kv_data_lens
kv_args
.
kv_item_lens
=
kv_item_lens
kv_args
.
kv_item_lens
=
kv_item_lens
kv_args
.
aux_data_ptrs
=
[
kv_args
.
aux_data_ptrs
,
kv_args
.
aux_data_lens
,
kv_args
.
aux_item_lens
=
(
output_id_tensor
.
data_ptr
()
for
output_id_tensor
in
self
.
metadata_buffers
self
.
metadata_buffers
.
get_buf_infos
()
]
)
kv_args
.
aux_data_lens
=
[
metadata_buffer
.
nbytes
for
metadata_buffer
in
self
.
metadata_buffers
]
kv_args
.
aux_item_lens
=
[
metadata_buffer
[
0
].
nbytes
for
metadata_buffer
in
self
.
metadata_buffers
]
kv_args
.
ib_device
=
self
.
scheduler
.
server_args
.
disaggregation_ib_device
kv_args
.
ib_device
=
self
.
scheduler
.
server_args
.
disaggregation_ib_device
kv_args
.
gpu_id
=
self
.
scheduler
.
gpu_id
kv_args
.
gpu_id
=
self
.
scheduler
.
gpu_id
kv_manager_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
MANAGER
)
kv_manager_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
MANAGER
)
...
@@ -211,7 +204,18 @@ class DecodePreallocQueue:
...
@@ -211,7 +204,18 @@ class DecodePreallocQueue:
indices_to_remove
=
set
()
indices_to_remove
=
set
()
allocatable_tokens
=
self
.
_allocatable_tokens
()
allocatable_tokens
=
self
.
_allocatable_tokens
()
# First, remove all failed requests from the queue
for
i
,
decode_req
in
enumerate
(
self
.
queue
):
for
i
,
decode_req
in
enumerate
(
self
.
queue
):
if
isinstance
(
decode_req
.
req
.
finished_reason
,
FINISH_ABORT
):
self
.
scheduler
.
stream_output
(
[
decode_req
.
req
],
decode_req
.
req
.
return_logprob
)
indices_to_remove
.
add
(
i
)
for
i
,
decode_req
in
enumerate
(
self
.
queue
):
if
i
in
indices_to_remove
:
continue
if
not
decode_req
.
waiting_for_input
:
if
not
decode_req
.
waiting_for_input
:
continue
continue
...
@@ -331,7 +335,7 @@ class DecodeTransferQueue:
...
@@ -331,7 +335,7 @@ class DecodeTransferQueue:
self
,
self
,
gloo_group
:
ProcessGroup
,
gloo_group
:
ProcessGroup
,
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
metadata_buffers
:
torch
.
Tensor
,
metadata_buffers
:
MetadataBuffers
,
scheduler
:
Scheduler
,
scheduler
:
Scheduler
,
tree_cache
:
BasePrefixCache
,
tree_cache
:
BasePrefixCache
,
):
):
...
@@ -342,11 +346,11 @@ class DecodeTransferQueue:
...
@@ -342,11 +346,11 @@ class DecodeTransferQueue:
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
self
.
tree_cache
=
tree_cache
self
.
tree_cache
=
tree_cache
def
add
(
self
,
req_conn
:
DecodeRequest
)
->
None
:
def
add
(
self
,
decode_req
:
DecodeRequest
)
->
None
:
self
.
queue
.
append
(
req_conn
)
self
.
queue
.
append
(
decode_req
)
def
extend
(
self
,
req_conns
)
->
None
:
def
extend
(
self
,
decode_reqs
:
List
[
DecodeRequest
]
)
->
None
:
self
.
queue
.
extend
(
req_conn
s
)
self
.
queue
.
extend
(
decode_req
s
)
def
pop_transferred
(
self
)
->
List
[
DecodeRequest
]:
def
pop_transferred
(
self
)
->
List
[
DecodeRequest
]:
if
not
self
.
queue
:
if
not
self
.
queue
:
...
@@ -356,14 +360,6 @@ class DecodeTransferQueue:
...
@@ -356,14 +360,6 @@ class DecodeTransferQueue:
[
decode_req
.
kv_receiver
for
decode_req
in
self
.
queue
],
self
.
gloo_group
[
decode_req
.
kv_receiver
for
decode_req
in
self
.
queue
],
self
.
gloo_group
)
)
# First, remove all failed requests from the queue
for
i
,
decode_req
in
enumerate
(
self
.
queue
):
if
isinstance
(
decode_req
.
req
.
finished_reason
,
FINISH_ABORT
):
self
.
scheduler
.
stream_output
(
[
decode_req
.
req
],
decode_req
.
req
.
return_logprob
)
indices_to_remove
.
add
(
i
)
transferred_reqs
=
[]
transferred_reqs
=
[]
indices_to_remove
=
set
()
indices_to_remove
=
set
()
for
i
,
(
decode_req
,
poll
)
in
enumerate
(
zip
(
self
.
queue
,
polls
)):
for
i
,
(
decode_req
,
poll
)
in
enumerate
(
zip
(
self
.
queue
,
polls
)):
...
@@ -387,16 +383,37 @@ class DecodeTransferQueue:
...
@@ -387,16 +383,37 @@ class DecodeTransferQueue:
indices_to_remove
.
add
(
i
)
indices_to_remove
.
add
(
i
)
continue
continue
elif
poll
==
KVPoll
.
Success
:
elif
poll
==
KVPoll
.
Success
:
# pop and push it to waiting queue
idx
=
decode_req
.
metadata_buffer_index
idx
=
decode_req
.
metadata_buffer_index
assert
len
(
decode_req
.
req
.
output_ids
)
==
0
(
output_id_buffer
=
self
.
metadata_buffers
[
0
]
output_id
,
# the last dimension is padded by the same values.
output_token_logprobs_val
,
output_id
=
output_id_buffer
[
idx
][
0
].
item
()
output_token_logprobs_idx
,
assert
len
(
decode_req
.
req
.
output_ids
)
==
0
output_top_logprobs_val
,
assert
decode_req
.
req
.
transferred_output_id
is
None
output_top_logprobs_idx
,
decode_req
.
req
.
transferred_output_id
=
output_id
)
=
self
.
metadata_buffers
.
get_buf
(
idx
)
transferred_reqs
.
append
(
decode_req
)
decode_req
.
req
.
output_ids
.
append
(
output_id
[
0
].
item
())
if
decode_req
.
req
.
return_logprob
:
decode_req
.
req
.
output_token_logprobs_val
.
append
(
output_token_logprobs_val
[
0
].
item
()
)
decode_req
.
req
.
output_token_logprobs_idx
.
append
(
output_token_logprobs_idx
[
0
].
item
()
)
decode_req
.
req
.
output_top_logprobs_val
.
append
(
output_top_logprobs_val
[
:
decode_req
.
req
.
top_logprobs_num
].
tolist
()
)
decode_req
.
req
.
output_top_logprobs_idx
.
append
(
output_top_logprobs_idx
[
:
decode_req
.
req
.
top_logprobs_num
].
tolist
()
)
transferred_reqs
.
append
(
decode_req
.
req
)
indices_to_remove
.
add
(
i
)
indices_to_remove
.
add
(
i
)
elif
poll
in
[
elif
poll
in
[
KVPoll
.
Bootstrapping
,
KVPoll
.
Bootstrapping
,
...
@@ -451,7 +468,9 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -451,7 +468,9 @@ class SchedulerDisaggregationDecodeMixin:
# Generate fake extend output.
# Generate fake extend output.
if
batch
.
forward_mode
.
is_extend
():
if
batch
.
forward_mode
.
is_extend
():
# Note: Logprobs should be handled on the prefill engine.
# Note: Logprobs should be handled on the prefill engine.
self
.
stream_output
(
batch
.
reqs
,
False
)
self
.
stream_output
(
batch
.
reqs
,
any
(
req
.
return_logprob
for
req
in
batch
.
reqs
)
)
if
prepare_dp_attn_flag
:
if
prepare_dp_attn_flag
:
self
.
_prepare_idle_batch_and_run
(
None
)
self
.
_prepare_idle_batch_and_run
(
None
)
else
:
else
:
...
@@ -497,7 +516,9 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -497,7 +516,9 @@ class SchedulerDisaggregationDecodeMixin:
# Generate fake extend output.
# Generate fake extend output.
if
batch
.
forward_mode
.
is_extend
():
if
batch
.
forward_mode
.
is_extend
():
# Note: Logprobs should be handled on the prefill engine.
# Note: Logprobs should be handled on the prefill engine.
self
.
stream_output
(
batch
.
reqs
,
False
)
self
.
stream_output
(
batch
.
reqs
,
any
(
req
.
return_logprob
for
req
in
batch
.
reqs
)
)
if
prepare_dp_attn_flag
:
if
prepare_dp_attn_flag
:
batch_
,
result
=
self
.
_prepare_idle_batch_and_run
(
batch_
,
result
=
self
.
_prepare_idle_batch_and_run
(
None
,
delay_process
=
True
None
,
delay_process
=
True
...
@@ -618,15 +639,8 @@ class SchedulerDisaggregationDecodeMixin:
...
@@ -618,15 +639,8 @@ class SchedulerDisaggregationDecodeMixin:
def
process_decode_queue
(
self
:
Scheduler
):
def
process_decode_queue
(
self
:
Scheduler
):
req_conns
=
self
.
disagg_decode_prealloc_queue
.
pop_preallocated
()
req_conns
=
self
.
disagg_decode_prealloc_queue
.
pop_preallocated
()
def
_num_pre_alloc
(
req
):
return
len
(
req
.
req
.
origin_input_ids
)
+
max
(
len
(
req
.
req
.
output_ids
)
-
1
,
0
)
self
.
num_tokens_pre_allocated
+=
sum
(
_num_pre_alloc
(
req
)
for
req
in
req_conns
)
self
.
disagg_decode_transfer_queue
.
extend
(
req_conns
)
self
.
disagg_decode_transfer_queue
.
extend
(
req_conns
)
alloc_reqs
=
(
alloc_reqs
=
(
self
.
disagg_decode_transfer_queue
.
pop_transferred
()
self
.
disagg_decode_transfer_queue
.
pop_transferred
()
)
# the requests which kv has arrived
)
# the requests which kv has arrived
self
.
num_tokens_pre_allocated
-=
sum
(
_num_pre_alloc
(
req
)
for
req
in
alloc_reqs
)
self
.
waiting_queue
.
extend
(
alloc_reqs
)
self
.
waiting_queue
.
extend
([
req
.
req
for
req
in
alloc_reqs
])
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
View file @
8233cc10
...
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
...
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
import
torch
import
torch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -76,6 +76,11 @@ class ScheduleBatchDisaggregationDecodeMixin:
...
@@ -76,6 +76,11 @@ class ScheduleBatchDisaggregationDecodeMixin:
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
out_cache_loc
=
out_cache_loc
self
.
out_cache_loc
=
out_cache_loc
self
.
seq_lens_sum
=
sum
(
seq_lens
)
self
.
seq_lens_sum
=
sum
(
seq_lens
)
if
self
.
return_logprob
:
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
self
.
token_ids_logprobs
=
[
r
.
token_ids_logprob
for
r
in
reqs
]
self
.
extend_num_tokens
=
extend_num_tokens
self
.
extend_num_tokens
=
extend_num_tokens
self
.
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
self
.
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
...
@@ -94,12 +99,41 @@ class ScheduleBatchDisaggregationDecodeMixin:
...
@@ -94,12 +99,41 @@ class ScheduleBatchDisaggregationDecodeMixin:
"""Assign the buffered last input id to schedule batch"""
"""Assign the buffered last input id to schedule batch"""
self
.
output_ids
=
[]
self
.
output_ids
=
[]
for
req
in
self
.
reqs
:
for
req
in
self
.
reqs
:
if
req
.
output_ids
and
len
(
req
.
output_ids
)
>
0
:
self
.
output_ids
.
append
(
req
.
output_ids
[
-
1
])
# resumed retracted req
self
.
output_ids
.
append
(
req
.
output_ids
[
-
1
])
else
:
assert
req
.
transferred_output_id
is
not
None
req
.
output_ids
.
append
(
req
.
transferred_output_id
)
self
.
output_ids
.
append
(
req
.
transferred_output_id
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
self
.
output_ids
=
torch
.
tensor
(
self
.
output_ids
,
device
=
self
.
device
)
self
.
output_ids
=
torch
.
tensor
(
self
.
output_ids
,
device
=
self
.
device
)
# Simulate the eagle run. We add mock data to hidden states for the
# ease of implementation now meaning the first token will have acc rate
# of 0.
if
not
self
.
spec_algorithm
.
is_none
():
b
=
len
(
self
.
reqs
)
topk_p
=
torch
.
arange
(
b
*
server_args
.
speculative_eagle_topk
,
0
,
-
1
,
device
=
self
.
device
,
dtype
=
torch
.
float32
,
)
topk_p
=
topk_p
.
reshape
(
b
,
server_args
.
speculative_eagle_topk
)
topk_p
/=
b
*
server_args
.
speculative_eagle_topk
topk_index
=
torch
.
arange
(
b
*
server_args
.
speculative_eagle_topk
,
device
=
self
.
device
)
topk_index
=
topk_index
.
reshape
(
b
,
server_args
.
speculative_eagle_topk
)
# local import to avoid circular import
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
spec_info
=
EagleDraftInput
(
topk_p
=
topk_p
,
topk_index
=
topk_index
,
hidden_states
=
torch
.
ones
(
(
b
,
model_config
.
hidden_size
),
device
=
self
.
device
),
verified_id
=
self
.
output_ids
,
)
spec_info
.
prepare_for_extend
(
self
)
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
self
.
spec_info
=
spec_info
python/sglang/srt/disaggregation/mini_lb.py
View file @
8233cc10
...
@@ -73,11 +73,27 @@ class MiniLoadBalancer:
...
@@ -73,11 +73,27 @@ class MiniLoadBalancer:
session
.
post
(
f
"
{
prefill_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
session
.
post
(
f
"
{
prefill_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
session
.
post
(
f
"
{
decode_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
session
.
post
(
f
"
{
decode_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
]
]
# Wait for both responses to complete. Prefill should end first.
# Wait for both responses to complete. Prefill should end first.
_
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
prefill_response
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
if
"return_logprob"
in
modified_request
:
prefill_json
=
await
prefill_response
.
json
()
ret_json
=
await
decode_response
.
json
()
# merge `meta_info.input_token_logprobs` from prefill to decode
if
"meta_info"
in
ret_json
:
if
"input_token_logprobs"
in
ret_json
[
"meta_info"
]:
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
=
(
prefill_json
[
"meta_info"
][
"input_token_logprobs"
]
+
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
)
else
:
ret_json
=
await
decode_response
.
json
()
return
ORJSONResponse
(
return
ORJSONResponse
(
content
=
await
decode_response
.
json
()
,
content
=
ret_
json
,
status_code
=
decode_response
.
status
,
status_code
=
decode_response
.
status
,
)
)
...
@@ -92,30 +108,47 @@ class MiniLoadBalancer:
...
@@ -92,30 +108,47 @@ class MiniLoadBalancer:
total
=
3600
total
=
3600
)
# Add timeout for request reliability
)
# Add timeout for request reliability
)
as
session
:
)
as
session
:
try
:
# Create the tasks for both prefill and decode requests
# Create the tasks for both prefill and decode requests
tasks
=
[
tasks
=
[
session
.
post
(
f
"
{
prefill_server
}
/generate"
,
json
=
modified_request
),
session
.
post
(
session
.
post
(
f
"
{
decode_server
}
/generate"
,
json
=
modified_request
),
f
"
{
prefill_server
}
/
{
endpoint
}
"
,
json
=
modified_request
]
),
# Wait for both responses to complete. Since this is streaming, they return immediately.
session
.
post
(
prefill_response
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
f
"
{
decode_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
if
modified_request
.
get
(
"return_logprob"
,
False
):
]
prefill_chunks
=
[]
# Wait for both responses to complete. Since this is streaming, they return immediately.
async
for
chunk
in
prefill_response
.
content
:
prefill_response
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
prefill_chunks
.
append
(
chunk
)
first_prefill_chunk
=
(
prefill_chunks
[
0
].
decode
(
"utf-8"
)[
5
:].
strip
(
"
\n
"
)
)
first_prefill_chunk_json
=
orjson
.
loads
(
first_prefill_chunk
)
async
for
chunk
in
decode_response
.
content
:
# Note: This is inefficient
# merge prefill input_token_logprobs, output_token_logprobs to decode
decoded_chunk
=
chunk
.
decode
(
"utf-8"
)
if
(
decoded_chunk
and
decoded_chunk
.
startswith
(
"data:"
)
and
"[DONE]"
not
in
decoded_chunk
):
ret_json
=
orjson
.
loads
(
decoded_chunk
[
5
:].
strip
(
"
\n
"
))
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
=
(
first_prefill_chunk_json
[
"meta_info"
][
"input_token_logprobs"
]
+
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
)
yield
b
"data: "
+
orjson
.
dumps
(
ret_json
)
+
b
"
\n\n
"
else
:
yield
chunk
else
:
async
for
chunk
in
decode_response
.
content
:
async
for
chunk
in
decode_response
.
content
:
yield
chunk
yield
chunk
except
Exception
as
e
:
error_msg
=
{
"error"
:
{
"message"
:
f
"Stream processing error:
{
str
(
e
)
}
"
}
}
yield
b
"data: "
+
orjson
.
dumps
(
error_msg
,
option
=
orjson
.
OPT_NON_STR_KEYS
)
+
b
"
\n\n
"
finally
:
if
prefill_response
is
not
None
:
await
prefill_response
.
release
()
return
StreamingResponse
(
return
StreamingResponse
(
stream_results
(),
stream_results
(),
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
8233cc10
...
@@ -32,6 +32,7 @@ from sglang.srt.disaggregation.utils import (
...
@@ -32,6 +32,7 @@ from sglang.srt.disaggregation.utils import (
DisaggregationMode
,
DisaggregationMode
,
FakeBootstrapHost
,
FakeBootstrapHost
,
KVClassType
,
KVClassType
,
MetadataBuffers
,
ReqToMetadataIdxAllocator
,
ReqToMetadataIdxAllocator
,
TransferBackend
,
TransferBackend
,
get_kv_class
,
get_kv_class
,
...
@@ -63,8 +64,7 @@ class PrefillBootstrapQueue:
...
@@ -63,8 +64,7 @@ class PrefillBootstrapQueue:
token_to_kv_pool
:
KVCache
,
token_to_kv_pool
:
KVCache
,
draft_token_to_kv_pool
:
Optional
[
KVCache
],
draft_token_to_kv_pool
:
Optional
[
KVCache
],
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
metadata_buffers
:
List
[
torch
.
Tensor
],
metadata_buffers
:
MetadataBuffers
,
aux_dtype
:
torch
.
dtype
,
tp_rank
:
int
,
tp_rank
:
int
,
tp_size
:
int
,
tp_size
:
int
,
bootstrap_port
:
int
,
bootstrap_port
:
int
,
...
@@ -76,7 +76,6 @@ class PrefillBootstrapQueue:
...
@@ -76,7 +76,6 @@ class PrefillBootstrapQueue:
self
.
draft_token_to_kv_pool
=
draft_token_to_kv_pool
self
.
draft_token_to_kv_pool
=
draft_token_to_kv_pool
self
.
is_mla_backend
=
is_mla_backend
(
token_to_kv_pool
)
self
.
is_mla_backend
=
is_mla_backend
(
token_to_kv_pool
)
self
.
aux_dtype
=
aux_dtype
self
.
metadata_buffers
=
metadata_buffers
self
.
metadata_buffers
=
metadata_buffers
self
.
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
self
.
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
...
@@ -116,15 +115,9 @@ class PrefillBootstrapQueue:
...
@@ -116,15 +115,9 @@ class PrefillBootstrapQueue:
kv_args
.
kv_item_lens
=
kv_item_lens
kv_args
.
kv_item_lens
=
kv_item_lens
# Define req -> input ids buffer
# Define req -> input ids buffer
kv_args
.
aux_data_ptrs
=
[
kv_args
.
aux_data_ptrs
,
kv_args
.
aux_data_lens
,
kv_args
.
aux_item_lens
=
(
metadata_buffer
.
data_ptr
()
for
metadata_buffer
in
self
.
metadata_buffers
self
.
metadata_buffers
.
get_buf_infos
()
]
)
kv_args
.
aux_data_lens
=
[
metadata_buffer
.
nbytes
for
metadata_buffer
in
self
.
metadata_buffers
]
kv_args
.
aux_item_lens
=
[
metadata_buffer
[
0
].
nbytes
for
metadata_buffer
in
self
.
metadata_buffers
]
kv_args
.
ib_device
=
self
.
scheduler
.
server_args
.
disaggregation_ib_device
kv_args
.
ib_device
=
self
.
scheduler
.
server_args
.
disaggregation_ib_device
kv_args
.
gpu_id
=
self
.
scheduler
.
gpu_id
kv_args
.
gpu_id
=
self
.
scheduler
.
gpu_id
kv_manager_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
MANAGER
)
kv_manager_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
MANAGER
)
...
@@ -299,10 +292,9 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -299,10 +292,9 @@ class SchedulerDisaggregationPrefillMixin:
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
)
->
None
:
)
->
None
:
"""
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_inf
l
ight_queue
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
Adapted from process_batch_result_prefill
Adapted from process_batch_result_prefill
"""
"""
(
(
logits_output
,
logits_output
,
next_token_ids
,
next_token_ids
,
...
@@ -315,27 +307,78 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -315,27 +307,78 @@ class SchedulerDisaggregationPrefillMixin:
result
.
extend_logprob_start_len_per_req
,
result
.
extend_logprob_start_len_per_req
,
)
)
logprob_pt
=
0
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
# wait
# wait
_
,
next_token_ids
,
_
=
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
logits_output
,
next_token_ids
,
_
=
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
else
:
else
:
next_token_ids
=
result
.
next_token_ids
.
tolist
()
next_token_ids
=
result
.
next_token_ids
.
tolist
()
if
batch
.
return_logprob
:
for
req
,
next_token_id
in
zip
(
batch
.
reqs
,
next_token_ids
,
strict
=
True
):
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
.
tolist
()
)
if
logits_output
.
input_token_logprobs
is
not
None
:
logits_output
.
input_token_logprobs
=
tuple
(
logits_output
.
input_token_logprobs
.
tolist
()
)
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
,
strict
=
True
)
):
req
:
Req
req
:
Req
if
req
.
is_chunked
<=
0
:
if
req
.
is_chunked
<=
0
:
# There is no output_ids for prefill
# There is no output_ids for prefill
req
.
output_ids
.
append
(
next_token_id
)
req
.
output_ids
.
append
(
next_token_id
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
# update the tree and lock
self
.
tree_cache
.
cache_unfinished_req
(
req
)
# update the tree and lock
self
.
send_kv_chunk
(
req
,
token_id
=
next_token_id
)
self
.
disagg_prefill_inflight_queue
.
append
(
req
)
self
.
disagg_prefill_inflight_queue
.
append
(
req
)
if
req
.
return_logprob
:
assert
extend_logprob_start_len_per_req
is
not
None
assert
extend_input_len_per_req
is
not
None
extend_logprob_start_len
=
extend_logprob_start_len_per_req
[
i
]
extend_input_len
=
extend_input_len_per_req
[
i
]
num_input_logprobs
=
extend_input_len
-
extend_logprob_start_len
self
.
add_logprob_return_values
(
i
,
req
,
logprob_pt
,
next_token_ids
,
num_input_logprobs
,
logits_output
,
)
logprob_pt
+=
num_input_logprobs
self
.
send_kv_chunk
(
req
,
last_chunk
=
True
)
if
req
.
grammar
is
not
None
:
req
.
grammar
.
accept_token
(
next_token_id
)
req
.
grammar
.
finished
=
req
.
finished
()
else
:
else
:
# being chunked reqs' prefill is not finished
# being chunked reqs' prefill is not finished
req
.
is_chunked
-=
1
req
.
is_chunked
-=
1
if
req
.
return_logprob
:
extend_logprob_start_len
=
extend_logprob_start_len_per_req
[
i
]
extend_input_len
=
extend_input_len_per_req
[
i
]
if
extend_logprob_start_len
<
extend_input_len
:
# Update input logprobs.
num_input_logprobs
=
extend_input_len
-
extend_logprob_start_len
self
.
add_input_logprob_return_values
(
i
,
req
,
logits_output
,
logprob_pt
,
num_input_logprobs
,
last_prefill_chunk
=
False
,
)
logprob_pt
+=
num_input_logprobs
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
self
.
send_kv_chunk
(
req
,
end_idx
=
req
.
tmp_end_idx
)
self
.
send_kv_chunk
(
req
,
last_chunk
=
False
,
end_idx
=
req
.
tmp_end_idx
)
# We need to remove the sync in the following function for overlap schedule.
self
.
set_next_batch_sampling_info_done
(
batch
)
def
process_disagg_prefill_inflight_queue
(
self
:
Scheduler
)
->
None
:
def
process_disagg_prefill_inflight_queue
(
self
:
Scheduler
)
->
None
:
"""
"""
...
@@ -379,7 +422,11 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -379,7 +422,11 @@ class SchedulerDisaggregationPrefillMixin:
)
)
# Stream requests which have finished transfer
# Stream requests which have finished transfer
self
.
stream_output
(
done_reqs
,
False
,
None
)
self
.
stream_output
(
done_reqs
,
any
(
req
.
return_logprob
for
req
in
done_reqs
),
None
,
)
self
.
disagg_prefill_inflight_queue
=
undone_reqs
self
.
disagg_prefill_inflight_queue
=
undone_reqs
...
@@ -405,7 +452,7 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -405,7 +452,7 @@ class SchedulerDisaggregationPrefillMixin:
def
send_kv_chunk
(
def
send_kv_chunk
(
self
:
Scheduler
,
self
:
Scheduler
,
req
:
Req
,
req
:
Req
,
token_id
:
Optional
[
int
]
=
Non
e
,
last_chunk
:
bool
=
Fals
e
,
end_idx
:
Optional
[
int
]
=
None
,
end_idx
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -413,37 +460,28 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -413,37 +460,28 @@ class SchedulerDisaggregationPrefillMixin:
"""
"""
page_size
=
self
.
token_to_kv_pool_allocator
.
page_size
page_size
=
self
.
token_to_kv_pool_allocator
.
page_size
start_idx
=
req
.
start_send_idx
start_idx
=
req
.
start_send_idx
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
# the resolved length is not the same as fill_ids's length
end_idx
=
(
end_idx
=
(
end_idx
end_idx
if
end_idx
is
not
None
if
end_idx
is
not
None
else
min
(
len
(
req
.
fill_ids
),
len
(
req
.
origin_input_ids
))
else
min
(
len
(
req
.
fill_ids
),
len
(
req
.
origin_input_ids
))
)
)
last_chunk
=
token_id
is
not
None
if
not
last_chunk
:
if
not
last_chunk
:
# if not the last chunk and the last page is partial, delay the last partial page to the next send
# if not the last chunk and the last page is partial, delay the last partial page to the next send
end_idx
=
end_idx
-
end_idx
%
page_size
end_idx
=
end_idx
-
end_idx
%
page_size
# Update next start_send_idx
req
.
start_send_idx
=
end_idx
kv_indices
=
(
kv_indices
=
(
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
start_idx
:
end_idx
]
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
start_idx
:
end_idx
]
.
cpu
()
.
cpu
()
.
numpy
()
.
numpy
()
)
)
if
last_chunk
is
True
:
req
.
start_send_idx
=
end_idx
self
.
disagg_prefill_bootstrap_queue
.
store_prefill_results
(
if
last_chunk
:
req
.
metadata_buffer_index
,
token_id
self
.
disagg_metadata_buffers
.
set_buf
(
req
)
)
page_indices
=
kv_to_page_indices
(
kv_indices
,
page_size
)
page_indices
=
kv_to_page_indices
(
kv_indices
,
page_size
)
if
len
(
page_indices
)
==
0
:
if
len
(
page_indices
)
==
0
:
logger
.
info
(
logger
.
info
(
f
"Skip sending kv chunk for request
{
req
.
rid
=
}
{
req
.
bootstrap_room
=
}
because page_indices is empty"
f
"Skip sending kv chunk for request
{
req
.
rid
=
}
{
req
.
bootstrap_room
=
}
because page_indices is empty"
)
)
return
return
req
.
disagg_kv_sender
.
send
(
page_indices
)
req
.
disagg_kv_sender
.
send
(
page_indices
)
python/sglang/srt/disaggregation/utils.py
View file @
8233cc10
...
@@ -6,7 +6,7 @@ import random
...
@@ -6,7 +6,7 @@ import random
import
warnings
import
warnings
from
collections
import
deque
from
collections
import
deque
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
numpy
as
np
import
numpy
as
np
import
requests
import
requests
...
@@ -15,6 +15,9 @@ import torch.distributed as dist
...
@@ -15,6 +15,9 @@ import torch.distributed as dist
from
sglang.srt.utils
import
get_ip
from
sglang.srt.utils
import
get_ip
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
FakeBootstrapHost
=
"2.2.2.2"
FakeBootstrapHost
=
"2.2.2.2"
# env var for testing failure, convert to float explicitly
# env var for testing failure, convert to float explicitly
...
@@ -196,3 +199,83 @@ def prepare_abort(req: Req, error_message: str, status_code=None):
...
@@ -196,3 +199,83 @@ def prepare_abort(req: Req, error_message: str, status_code=None):
req
.
input_top_logprobs_idx
=
[]
req
.
input_top_logprobs_idx
=
[]
req
.
input_token_ids_logprobs_val
=
[]
req
.
input_token_ids_logprobs_val
=
[]
req
.
input_token_ids_logprobs_idx
=
[]
req
.
input_token_ids_logprobs_idx
=
[]
class
MetadataBuffers
:
def
__init__
(
self
,
size
:
int
,
max_top_logprobs_num
:
int
=
128
):
# TODO: abort top_logprobs_num > 128 in PD
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self
.
output_ids
=
torch
.
zeros
((
size
,
16
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
output_token_logprobs_val
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
self
.
output_token_logprobs_idx
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
output_top_logprobs_val
=
torch
.
zeros
(
(
size
,
max_top_logprobs_num
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
self
.
output_top_logprobs_idx
=
torch
.
zeros
(
(
size
,
max_top_logprobs_num
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
def
get_buf_infos
(
self
):
ptrs
=
[
self
.
output_ids
.
data_ptr
(),
self
.
output_token_logprobs_val
.
data_ptr
(),
self
.
output_token_logprobs_idx
.
data_ptr
(),
self
.
output_top_logprobs_val
.
data_ptr
(),
self
.
output_top_logprobs_idx
.
data_ptr
(),
]
data_lens
=
[
self
.
output_ids
.
nbytes
,
self
.
output_token_logprobs_val
.
nbytes
,
self
.
output_token_logprobs_idx
.
nbytes
,
self
.
output_top_logprobs_val
.
nbytes
,
self
.
output_top_logprobs_idx
.
nbytes
,
]
item_lens
=
[
self
.
output_ids
[
0
].
nbytes
,
self
.
output_token_logprobs_val
[
0
].
nbytes
,
self
.
output_token_logprobs_idx
[
0
].
nbytes
,
self
.
output_top_logprobs_val
[
0
].
nbytes
,
self
.
output_top_logprobs_idx
[
0
].
nbytes
,
]
return
ptrs
,
data_lens
,
item_lens
def
get_buf
(
self
,
idx
:
int
):
return
(
self
.
output_ids
[
idx
],
self
.
output_token_logprobs_val
[
idx
],
self
.
output_token_logprobs_idx
[
idx
],
self
.
output_top_logprobs_val
[
idx
],
self
.
output_top_logprobs_idx
[
idx
],
)
def
set_buf
(
self
,
req
:
Req
):
self
.
output_ids
[
req
.
metadata_buffer_index
][
0
]
=
req
.
output_ids
[
0
]
if
req
.
return_logprob
:
if
req
.
output_token_logprobs_val
:
# not none or empty list
self
.
output_token_logprobs_val
[
req
.
metadata_buffer_index
][
0
]
=
(
req
.
output_token_logprobs_val
[
0
]
)
if
req
.
output_token_logprobs_idx
:
# not none or empty list
self
.
output_token_logprobs_idx
[
req
.
metadata_buffer_index
][
0
]
=
(
req
.
output_token_logprobs_idx
[
0
]
)
if
req
.
output_top_logprobs_val
:
# not none or empty list
self
.
output_top_logprobs_val
[
req
.
metadata_buffer_index
][
:
len
(
req
.
output_top_logprobs_val
[
0
])
]
=
torch
.
tensor
(
req
.
output_top_logprobs_val
[
0
],
dtype
=
torch
.
float32
,
device
=
"cpu"
)
if
req
.
output_top_logprobs_idx
:
# not none or empty list
self
.
output_top_logprobs_idx
[
req
.
metadata_buffer_index
][
:
len
(
req
.
output_top_logprobs_idx
[
0
])
]
=
torch
.
tensor
(
req
.
output_top_logprobs_idx
[
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
python/sglang/srt/managers/schedule_batch.py
View file @
8233cc10
...
@@ -607,9 +607,6 @@ class Req:
...
@@ -607,9 +607,6 @@ class Req:
self
.
tmp_end_idx
:
int
=
-
1
self
.
tmp_end_idx
:
int
=
-
1
self
.
metadata_buffer_index
:
int
=
-
1
self
.
metadata_buffer_index
:
int
=
-
1
# The first output_id transferred from prefill instance.
self
.
transferred_output_id
:
Optional
[
int
]
=
None
@
property
@
property
def
seqlen
(
self
):
def
seqlen
(
self
):
return
len
(
self
.
origin_input_ids
)
+
len
(
self
.
output_ids
)
return
len
(
self
.
origin_input_ids
)
+
len
(
self
.
output_ids
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
8233cc10
...
@@ -48,6 +48,7 @@ from sglang.srt.disaggregation.prefill import (
...
@@ -48,6 +48,7 @@ from sglang.srt.disaggregation.prefill import (
)
)
from
sglang.srt.disaggregation.utils
import
(
from
sglang.srt.disaggregation.utils
import
(
DisaggregationMode
,
DisaggregationMode
,
MetadataBuffers
,
ReqToMetadataIdxAllocator
,
ReqToMetadataIdxAllocator
,
TransferBackend
,
TransferBackend
,
prepare_abort
,
prepare_abort
,
...
@@ -569,20 +570,13 @@ class Scheduler(
...
@@ -569,20 +570,13 @@ class Scheduler(
req_to_metadata_buffer_idx_allocator
=
ReqToMetadataIdxAllocator
(
req_to_metadata_buffer_idx_allocator
=
ReqToMetadataIdxAllocator
(
buffer_size
buffer_size
)
)
aux_dtype
=
torch
.
int32
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
)
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer
=
torch
.
zeros
(
(
buffer_size
,
16
),
dtype
=
aux_dtype
,
device
=
"cpu"
)
metadata_buffers
=
[
output_id_buffer
]
# The decode requests polling kv cache
# The decode requests polling kv cache
self
.
disagg_decode_transfer_queue
=
DecodeTransferQueue
(
self
.
disagg_decode_transfer_queue
=
DecodeTransferQueue
(
gloo_group
=
self
.
attn_tp_cpu_group
,
gloo_group
=
self
.
attn_tp_cpu_group
,
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
metadata_buffers
=
metadata_buffers
,
metadata_buffers
=
self
.
disagg_
metadata_buffers
,
scheduler
=
self
,
scheduler
=
self
,
tree_cache
=
self
.
tree_cache
,
tree_cache
=
self
.
tree_cache
,
)
)
...
@@ -597,8 +591,7 @@ class Scheduler(
...
@@ -597,8 +591,7 @@ class Scheduler(
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
),
),
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
metadata_buffers
=
metadata_buffers
,
metadata_buffers
=
self
.
disagg_metadata_buffers
,
aux_dtype
=
aux_dtype
,
scheduler
=
self
,
scheduler
=
self
,
transfer_queue
=
self
.
disagg_decode_transfer_queue
,
transfer_queue
=
self
.
disagg_decode_transfer_queue
,
tree_cache
=
self
.
tree_cache
,
tree_cache
=
self
.
tree_cache
,
...
@@ -618,14 +611,7 @@ class Scheduler(
...
@@ -618,14 +611,7 @@ class Scheduler(
req_to_metadata_buffer_idx_allocator
=
ReqToMetadataIdxAllocator
(
req_to_metadata_buffer_idx_allocator
=
ReqToMetadataIdxAllocator
(
buffer_size
buffer_size
)
)
aux_dtype
=
torch
.
int32
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
)
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer
=
torch
.
zeros
(
(
buffer_size
,
16
),
dtype
=
aux_dtype
,
device
=
"cpu"
)
metadata_buffers
=
[
output_id_buffer
]
self
.
disagg_prefill_bootstrap_queue
=
PrefillBootstrapQueue
(
self
.
disagg_prefill_bootstrap_queue
=
PrefillBootstrapQueue
(
token_to_kv_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
(),
token_to_kv_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
(),
...
@@ -635,8 +621,7 @@ class Scheduler(
...
@@ -635,8 +621,7 @@ class Scheduler(
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
),
),
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
metadata_buffers
=
metadata_buffers
,
metadata_buffers
=
self
.
disagg_metadata_buffers
,
aux_dtype
=
aux_dtype
,
tp_rank
=
self
.
tp_rank
,
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
tp_size
=
self
.
tp_size
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
...
...
python/sglang/test/test_utils.py
View file @
8233cc10
...
@@ -485,7 +485,6 @@ def popen_launch_pd_server(
...
@@ -485,7 +485,6 @@ def popen_launch_pd_server(
api_key
:
Optional
[
str
]
=
None
,
api_key
:
Optional
[
str
]
=
None
,
other_args
:
list
[
str
]
=
(),
other_args
:
list
[
str
]
=
(),
env
:
Optional
[
dict
]
=
None
,
env
:
Optional
[
dict
]
=
None
,
return_stdout_stderr
:
Optional
[
tuple
]
=
None
,
):
):
_
,
host
,
port
=
base_url
.
split
(
":"
)
_
,
host
,
port
=
base_url
.
split
(
":"
)
host
=
host
[
2
:]
host
=
host
[
2
:]
...
@@ -515,42 +514,9 @@ def popen_launch_pd_server(
...
@@ -515,42 +514,9 @@ def popen_launch_pd_server(
print
(
f
"command=
{
' '
.
join
(
command
)
}
"
)
print
(
f
"command=
{
' '
.
join
(
command
)
}
"
)
if
return_stdout_stderr
:
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
,
env
=
env
)
process
=
subprocess
.
Popen
(
command
,
stdout
=
return_stdout_stderr
[
0
],
stderr
=
return_stdout_stderr
[
1
],
env
=
env
,
text
=
True
,
)
else
:
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
,
env
=
env
)
start_time
=
time
.
perf_counter
()
with
requests
.
Session
()
as
session
:
while
time
.
perf_counter
()
-
start_time
<
timeout
:
try
:
headers
=
{
"Content-Type"
:
"application/json; charset=utf-8"
,
"Authorization"
:
f
"Bearer
{
api_key
}
"
,
}
response
=
session
.
get
(
f
"
{
base_url
}
/health"
,
headers
=
headers
,
)
if
response
.
status_code
==
200
:
return
process
except
requests
.
RequestException
:
pass
return_code
=
process
.
poll
()
if
return_code
is
not
None
:
raise
Exception
(
f
"Server unexpectedly exits (
{
return_code
=
}
)."
)
time
.
sleep
(
10
)
kill_process_tree
(
process
.
pid
)
return
process
raise
TimeoutError
(
"Server failed to start within the timeout period."
)
def
run_with_timeout
(
def
run_with_timeout
(
...
...
scripts/playground/disaggregation/cli-logprob.py
0 → 100644
View file @
8233cc10
prompt
=
"The capital of taiwan is "
import
json
import
requests
response
=
requests
.
post
(
"http://0.0.0.0:8000/generate"
,
json
=
{
"text"
:
prompt
,
"sampling_params"
:
{
"temperature"
:
0
},
"return_logprob"
:
True
,
"return_input_logprob"
:
True
,
"logprob_start_len"
:
0
,
},
)
j
=
response
.
json
()
input_logprobs
=
j
[
"meta_info"
][
"input_token_logprobs"
]
output_logprobs
=
j
[
"meta_info"
][
"output_token_logprobs"
]
print
(
len
(
input_logprobs
),
len
(
output_logprobs
))
test/srt/test_disaggregation.py
View file @
8233cc10
import
os
import
subprocess
import
subprocess
import
time
import
time
import
unittest
import
unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
import
requests
import
requests
...
@@ -25,15 +27,22 @@ class TestDisaggregationAccuracy(CustomTestCase):
...
@@ -25,15 +27,22 @@ class TestDisaggregationAccuracy(CustomTestCase):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_host
=
"127.0.0.1"
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_port
=
int
(
DEFAULT_URL_FOR_TEST
.
split
(
":"
)[
-
1
])
cls
.
base_host
=
parsed_url
.
hostname
cls
.
lb_url
=
DEFAULT_URL_FOR_TEST
base_port
=
str
(
parsed_url
.
port
)
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
base_port
+
100
}
"
cls
.
lb_port
=
base_port
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
base_port
+
200
}
"
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
run_with_timeout
(
cls
.
start_prefill
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
run_with_timeout
(
cls
.
start_decode
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
print
(
f
"
{
cls
.
base_host
=
}
{
cls
.
lb_port
=
}
{
cls
.
prefill_port
=
}
{
cls
.
decode_port
=
}
"
)
# Non blocking start servers
cls
.
start_prefill
()
cls
.
start_decode
()
# Block until both
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
...
@@ -48,7 +57,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
...
@@ -48,7 +57,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
"--host"
,
"--host"
,
cls
.
base_host
,
cls
.
base_host
,
"--port"
,
"--port"
,
str
(
cls
.
base
_port
)
,
cls
.
lb
_port
,
]
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
...
@@ -63,14 +72,10 @@ class TestDisaggregationAccuracy(CustomTestCase):
...
@@ -63,14 +72,10 @@ class TestDisaggregationAccuracy(CustomTestCase):
"--trust-remote-code"
,
"--trust-remote-code"
,
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"prefill"
,
"prefill"
,
"--host"
,
cls
.
base_host
,
"--port"
,
str
(
cls
.
base_port
+
100
),
"--tp"
,
"--tp"
,
"
4
"
,
"
1
"
,
#
"--disaggregation-ib-device",
"--disaggregation-ib-device"
,
#
"mlx5_roce0
,mlx5_roce1,mlx5_roce2,mlx5_roce3
",
"mlx5_roce0"
,
]
]
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
model
,
...
@@ -85,16 +90,165 @@ class TestDisaggregationAccuracy(CustomTestCase):
...
@@ -85,16 +90,165 @@ class TestDisaggregationAccuracy(CustomTestCase):
"--trust-remote-code"
,
"--trust-remote-code"
,
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"decode"
,
"decode"
,
"--tp"
,
"1"
,
"--base-gpu-id"
,
"1"
,
"--disaggregation-ib-device"
,
"mlx5_roce1"
,
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
decode_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
decode_args
,
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
60
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
f
"http://
{
self
.
base_host
}
"
,
port
=
int
(
self
.
lb_port
),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"Evaluation metrics:
{
metrics
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.62
)
def
test_logprob
(
self
):
prompt
=
"The capital of taiwan is "
response
=
requests
.
post
(
self
.
lb_url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"sampling_params"
:
{
"temperature"
:
0
},
"return_logprob"
:
True
,
"return_input_logprob"
:
True
,
"logprob_start_len"
:
0
,
},
)
j
=
response
.
json
()
completion_tokens
=
j
[
"meta_info"
][
"completion_tokens"
]
input_logprobs
=
j
[
"meta_info"
][
"input_token_logprobs"
]
output_logprobs
=
j
[
"meta_info"
][
"output_token_logprobs"
]
assert
(
len
(
output_logprobs
)
==
completion_tokens
),
f
"output_logprobs and completion_tokens should have the same length, but got
{
len
(
output_logprobs
)
}
and
{
completion_tokens
}
"
assert
(
len
(
input_logprobs
)
>
0
),
f
"input_logprobs should have at least one token, but got
{
len
(
input_logprobs
)
}
"
class
TestDisaggregationMooncakeFailure
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
# set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
os
.
environ
[
"DISAGGREGATION_TEST_FAILURE_PROB"
]
=
"0.05"
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_host
=
parsed_url
.
hostname
base_port
=
str
(
parsed_url
.
port
)
cls
.
lb_port
=
base_port
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
print
(
f
"
{
cls
.
base_host
=
}
{
cls
.
lb_port
=
}
{
cls
.
prefill_port
=
}
{
cls
.
decode_port
=
}
"
)
# Non blocking start servers
cls
.
start_prefill
()
cls
.
start_decode
()
# Block until both
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
"python3"
,
"-m"
,
"sglang.srt.disaggregation.mini_lb"
,
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
"--host"
,
cls
.
base_host
,
cls
.
base_host
,
"--port"
,
"--port"
,
str
(
cls
.
base_port
+
200
),
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
subprocess
.
Popen
(
lb_command
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
@
classmethod
def
start_prefill
(
cls
):
prefill_args
=
[
"--trust-remote-code"
,
"--disaggregation-mode"
,
"prefill"
,
"--tp"
,
"--tp"
,
"4"
,
"1"
,
"--disaggregation-ib-device"
,
"mlx5_roce0"
,
]
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
prefill_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
prefill_args
,
)
@
classmethod
def
start_decode
(
cls
):
decode_args
=
[
"--trust-remote-code"
,
"--disaggregation-mode"
,
"decode"
,
"--tp"
,
"1"
,
"--base-gpu-id"
,
"--base-gpu-id"
,
"
4
"
,
"
1
"
,
#
"--disaggregation-ib-device",
"--disaggregation-ib-device"
,
#
"mlx5_roce
4,mlx5_roce5,mlx5_roce6,mlx5_roce7
",
"mlx5_roce
1
"
,
]
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
model
,
...
@@ -121,6 +275,8 @@ class TestDisaggregationAccuracy(CustomTestCase):
...
@@ -121,6 +275,8 @@ class TestDisaggregationAccuracy(CustomTestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
# unset DISAGGREGATION_TEST_FAILURE_PROB
os
.
environ
.
pop
(
"DISAGGREGATION_TEST_FAILURE_PROB"
)
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
if
process
:
try
:
try
:
...
@@ -128,6 +284,9 @@ class TestDisaggregationAccuracy(CustomTestCase):
...
@@ -128,6 +284,9 @@ class TestDisaggregationAccuracy(CustomTestCase):
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
num_shots
=
5
,
num_shots
=
5
,
...
@@ -135,27 +294,29 @@ class TestDisaggregationAccuracy(CustomTestCase):
...
@@ -135,27 +294,29 @@ class TestDisaggregationAccuracy(CustomTestCase):
num_questions
=
200
,
num_questions
=
200
,
max_new_tokens
=
512
,
max_new_tokens
=
512
,
parallel
=
128
,
parallel
=
128
,
host
=
"http://
127.0.0.1
"
,
host
=
f
"http://
{
self
.
base_host
}
"
,
port
=
int
(
self
.
lb_
url
.
split
(
":"
)[
-
1
]
),
port
=
int
(
self
.
lb_
port
),
)
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"Evaluation metrics:
{
metrics
}
"
)
print
(
f
"Evaluation metrics:
{
metrics
}
"
)
# Expect lots of failure but the server cannot crash
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.62
)
class
TestDisaggregation
SpecAccuracy
(
CustomTestCase
):
class
TestDisaggregation
MooncakeSpec
(
CustomTestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
super
().
setUpClass
()
cls
.
model
=
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
cls
.
model
=
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
cls
.
draft_model
=
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
cls
.
draft_model
=
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
cls
.
base_host
=
"127.0.0.1"
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_port
=
int
(
DEFAULT_URL_FOR_TEST
.
split
(
":"
)[
-
1
])
cls
.
base_host
=
parsed_url
.
hostname
cls
.
lb_url
=
DEFAULT_URL_FOR_TEST
base_port
=
str
(
parsed_url
.
port
)
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
base_port
+
100
}
"
cls
.
lb_port
=
base_port
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
base_port
+
200
}
"
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
cls
.
spec_args
=
[
cls
.
spec_args
=
[
"--speculative-algorithm"
,
"--speculative-algorithm"
,
"EAGLE"
,
"EAGLE"
,
...
@@ -170,10 +331,13 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
...
@@ -170,10 +331,13 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
"--cuda-graph-max-bs"
,
"--cuda-graph-max-bs"
,
"8"
,
"8"
,
]
]
print
(
f
"
{
cls
.
base_host
=
}
{
cls
.
lb_port
=
}
{
cls
.
prefill_port
=
}
{
cls
.
decode_port
=
}
"
)
run_with_timeout
(
cls
.
start_prefill
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
# Non blocking start servers
run_with_timeout
(
cls
.
start_decode
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
cls
.
start_prefill
()
cls
.
start_decode
()
# Block until both
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
...
@@ -188,7 +352,7 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
...
@@ -188,7 +352,7 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
"--host"
,
"--host"
,
cls
.
base_host
,
cls
.
base_host
,
"--port"
,
"--port"
,
str
(
cls
.
base
_port
)
,
cls
.
lb
_port
,
]
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
...
@@ -215,21 +379,15 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
...
@@ -215,21 +379,15 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
@
classmethod
@
classmethod
def
start_prefill
(
cls
):
def
start_prefill
(
cls
):
prefill_args
=
[
prefill_args
=
[
"--trust-remote-code"
,
"--trust-remote-code"
,
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"prefill"
,
"prefill"
,
"--host"
,
cls
.
base_host
,
"--port"
,
str
(
cls
.
base_port
+
100
),
"--tp"
,
"--tp"
,
"
4
"
,
"
2
"
,
#
"--disaggregation-ib-device",
"--disaggregation-ib-device"
,
#
"mlx5_roce0,mlx5_roce1
,mlx5_roce2,mlx5_roce3
",
"mlx5_roce0,mlx5_roce1"
,
]
+
cls
.
spec_args
]
+
cls
.
spec_args
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
model
,
cls
.
prefill_url
,
cls
.
prefill_url
,
...
@@ -243,16 +401,12 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
...
@@ -243,16 +401,12 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
"--trust-remote-code"
,
"--trust-remote-code"
,
"--disaggregation-mode"
,
"--disaggregation-mode"
,
"decode"
,
"decode"
,
"--host"
,
cls
.
base_host
,
"--port"
,
str
(
cls
.
base_port
+
200
),
"--tp"
,
"--tp"
,
"
4
"
,
"
2
"
,
"--base-gpu-id"
,
"--base-gpu-id"
,
"
4
"
,
"
2
"
,
#
"--disaggregation-ib-device",
"--disaggregation-ib-device"
,
#
"mlx5_roce
4,mlx5_roce5,mlx5_roce6
,mlx5_roce
7
",
"mlx5_roce
2
,mlx5_roce
3
"
,
]
+
cls
.
spec_args
]
+
cls
.
spec_args
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
model
,
...
@@ -261,15 +415,43 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
...
@@ -261,15 +415,43 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
other_args
=
decode_args
,
other_args
=
decode_args
,
)
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
60
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
num_shots
=
5
,
num_shots
=
5
,
data_path
=
None
,
data_path
=
None
,
num_questions
=
200
,
num_questions
=
200
,
max_new_tokens
=
512
,
max_new_tokens
=
512
,
parallel
=
4
,
# TODO: 128 crashes the decode
parallel
=
2
,
host
=
"http://
127.0.0.1
"
,
host
=
f
"http://
{
self
.
base_host
}
"
,
port
=
int
(
self
.
lb_
url
.
split
(
":"
)[
-
1
]
),
port
=
int
(
self
.
lb_
port
),
)
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"Evaluation metrics:
{
metrics
}
"
)
print
(
f
"Evaluation metrics:
{
metrics
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment