Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8233cc10
"vscode:/vscode.git/clone" did not exist on "f1ede61ffaafdc40603a3d9f0730b2c8affcd1a5"
Unverified
Commit
8233cc10
authored
May 23, 2025
by
Byron Hsu
Committed by
GitHub
May 23, 2025
Browse files
[PD] Support logprob & Add failure test (#6558)
parent
1b2e8f76
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
580 additions
and
226 deletions
+580
-226
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+58
-44
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+42
-8
python/sglang/srt/disaggregation/mini_lb.py
python/sglang/srt/disaggregation/mini_lb.py
+57
-24
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+71
-33
python/sglang/srt/disaggregation/utils.py
python/sglang/srt/disaggregation/utils.py
+84
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+6
-21
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+2
-36
scripts/playground/disaggregation/cli-logprob.py
scripts/playground/disaggregation/cli-logprob.py
+22
-0
test/srt/test_disaggregation.py
test/srt/test_disaggregation.py
+238
-56
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
8233cc10
...
...
@@ -36,6 +36,7 @@ from sglang.srt.disaggregation.utils import (
DisaggregationMode
,
FakeBootstrapHost
,
KVClassType
,
MetadataBuffers
,
ReqToMetadataIdxAllocator
,
TransferBackend
,
get_kv_class
,
...
...
@@ -78,8 +79,7 @@ class DecodePreallocQueue:
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
draft_token_to_kv_pool
:
Optional
[
KVCache
],
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
metadata_buffers
:
List
[
torch
.
Tensor
],
aux_dtype
:
torch
.
dtype
,
metadata_buffers
:
MetadataBuffers
,
scheduler
:
Scheduler
,
transfer_queue
:
DecodeTransferQueue
,
tree_cache
:
BasePrefixCache
,
...
...
@@ -94,7 +94,6 @@ class DecodePreallocQueue:
self
.
token_to_kv_pool
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
draft_token_to_kv_pool
=
draft_token_to_kv_pool
self
.
is_mla_backend
=
is_mla_backend
(
self
.
token_to_kv_pool
)
self
.
aux_dtype
=
aux_dtype
self
.
metadata_buffers
=
metadata_buffers
self
.
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
self
.
scheduler
=
scheduler
...
...
@@ -133,15 +132,9 @@ class DecodePreallocQueue:
kv_args
.
kv_data_lens
=
kv_data_lens
kv_args
.
kv_item_lens
=
kv_item_lens
kv_args
.
aux_data_ptrs
=
[
output_id_tensor
.
data_ptr
()
for
output_id_tensor
in
self
.
metadata_buffers
]
kv_args
.
aux_data_lens
=
[
metadata_buffer
.
nbytes
for
metadata_buffer
in
self
.
metadata_buffers
]
kv_args
.
aux_item_lens
=
[
metadata_buffer
[
0
].
nbytes
for
metadata_buffer
in
self
.
metadata_buffers
]
kv_args
.
aux_data_ptrs
,
kv_args
.
aux_data_lens
,
kv_args
.
aux_item_lens
=
(
self
.
metadata_buffers
.
get_buf_infos
()
)
kv_args
.
ib_device
=
self
.
scheduler
.
server_args
.
disaggregation_ib_device
kv_args
.
gpu_id
=
self
.
scheduler
.
gpu_id
kv_manager_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
MANAGER
)
...
...
@@ -211,7 +204,18 @@ class DecodePreallocQueue:
indices_to_remove
=
set
()
allocatable_tokens
=
self
.
_allocatable_tokens
()
# First, remove all failed requests from the queue
for
i
,
decode_req
in
enumerate
(
self
.
queue
):
if
isinstance
(
decode_req
.
req
.
finished_reason
,
FINISH_ABORT
):
self
.
scheduler
.
stream_output
(
[
decode_req
.
req
],
decode_req
.
req
.
return_logprob
)
indices_to_remove
.
add
(
i
)
for
i
,
decode_req
in
enumerate
(
self
.
queue
):
if
i
in
indices_to_remove
:
continue
if
not
decode_req
.
waiting_for_input
:
continue
...
...
@@ -331,7 +335,7 @@ class DecodeTransferQueue:
self
,
gloo_group
:
ProcessGroup
,
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
metadata_buffers
:
torch
.
Tensor
,
metadata_buffers
:
MetadataBuffers
,
scheduler
:
Scheduler
,
tree_cache
:
BasePrefixCache
,
):
...
...
@@ -342,11 +346,11 @@ class DecodeTransferQueue:
self
.
scheduler
=
scheduler
self
.
tree_cache
=
tree_cache
def
add
(
self
,
req_conn
:
DecodeRequest
)
->
None
:
self
.
queue
.
append
(
req_conn
)
def
add
(
self
,
decode_req
:
DecodeRequest
)
->
None
:
self
.
queue
.
append
(
decode_req
)
def
extend
(
self
,
req_conns
)
->
None
:
self
.
queue
.
extend
(
req_conn
s
)
def
extend
(
self
,
decode_reqs
:
List
[
DecodeRequest
]
)
->
None
:
self
.
queue
.
extend
(
decode_req
s
)
def
pop_transferred
(
self
)
->
List
[
DecodeRequest
]:
if
not
self
.
queue
:
...
...
@@ -356,14 +360,6 @@ class DecodeTransferQueue:
[
decode_req
.
kv_receiver
for
decode_req
in
self
.
queue
],
self
.
gloo_group
)
# First, remove all failed requests from the queue
for
i
,
decode_req
in
enumerate
(
self
.
queue
):
if
isinstance
(
decode_req
.
req
.
finished_reason
,
FINISH_ABORT
):
self
.
scheduler
.
stream_output
(
[
decode_req
.
req
],
decode_req
.
req
.
return_logprob
)
indices_to_remove
.
add
(
i
)
transferred_reqs
=
[]
indices_to_remove
=
set
()
for
i
,
(
decode_req
,
poll
)
in
enumerate
(
zip
(
self
.
queue
,
polls
)):
...
...
@@ -387,16 +383,37 @@ class DecodeTransferQueue:
indices_to_remove
.
add
(
i
)
continue
elif
poll
==
KVPoll
.
Success
:
# pop and push it to waiting queue
idx
=
decode_req
.
metadata_buffer_index
assert
len
(
decode_req
.
req
.
output_ids
)
==
0
output_id_buffer
=
self
.
metadata_buffers
[
0
]
# the last dimension is padded by the same values.
output_id
=
output_id_buffer
[
idx
][
0
].
item
()
assert
len
(
decode_req
.
req
.
output_ids
)
==
0
assert
decode_req
.
req
.
transferred_output_id
is
None
decode_req
.
req
.
transferred_output_id
=
output_id
transferred_reqs
.
append
(
decode_req
)
(
output_id
,
output_token_logprobs_val
,
output_token_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
)
=
self
.
metadata_buffers
.
get_buf
(
idx
)
decode_req
.
req
.
output_ids
.
append
(
output_id
[
0
].
item
())
if
decode_req
.
req
.
return_logprob
:
decode_req
.
req
.
output_token_logprobs_val
.
append
(
output_token_logprobs_val
[
0
].
item
()
)
decode_req
.
req
.
output_token_logprobs_idx
.
append
(
output_token_logprobs_idx
[
0
].
item
()
)
decode_req
.
req
.
output_top_logprobs_val
.
append
(
output_top_logprobs_val
[
:
decode_req
.
req
.
top_logprobs_num
].
tolist
()
)
decode_req
.
req
.
output_top_logprobs_idx
.
append
(
output_top_logprobs_idx
[
:
decode_req
.
req
.
top_logprobs_num
].
tolist
()
)
transferred_reqs
.
append
(
decode_req
.
req
)
indices_to_remove
.
add
(
i
)
elif
poll
in
[
KVPoll
.
Bootstrapping
,
...
...
@@ -451,7 +468,9 @@ class SchedulerDisaggregationDecodeMixin:
# Generate fake extend output.
if
batch
.
forward_mode
.
is_extend
():
# Note: Logprobs should be handled on the prefill engine.
self
.
stream_output
(
batch
.
reqs
,
False
)
self
.
stream_output
(
batch
.
reqs
,
any
(
req
.
return_logprob
for
req
in
batch
.
reqs
)
)
if
prepare_dp_attn_flag
:
self
.
_prepare_idle_batch_and_run
(
None
)
else
:
...
...
@@ -497,7 +516,9 @@ class SchedulerDisaggregationDecodeMixin:
# Generate fake extend output.
if
batch
.
forward_mode
.
is_extend
():
# Note: Logprobs should be handled on the prefill engine.
self
.
stream_output
(
batch
.
reqs
,
False
)
self
.
stream_output
(
batch
.
reqs
,
any
(
req
.
return_logprob
for
req
in
batch
.
reqs
)
)
if
prepare_dp_attn_flag
:
batch_
,
result
=
self
.
_prepare_idle_batch_and_run
(
None
,
delay_process
=
True
...
...
@@ -618,15 +639,8 @@ class SchedulerDisaggregationDecodeMixin:
def
process_decode_queue
(
self
:
Scheduler
):
req_conns
=
self
.
disagg_decode_prealloc_queue
.
pop_preallocated
()
def
_num_pre_alloc
(
req
):
return
len
(
req
.
req
.
origin_input_ids
)
+
max
(
len
(
req
.
req
.
output_ids
)
-
1
,
0
)
self
.
num_tokens_pre_allocated
+=
sum
(
_num_pre_alloc
(
req
)
for
req
in
req_conns
)
self
.
disagg_decode_transfer_queue
.
extend
(
req_conns
)
alloc_reqs
=
(
self
.
disagg_decode_transfer_queue
.
pop_transferred
()
)
# the requests which kv has arrived
self
.
num_tokens_pre_allocated
-=
sum
(
_num_pre_alloc
(
req
)
for
req
in
alloc_reqs
)
self
.
waiting_queue
.
extend
([
req
.
req
for
req
in
alloc_reqs
])
self
.
waiting_queue
.
extend
(
alloc_reqs
)
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
View file @
8233cc10
...
...
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
import
torch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -76,6 +76,11 @@ class ScheduleBatchDisaggregationDecodeMixin:
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
out_cache_loc
=
out_cache_loc
self
.
seq_lens_sum
=
sum
(
seq_lens
)
if
self
.
return_logprob
:
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
self
.
token_ids_logprobs
=
[
r
.
token_ids_logprob
for
r
in
reqs
]
self
.
extend_num_tokens
=
extend_num_tokens
self
.
prefix_lens
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
self
.
extend_lens
=
[
r
.
extend_input_len
for
r
in
reqs
]
...
...
@@ -94,12 +99,41 @@ class ScheduleBatchDisaggregationDecodeMixin:
"""Assign the buffered last input id to schedule batch"""
self
.
output_ids
=
[]
for
req
in
self
.
reqs
:
if
req
.
output_ids
and
len
(
req
.
output_ids
)
>
0
:
# resumed retracted req
self
.
output_ids
.
append
(
req
.
output_ids
[
-
1
])
else
:
assert
req
.
transferred_output_id
is
not
None
req
.
output_ids
.
append
(
req
.
transferred_output_id
)
self
.
output_ids
.
append
(
req
.
transferred_output_id
)
self
.
output_ids
.
append
(
req
.
output_ids
[
-
1
])
self
.
tree_cache
.
cache_unfinished_req
(
req
)
self
.
output_ids
=
torch
.
tensor
(
self
.
output_ids
,
device
=
self
.
device
)
# Simulate the eagle run. We add mock data to hidden states for the
# ease of implementation now meaning the first token will have acc rate
# of 0.
if
not
self
.
spec_algorithm
.
is_none
():
b
=
len
(
self
.
reqs
)
topk_p
=
torch
.
arange
(
b
*
server_args
.
speculative_eagle_topk
,
0
,
-
1
,
device
=
self
.
device
,
dtype
=
torch
.
float32
,
)
topk_p
=
topk_p
.
reshape
(
b
,
server_args
.
speculative_eagle_topk
)
topk_p
/=
b
*
server_args
.
speculative_eagle_topk
topk_index
=
torch
.
arange
(
b
*
server_args
.
speculative_eagle_topk
,
device
=
self
.
device
)
topk_index
=
topk_index
.
reshape
(
b
,
server_args
.
speculative_eagle_topk
)
# local import to avoid circular import
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
spec_info
=
EagleDraftInput
(
topk_p
=
topk_p
,
topk_index
=
topk_index
,
hidden_states
=
torch
.
ones
(
(
b
,
model_config
.
hidden_size
),
device
=
self
.
device
),
verified_id
=
self
.
output_ids
,
)
spec_info
.
prepare_for_extend
(
self
)
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
self
.
spec_info
=
spec_info
python/sglang/srt/disaggregation/mini_lb.py
View file @
8233cc10
...
...
@@ -73,11 +73,27 @@ class MiniLoadBalancer:
session
.
post
(
f
"
{
prefill_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
session
.
post
(
f
"
{
decode_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
]
# Wait for both responses to complete. Prefill should end first.
_
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
prefill_response
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
if
"return_logprob"
in
modified_request
:
prefill_json
=
await
prefill_response
.
json
()
ret_json
=
await
decode_response
.
json
()
# merge `meta_info.input_token_logprobs` from prefill to decode
if
"meta_info"
in
ret_json
:
if
"input_token_logprobs"
in
ret_json
[
"meta_info"
]:
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
=
(
prefill_json
[
"meta_info"
][
"input_token_logprobs"
]
+
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
)
else
:
ret_json
=
await
decode_response
.
json
()
return
ORJSONResponse
(
content
=
await
decode_response
.
json
()
,
content
=
ret_
json
,
status_code
=
decode_response
.
status
,
)
...
...
@@ -92,30 +108,47 @@ class MiniLoadBalancer:
total
=
3600
)
# Add timeout for request reliability
)
as
session
:
try
:
# Create the tasks for both prefill and decode requests
tasks
=
[
session
.
post
(
f
"
{
prefill_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
session
.
post
(
f
"
{
decode_server
}
/
{
endpoint
}
"
,
json
=
modified_request
),
]
# Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
# Create the tasks for both prefill and decode requests
tasks
=
[
session
.
post
(
f
"
{
prefill_server
}
/generate"
,
json
=
modified_request
),
session
.
post
(
f
"
{
decode_server
}
/generate"
,
json
=
modified_request
),
]
# Wait for both responses to complete. Since this is streaming, they return immediately.
prefill_response
,
decode_response
=
await
asyncio
.
gather
(
*
tasks
)
if
modified_request
.
get
(
"return_logprob"
,
False
):
prefill_chunks
=
[]
async
for
chunk
in
prefill_response
.
content
:
prefill_chunks
.
append
(
chunk
)
first_prefill_chunk
=
(
prefill_chunks
[
0
].
decode
(
"utf-8"
)[
5
:].
strip
(
"
\n
"
)
)
first_prefill_chunk_json
=
orjson
.
loads
(
first_prefill_chunk
)
async
for
chunk
in
decode_response
.
content
:
# Note: This is inefficient
# merge prefill input_token_logprobs, output_token_logprobs to decode
decoded_chunk
=
chunk
.
decode
(
"utf-8"
)
if
(
decoded_chunk
and
decoded_chunk
.
startswith
(
"data:"
)
and
"[DONE]"
not
in
decoded_chunk
):
ret_json
=
orjson
.
loads
(
decoded_chunk
[
5
:].
strip
(
"
\n
"
))
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
=
(
first_prefill_chunk_json
[
"meta_info"
][
"input_token_logprobs"
]
+
ret_json
[
"meta_info"
][
"input_token_logprobs"
]
)
yield
b
"data: "
+
orjson
.
dumps
(
ret_json
)
+
b
"
\n\n
"
else
:
yield
chunk
else
:
async
for
chunk
in
decode_response
.
content
:
yield
chunk
except
Exception
as
e
:
error_msg
=
{
"error"
:
{
"message"
:
f
"Stream processing error:
{
str
(
e
)
}
"
}
}
yield
b
"data: "
+
orjson
.
dumps
(
error_msg
,
option
=
orjson
.
OPT_NON_STR_KEYS
)
+
b
"
\n\n
"
finally
:
if
prefill_response
is
not
None
:
await
prefill_response
.
release
()
return
StreamingResponse
(
stream_results
(),
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
8233cc10
...
...
@@ -32,6 +32,7 @@ from sglang.srt.disaggregation.utils import (
DisaggregationMode
,
FakeBootstrapHost
,
KVClassType
,
MetadataBuffers
,
ReqToMetadataIdxAllocator
,
TransferBackend
,
get_kv_class
,
...
...
@@ -63,8 +64,7 @@ class PrefillBootstrapQueue:
token_to_kv_pool
:
KVCache
,
draft_token_to_kv_pool
:
Optional
[
KVCache
],
req_to_metadata_buffer_idx_allocator
:
ReqToMetadataIdxAllocator
,
metadata_buffers
:
List
[
torch
.
Tensor
],
aux_dtype
:
torch
.
dtype
,
metadata_buffers
:
MetadataBuffers
,
tp_rank
:
int
,
tp_size
:
int
,
bootstrap_port
:
int
,
...
...
@@ -76,7 +76,6 @@ class PrefillBootstrapQueue:
self
.
draft_token_to_kv_pool
=
draft_token_to_kv_pool
self
.
is_mla_backend
=
is_mla_backend
(
token_to_kv_pool
)
self
.
aux_dtype
=
aux_dtype
self
.
metadata_buffers
=
metadata_buffers
self
.
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
...
...
@@ -116,15 +115,9 @@ class PrefillBootstrapQueue:
kv_args
.
kv_item_lens
=
kv_item_lens
# Define req -> input ids buffer
kv_args
.
aux_data_ptrs
=
[
metadata_buffer
.
data_ptr
()
for
metadata_buffer
in
self
.
metadata_buffers
]
kv_args
.
aux_data_lens
=
[
metadata_buffer
.
nbytes
for
metadata_buffer
in
self
.
metadata_buffers
]
kv_args
.
aux_item_lens
=
[
metadata_buffer
[
0
].
nbytes
for
metadata_buffer
in
self
.
metadata_buffers
]
kv_args
.
aux_data_ptrs
,
kv_args
.
aux_data_lens
,
kv_args
.
aux_item_lens
=
(
self
.
metadata_buffers
.
get_buf_infos
()
)
kv_args
.
ib_device
=
self
.
scheduler
.
server_args
.
disaggregation_ib_device
kv_args
.
gpu_id
=
self
.
scheduler
.
gpu_id
kv_manager_class
=
get_kv_class
(
self
.
transfer_backend
,
KVClassType
.
MANAGER
)
...
...
@@ -299,10 +292,9 @@ class SchedulerDisaggregationPrefillMixin:
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
)
->
None
:
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_inf
l
ight_queue
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
Adapted from process_batch_result_prefill
"""
(
logits_output
,
next_token_ids
,
...
...
@@ -315,27 +307,78 @@ class SchedulerDisaggregationPrefillMixin:
result
.
extend_logprob_start_len_per_req
,
)
logprob_pt
=
0
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if
self
.
enable_overlap
:
# wait
_
,
next_token_ids
,
_
=
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
logits_output
,
next_token_ids
,
_
=
self
.
tp_worker
.
resolve_last_batch_result
(
launch_done
)
else
:
next_token_ids
=
result
.
next_token_ids
.
tolist
()
for
req
,
next_token_id
in
zip
(
batch
.
reqs
,
next_token_ids
,
strict
=
True
):
if
batch
.
return_logprob
:
if
logits_output
.
next_token_logprobs
is
not
None
:
logits_output
.
next_token_logprobs
=
(
logits_output
.
next_token_logprobs
.
tolist
()
)
if
logits_output
.
input_token_logprobs
is
not
None
:
logits_output
.
input_token_logprobs
=
tuple
(
logits_output
.
input_token_logprobs
.
tolist
()
)
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
,
strict
=
True
)
):
req
:
Req
if
req
.
is_chunked
<=
0
:
# There is no output_ids for prefill
req
.
output_ids
.
append
(
next_token_id
)
self
.
tree_cache
.
cache_unfinished_req
(
req
)
# update the tree and lock
self
.
send_kv_chunk
(
req
,
token_id
=
next_token_id
)
self
.
disagg_prefill_inflight_queue
.
append
(
req
)
if
req
.
return_logprob
:
assert
extend_logprob_start_len_per_req
is
not
None
assert
extend_input_len_per_req
is
not
None
extend_logprob_start_len
=
extend_logprob_start_len_per_req
[
i
]
extend_input_len
=
extend_input_len_per_req
[
i
]
num_input_logprobs
=
extend_input_len
-
extend_logprob_start_len
self
.
add_logprob_return_values
(
i
,
req
,
logprob_pt
,
next_token_ids
,
num_input_logprobs
,
logits_output
,
)
logprob_pt
+=
num_input_logprobs
self
.
send_kv_chunk
(
req
,
last_chunk
=
True
)
if
req
.
grammar
is
not
None
:
req
.
grammar
.
accept_token
(
next_token_id
)
req
.
grammar
.
finished
=
req
.
finished
()
else
:
# being chunked reqs' prefill is not finished
req
.
is_chunked
-=
1
if
req
.
return_logprob
:
extend_logprob_start_len
=
extend_logprob_start_len_per_req
[
i
]
extend_input_len
=
extend_input_len_per_req
[
i
]
if
extend_logprob_start_len
<
extend_input_len
:
# Update input logprobs.
num_input_logprobs
=
extend_input_len
-
extend_logprob_start_len
self
.
add_input_logprob_return_values
(
i
,
req
,
logits_output
,
logprob_pt
,
num_input_logprobs
,
last_prefill_chunk
=
False
,
)
logprob_pt
+=
num_input_logprobs
if
self
.
enable_overlap
:
self
.
send_kv_chunk
(
req
,
end_idx
=
req
.
tmp_end_idx
)
self
.
send_kv_chunk
(
req
,
last_chunk
=
False
,
end_idx
=
req
.
tmp_end_idx
)
# We need to remove the sync in the following function for overlap schedule.
self
.
set_next_batch_sampling_info_done
(
batch
)
def
process_disagg_prefill_inflight_queue
(
self
:
Scheduler
)
->
None
:
"""
...
...
@@ -379,7 +422,11 @@ class SchedulerDisaggregationPrefillMixin:
)
# Stream requests which have finished transfer
self
.
stream_output
(
done_reqs
,
False
,
None
)
self
.
stream_output
(
done_reqs
,
any
(
req
.
return_logprob
for
req
in
done_reqs
),
None
,
)
self
.
disagg_prefill_inflight_queue
=
undone_reqs
...
...
@@ -405,7 +452,7 @@ class SchedulerDisaggregationPrefillMixin:
def
send_kv_chunk
(
self
:
Scheduler
,
req
:
Req
,
token_id
:
Optional
[
int
]
=
Non
e
,
last_chunk
:
bool
=
Fals
e
,
end_idx
:
Optional
[
int
]
=
None
,
)
->
None
:
"""
...
...
@@ -413,37 +460,28 @@ class SchedulerDisaggregationPrefillMixin:
"""
page_size
=
self
.
token_to_kv_pool_allocator
.
page_size
start_idx
=
req
.
start_send_idx
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
# the resolved length is not the same as fill_ids's length
end_idx
=
(
end_idx
if
end_idx
is
not
None
else
min
(
len
(
req
.
fill_ids
),
len
(
req
.
origin_input_ids
))
)
last_chunk
=
token_id
is
not
None
if
not
last_chunk
:
# if not the last chunk and the last page is partial, delay the last partial page to the next send
end_idx
=
end_idx
-
end_idx
%
page_size
# Update next start_send_idx
req
.
start_send_idx
=
end_idx
kv_indices
=
(
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
start_idx
:
end_idx
]
.
cpu
()
.
numpy
()
)
if
last_chunk
is
True
:
self
.
disagg_prefill_bootstrap_queue
.
store_prefill_results
(
req
.
metadata_buffer_index
,
token_id
)
req
.
start_send_idx
=
end_idx
if
last_chunk
:
self
.
disagg_metadata_buffers
.
set_buf
(
req
)
page_indices
=
kv_to_page_indices
(
kv_indices
,
page_size
)
if
len
(
page_indices
)
==
0
:
logger
.
info
(
f
"Skip sending kv chunk for request
{
req
.
rid
=
}
{
req
.
bootstrap_room
=
}
because page_indices is empty"
)
return
req
.
disagg_kv_sender
.
send
(
page_indices
)
python/sglang/srt/disaggregation/utils.py
View file @
8233cc10
...
...
@@ -6,7 +6,7 @@ import random
import
warnings
from
collections
import
deque
from
enum
import
Enum
from
typing
import
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
numpy
as
np
import
requests
...
...
@@ -15,6 +15,9 @@ import torch.distributed as dist
from
sglang.srt.utils
import
get_ip
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
FakeBootstrapHost
=
"2.2.2.2"
# env var for testing failure, convert to float explicitly
...
...
@@ -196,3 +199,83 @@ def prepare_abort(req: Req, error_message: str, status_code=None):
req
.
input_top_logprobs_idx
=
[]
req
.
input_token_ids_logprobs_val
=
[]
req
.
input_token_ids_logprobs_idx
=
[]
class
MetadataBuffers
:
def
__init__
(
self
,
size
:
int
,
max_top_logprobs_num
:
int
=
128
):
# TODO: abort top_logprobs_num > 128 in PD
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self
.
output_ids
=
torch
.
zeros
((
size
,
16
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
output_token_logprobs_val
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
self
.
output_token_logprobs_idx
=
torch
.
zeros
(
(
size
,
16
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
output_top_logprobs_val
=
torch
.
zeros
(
(
size
,
max_top_logprobs_num
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
self
.
output_top_logprobs_idx
=
torch
.
zeros
(
(
size
,
max_top_logprobs_num
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
def
get_buf_infos
(
self
):
ptrs
=
[
self
.
output_ids
.
data_ptr
(),
self
.
output_token_logprobs_val
.
data_ptr
(),
self
.
output_token_logprobs_idx
.
data_ptr
(),
self
.
output_top_logprobs_val
.
data_ptr
(),
self
.
output_top_logprobs_idx
.
data_ptr
(),
]
data_lens
=
[
self
.
output_ids
.
nbytes
,
self
.
output_token_logprobs_val
.
nbytes
,
self
.
output_token_logprobs_idx
.
nbytes
,
self
.
output_top_logprobs_val
.
nbytes
,
self
.
output_top_logprobs_idx
.
nbytes
,
]
item_lens
=
[
self
.
output_ids
[
0
].
nbytes
,
self
.
output_token_logprobs_val
[
0
].
nbytes
,
self
.
output_token_logprobs_idx
[
0
].
nbytes
,
self
.
output_top_logprobs_val
[
0
].
nbytes
,
self
.
output_top_logprobs_idx
[
0
].
nbytes
,
]
return
ptrs
,
data_lens
,
item_lens
def
get_buf
(
self
,
idx
:
int
):
return
(
self
.
output_ids
[
idx
],
self
.
output_token_logprobs_val
[
idx
],
self
.
output_token_logprobs_idx
[
idx
],
self
.
output_top_logprobs_val
[
idx
],
self
.
output_top_logprobs_idx
[
idx
],
)
def
set_buf
(
self
,
req
:
Req
):
self
.
output_ids
[
req
.
metadata_buffer_index
][
0
]
=
req
.
output_ids
[
0
]
if
req
.
return_logprob
:
if
req
.
output_token_logprobs_val
:
# not none or empty list
self
.
output_token_logprobs_val
[
req
.
metadata_buffer_index
][
0
]
=
(
req
.
output_token_logprobs_val
[
0
]
)
if
req
.
output_token_logprobs_idx
:
# not none or empty list
self
.
output_token_logprobs_idx
[
req
.
metadata_buffer_index
][
0
]
=
(
req
.
output_token_logprobs_idx
[
0
]
)
if
req
.
output_top_logprobs_val
:
# not none or empty list
self
.
output_top_logprobs_val
[
req
.
metadata_buffer_index
][
:
len
(
req
.
output_top_logprobs_val
[
0
])
]
=
torch
.
tensor
(
req
.
output_top_logprobs_val
[
0
],
dtype
=
torch
.
float32
,
device
=
"cpu"
)
if
req
.
output_top_logprobs_idx
:
# not none or empty list
self
.
output_top_logprobs_idx
[
req
.
metadata_buffer_index
][
:
len
(
req
.
output_top_logprobs_idx
[
0
])
]
=
torch
.
tensor
(
req
.
output_top_logprobs_idx
[
0
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
python/sglang/srt/managers/schedule_batch.py
View file @
8233cc10
...
...
@@ -607,9 +607,6 @@ class Req:
self
.
tmp_end_idx
:
int
=
-
1
self
.
metadata_buffer_index
:
int
=
-
1
# The first output_id transferred from prefill instance.
self
.
transferred_output_id
:
Optional
[
int
]
=
None
@
property
def
seqlen
(
self
):
return
len
(
self
.
origin_input_ids
)
+
len
(
self
.
output_ids
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
8233cc10
...
...
@@ -48,6 +48,7 @@ from sglang.srt.disaggregation.prefill import (
)
from
sglang.srt.disaggregation.utils
import
(
DisaggregationMode
,
MetadataBuffers
,
ReqToMetadataIdxAllocator
,
TransferBackend
,
prepare_abort
,
...
...
@@ -569,20 +570,13 @@ class Scheduler(
req_to_metadata_buffer_idx_allocator
=
ReqToMetadataIdxAllocator
(
buffer_size
)
aux_dtype
=
torch
.
int32
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer
=
torch
.
zeros
(
(
buffer_size
,
16
),
dtype
=
aux_dtype
,
device
=
"cpu"
)
metadata_buffers
=
[
output_id_buffer
]
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
)
# The decode requests polling kv cache
self
.
disagg_decode_transfer_queue
=
DecodeTransferQueue
(
gloo_group
=
self
.
attn_tp_cpu_group
,
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
metadata_buffers
=
metadata_buffers
,
metadata_buffers
=
self
.
disagg_
metadata_buffers
,
scheduler
=
self
,
tree_cache
=
self
.
tree_cache
,
)
...
...
@@ -597,8 +591,7 @@ class Scheduler(
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
metadata_buffers
=
metadata_buffers
,
aux_dtype
=
aux_dtype
,
metadata_buffers
=
self
.
disagg_metadata_buffers
,
scheduler
=
self
,
transfer_queue
=
self
.
disagg_decode_transfer_queue
,
tree_cache
=
self
.
tree_cache
,
...
...
@@ -618,14 +611,7 @@ class Scheduler(
req_to_metadata_buffer_idx_allocator
=
ReqToMetadataIdxAllocator
(
buffer_size
)
aux_dtype
=
torch
.
int32
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer
=
torch
.
zeros
(
(
buffer_size
,
16
),
dtype
=
aux_dtype
,
device
=
"cpu"
)
metadata_buffers
=
[
output_id_buffer
]
self
.
disagg_metadata_buffers
=
MetadataBuffers
(
buffer_size
)
self
.
disagg_prefill_bootstrap_queue
=
PrefillBootstrapQueue
(
token_to_kv_pool
=
self
.
token_to_kv_pool_allocator
.
get_kvcache
(),
...
...
@@ -635,8 +621,7 @@ class Scheduler(
else
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator
=
req_to_metadata_buffer_idx_allocator
,
metadata_buffers
=
metadata_buffers
,
aux_dtype
=
aux_dtype
,
metadata_buffers
=
self
.
disagg_metadata_buffers
,
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
bootstrap_port
=
self
.
server_args
.
disaggregation_bootstrap_port
,
...
...
python/sglang/test/test_utils.py
View file @
8233cc10
...
...
@@ -485,7 +485,6 @@ def popen_launch_pd_server(
api_key
:
Optional
[
str
]
=
None
,
other_args
:
list
[
str
]
=
(),
env
:
Optional
[
dict
]
=
None
,
return_stdout_stderr
:
Optional
[
tuple
]
=
None
,
):
_
,
host
,
port
=
base_url
.
split
(
":"
)
host
=
host
[
2
:]
...
...
@@ -515,42 +514,9 @@ def popen_launch_pd_server(
print
(
f
"command=
{
' '
.
join
(
command
)
}
"
)
if
return_stdout_stderr
:
process
=
subprocess
.
Popen
(
command
,
stdout
=
return_stdout_stderr
[
0
],
stderr
=
return_stdout_stderr
[
1
],
env
=
env
,
text
=
True
,
)
else
:
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
,
env
=
env
)
start_time
=
time
.
perf_counter
()
with
requests
.
Session
()
as
session
:
while
time
.
perf_counter
()
-
start_time
<
timeout
:
try
:
headers
=
{
"Content-Type"
:
"application/json; charset=utf-8"
,
"Authorization"
:
f
"Bearer
{
api_key
}
"
,
}
response
=
session
.
get
(
f
"
{
base_url
}
/health"
,
headers
=
headers
,
)
if
response
.
status_code
==
200
:
return
process
except
requests
.
RequestException
:
pass
return_code
=
process
.
poll
()
if
return_code
is
not
None
:
raise
Exception
(
f
"Server unexpectedly exits (
{
return_code
=
}
)."
)
time
.
sleep
(
10
)
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
,
env
=
env
)
kill_process_tree
(
process
.
pid
)
raise
TimeoutError
(
"Server failed to start within the timeout period."
)
return
process
def
run_with_timeout
(
...
...
scripts/playground/disaggregation/cli-logprob.py
0 → 100644
View file @
8233cc10
prompt
=
"The capital of taiwan is "
import
json
import
requests
response
=
requests
.
post
(
"http://0.0.0.0:8000/generate"
,
json
=
{
"text"
:
prompt
,
"sampling_params"
:
{
"temperature"
:
0
},
"return_logprob"
:
True
,
"return_input_logprob"
:
True
,
"logprob_start_len"
:
0
,
},
)
j
=
response
.
json
()
input_logprobs
=
j
[
"meta_info"
][
"input_token_logprobs"
]
output_logprobs
=
j
[
"meta_info"
][
"output_token_logprobs"
]
print
(
len
(
input_logprobs
),
len
(
output_logprobs
))
test/srt/test_disaggregation.py
View file @
8233cc10
import
os
import
subprocess
import
time
import
unittest
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
import
requests
...
...
@@ -25,15 +27,22 @@ class TestDisaggregationAccuracy(CustomTestCase):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_host
=
"127.0.0.1"
cls
.
base_port
=
int
(
DEFAULT_URL_FOR_TEST
.
split
(
":"
)[
-
1
])
cls
.
lb_url
=
DEFAULT_URL_FOR_TEST
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
base_port
+
100
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
base_port
+
200
}
"
run_with_timeout
(
cls
.
start_prefill
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
run_with_timeout
(
cls
.
start_decode
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_host
=
parsed_url
.
hostname
base_port
=
str
(
parsed_url
.
port
)
cls
.
lb_port
=
base_port
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
print
(
f
"
{
cls
.
base_host
=
}
{
cls
.
lb_port
=
}
{
cls
.
prefill_port
=
}
{
cls
.
decode_port
=
}
"
)
# Non blocking start servers
cls
.
start_prefill
()
cls
.
start_decode
()
# Block until both
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
...
...
@@ -48,7 +57,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
"--host"
,
cls
.
base_host
,
"--port"
,
str
(
cls
.
base
_port
)
,
cls
.
lb
_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
...
...
@@ -63,14 +72,10 @@ class TestDisaggregationAccuracy(CustomTestCase):
"--trust-remote-code"
,
"--disaggregation-mode"
,
"prefill"
,
"--host"
,
cls
.
base_host
,
"--port"
,
str
(
cls
.
base_port
+
100
),
"--tp"
,
"
4
"
,
#
"--disaggregation-ib-device",
#
"mlx5_roce0
,mlx5_roce1,mlx5_roce2,mlx5_roce3
",
"
1
"
,
"--disaggregation-ib-device"
,
"mlx5_roce0"
,
]
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
model
,
...
...
@@ -85,16 +90,165 @@ class TestDisaggregationAccuracy(CustomTestCase):
"--trust-remote-code"
,
"--disaggregation-mode"
,
"decode"
,
"--tp"
,
"1"
,
"--base-gpu-id"
,
"1"
,
"--disaggregation-ib-device"
,
"mlx5_roce1"
,
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
decode_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
decode_args
,
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
60
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
f
"http://
{
self
.
base_host
}
"
,
port
=
int
(
self
.
lb_port
),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"Evaluation metrics:
{
metrics
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.62
)
def
test_logprob
(
self
):
prompt
=
"The capital of taiwan is "
response
=
requests
.
post
(
self
.
lb_url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"sampling_params"
:
{
"temperature"
:
0
},
"return_logprob"
:
True
,
"return_input_logprob"
:
True
,
"logprob_start_len"
:
0
,
},
)
j
=
response
.
json
()
completion_tokens
=
j
[
"meta_info"
][
"completion_tokens"
]
input_logprobs
=
j
[
"meta_info"
][
"input_token_logprobs"
]
output_logprobs
=
j
[
"meta_info"
][
"output_token_logprobs"
]
assert
(
len
(
output_logprobs
)
==
completion_tokens
),
f
"output_logprobs and completion_tokens should have the same length, but got
{
len
(
output_logprobs
)
}
and
{
completion_tokens
}
"
assert
(
len
(
input_logprobs
)
>
0
),
f
"input_logprobs should have at least one token, but got
{
len
(
input_logprobs
)
}
"
class
TestDisaggregationMooncakeFailure
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
# set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
os
.
environ
[
"DISAGGREGATION_TEST_FAILURE_PROB"
]
=
"0.05"
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_host
=
parsed_url
.
hostname
base_port
=
str
(
parsed_url
.
port
)
cls
.
lb_port
=
base_port
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
print
(
f
"
{
cls
.
base_host
=
}
{
cls
.
lb_port
=
}
{
cls
.
prefill_port
=
}
{
cls
.
decode_port
=
}
"
)
# Non blocking start servers
cls
.
start_prefill
()
cls
.
start_decode
()
# Block until both
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
lb_command
=
[
"python3"
,
"-m"
,
"sglang.srt.disaggregation.mini_lb"
,
"--prefill"
,
cls
.
prefill_url
,
"--decode"
,
cls
.
decode_url
,
"--host"
,
cls
.
base_host
,
"--port"
,
str
(
cls
.
base_port
+
200
),
cls
.
lb_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
cls
.
process_lb
=
subprocess
.
Popen
(
lb_command
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
cls
.
wait_server_ready
(
cls
.
lb_url
+
"/health"
)
@
classmethod
def
start_prefill
(
cls
):
prefill_args
=
[
"--trust-remote-code"
,
"--disaggregation-mode"
,
"prefill"
,
"--tp"
,
"4"
,
"1"
,
"--disaggregation-ib-device"
,
"mlx5_roce0"
,
]
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
prefill_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
prefill_args
,
)
@
classmethod
def
start_decode
(
cls
):
decode_args
=
[
"--trust-remote-code"
,
"--disaggregation-mode"
,
"decode"
,
"--tp"
,
"1"
,
"--base-gpu-id"
,
"
4
"
,
#
"--disaggregation-ib-device",
#
"mlx5_roce
4,mlx5_roce5,mlx5_roce6,mlx5_roce7
",
"
1
"
,
"--disaggregation-ib-device"
,
"mlx5_roce
1
"
,
]
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
...
...
@@ -121,6 +275,8 @@ class TestDisaggregationAccuracy(CustomTestCase):
@
classmethod
def
tearDownClass
(
cls
):
# unset DISAGGREGATION_TEST_FAILURE_PROB
os
.
environ
.
pop
(
"DISAGGREGATION_TEST_FAILURE_PROB"
)
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
...
...
@@ -128,6 +284,9 @@ class TestDisaggregationAccuracy(CustomTestCase):
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
...
...
@@ -135,27 +294,29 @@ class TestDisaggregationAccuracy(CustomTestCase):
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://
127.0.0.1
"
,
port
=
int
(
self
.
lb_
url
.
split
(
":"
)[
-
1
]
),
host
=
f
"http://
{
self
.
base_host
}
"
,
port
=
int
(
self
.
lb_
port
),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"Evaluation metrics:
{
metrics
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.62
)
# Expect lots of failure but the server cannot crash
class
TestDisaggregation
SpecAccuracy
(
CustomTestCase
):
class
TestDisaggregation
MooncakeSpec
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
super
().
setUpClass
()
cls
.
model
=
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
cls
.
draft_model
=
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
cls
.
base_host
=
"127.0.0.1"
cls
.
base_port
=
int
(
DEFAULT_URL_FOR_TEST
.
split
(
":"
)[
-
1
])
cls
.
lb_url
=
DEFAULT_URL_FOR_TEST
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
base_port
+
100
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
base_port
+
200
}
"
parsed_url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
base_host
=
parsed_url
.
hostname
base_port
=
str
(
parsed_url
.
port
)
cls
.
lb_port
=
base_port
cls
.
prefill_port
=
f
"
{
int
(
base_port
)
+
100
}
"
cls
.
decode_port
=
f
"
{
int
(
base_port
)
+
200
}
"
cls
.
prefill_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
prefill_port
}
"
cls
.
decode_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
decode_port
}
"
cls
.
lb_url
=
f
"http://
{
cls
.
base_host
}
:
{
cls
.
lb_port
}
"
cls
.
spec_args
=
[
"--speculative-algorithm"
,
"EAGLE"
,
...
...
@@ -170,10 +331,13 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
"--cuda-graph-max-bs"
,
"8"
,
]
print
(
f
"
{
cls
.
base_host
=
}
{
cls
.
lb_port
=
}
{
cls
.
prefill_port
=
}
{
cls
.
decode_port
=
}
"
)
run_with_timeout
(
cls
.
start_prefill
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
run_with_timeout
(
cls
.
start_decode
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
# Non blocking start servers
cls
.
start_prefill
()
cls
.
start_decode
()
# Block until both
cls
.
wait_server_ready
(
cls
.
prefill_url
+
"/health"
)
cls
.
wait_server_ready
(
cls
.
decode_url
+
"/health"
)
...
...
@@ -188,7 +352,7 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
"--host"
,
cls
.
base_host
,
"--port"
,
str
(
cls
.
base
_port
)
,
cls
.
lb
_port
,
]
print
(
"Starting load balancer:"
,
" "
.
join
(
lb_command
))
...
...
@@ -215,21 +379,15 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
@
classmethod
def
start_prefill
(
cls
):
prefill_args
=
[
"--trust-remote-code"
,
"--disaggregation-mode"
,
"prefill"
,
"--host"
,
cls
.
base_host
,
"--port"
,
str
(
cls
.
base_port
+
100
),
"--tp"
,
"
4
"
,
#
"--disaggregation-ib-device",
#
"mlx5_roce0,mlx5_roce1
,mlx5_roce2,mlx5_roce3
",
"
2
"
,
"--disaggregation-ib-device"
,
"mlx5_roce0,mlx5_roce1"
,
]
+
cls
.
spec_args
cls
.
process_prefill
=
popen_launch_pd_server
(
cls
.
model
,
cls
.
prefill_url
,
...
...
@@ -243,16 +401,12 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
"--trust-remote-code"
,
"--disaggregation-mode"
,
"decode"
,
"--host"
,
cls
.
base_host
,
"--port"
,
str
(
cls
.
base_port
+
200
),
"--tp"
,
"
4
"
,
"
2
"
,
"--base-gpu-id"
,
"
4
"
,
#
"--disaggregation-ib-device",
#
"mlx5_roce
4,mlx5_roce5,mlx5_roce6
,mlx5_roce
7
",
"
2
"
,
"--disaggregation-ib-device"
,
"mlx5_roce
2
,mlx5_roce
3
"
,
]
+
cls
.
spec_args
cls
.
process_decode
=
popen_launch_pd_server
(
cls
.
model
,
...
...
@@ -261,15 +415,43 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
other_args
=
decode_args
,
)
@
classmethod
def
wait_server_ready
(
cls
,
url
,
timeout
=
60
):
start_time
=
time
.
perf_counter
()
while
True
:
try
:
response
=
requests
.
get
(
url
)
if
response
.
status_code
==
200
:
print
(
f
"Server
{
url
}
is ready"
)
return
except
Exception
:
pass
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
RuntimeError
(
f
"Server
{
url
}
failed to start in
{
timeout
}
s"
)
time
.
sleep
(
1
)
@
classmethod
def
tearDownClass
(
cls
):
for
process
in
[
cls
.
process_lb
,
cls
.
process_decode
,
cls
.
process_prefill
]:
if
process
:
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process
{
process
.
pid
}
:
{
e
}
"
)
# wait for 5 seconds
time
.
sleep
(
5
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
4
,
# TODO: 128 crashes the decode
host
=
"http://
127.0.0.1
"
,
port
=
int
(
self
.
lb_
url
.
split
(
":"
)[
-
1
]
),
parallel
=
2
,
host
=
f
"http://
{
self
.
base_host
}
"
,
port
=
int
(
self
.
lb_
port
),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"Evaluation metrics:
{
metrics
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment