Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f70f7258
Unverified
Commit
f70f7258
authored
Jun 08, 2024
by
Qubitium
Committed by
GitHub
Jun 07, 2024
Browse files
Fix rid state map leak + Refractor .finished (#505)
Co-authored-by:
ZX
<
zx@lbx.dev
>
parent
c0ae70c8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
134 additions
and
112 deletions
+134
-112
python/sglang/srt/managers/controller/dp_worker.py
python/sglang/srt/managers/controller/dp_worker.py
+4
-1
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+52
-31
python/sglang/srt/managers/controller/manager_single.py
python/sglang/srt/managers/controller/manager_single.py
+1
-1
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+8
-13
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+40
-41
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+5
-6
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+24
-19
No files found.
python/sglang/srt/managers/controller/dp_worker.py
View file @
f70f7258
...
...
@@ -10,6 +10,7 @@ import zmq
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.tp_worker
import
ModelTpClient
from
sglang.srt.managers.io_struct
import
BatchTokenIDOut
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.utils
import
get_exception_traceback
...
...
@@ -44,6 +45,8 @@ class DataParallelWorkerThread(threading.Thread):
requests
=
[]
while
not
self
.
request_queue
.
empty
():
requests
.
append
(
self
.
request_queue
.
get
())
out_pyobjs
:
List
[
BatchTokenIDOut
]
=
[]
try
:
out_pyobjs
=
await
self
.
step
(
requests
)
except
Exception
:
...
...
@@ -61,7 +64,7 @@ class DataParallelWorkerThread(threading.Thread):
# async sleep for receiving the subsequent request and avoiding cache miss
if
len
(
out_pyobjs
)
!=
0
:
has_finished
=
any
([
obj
.
finished
for
obj
in
out_pyobjs
])
has_finished
=
any
([
obj
.
finished
_reason
is
not
None
for
obj
in
out_pyobjs
])
if
has_finished
:
await
asyncio
.
sleep
(
self
.
request_dependency_delay
)
await
asyncio
.
sleep
(
global_config
.
wait_for_new_request_delay
)
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
f70f7258
...
...
@@ -15,25 +15,47 @@ class ForwardMode(IntEnum):
EXTEND
=
auto
()
DECODE
=
auto
()
class
BaseFinishReason
:
def
__init__
(
self
,
is_error
:
bool
=
False
):
self
.
is_error
=
is_error
class
FinishReason
(
IntEnum
):
EOS_TOKEN
=
auto
()
LENGTH
=
auto
()
STOP_STR
=
auto
()
ABORT
=
auto
()
@
staticmethod
def
to_str
(
reason
):
if
reason
==
FinishReason
.
EOS_TOKEN
:
return
None
elif
reason
==
FinishReason
.
LENGTH
:
return
"length"
elif
reason
==
FinishReason
.
STOP_STR
:
return
"stop"
elif
reason
==
FinishReason
.
ABORT
:
return
"abort"
else
:
return
None
def
__str__
(
self
):
raise
NotImplementedError
(
"Subclasses must implement this method"
)
class
FINISH_MATCHED_TOKEN
(
BaseFinishReason
):
def
__init__
(
self
,
matched
:
int
|
List
[
int
]):
super
().
__init__
()
self
.
matched
=
matched
def
__str__
(
self
)
->
str
:
return
f
"FINISH_MATCHED_TOKEN:
{
self
.
matched
}
"
class
FINISH_LENGTH
(
BaseFinishReason
):
def
__init__
(
self
,
length
:
int
):
super
().
__init__
()
self
.
length
=
length
def
__str__
(
self
)
->
str
:
return
f
"FINISH_LENGTH:
{
self
.
length
}
"
class
FINISH_MATCHED_STR
(
BaseFinishReason
):
def
__init__
(
self
,
matched
:
str
):
super
().
__init__
()
self
.
matched
=
matched
def
__str__
(
self
)
->
str
:
return
f
"FINISH_MATCHED_STR:
{
self
.
matched
}
"
class
FINISH_ABORT
(
BaseFinishReason
):
def
__init__
(
self
):
super
().
__init__
(
is_error
=
True
)
def
__str__
(
self
)
->
str
:
return
"FINISH_ABORT"
class
Req
:
...
...
@@ -61,11 +83,10 @@ class Req:
self
.
sampling_params
=
None
self
.
stream
=
False
# Check finish
self
.
tokenizer
=
None
self
.
finished
=
False
self
.
finish_reason
=
None
self
.
hit_stop_str
=
None
# Check finish
self
.
finished_reason
=
None
# Prefix info
self
.
extend_input_len
=
0
...
...
@@ -90,6 +111,10 @@ class Req:
self
.
regex_fsm_state
=
0
self
.
jump_forward_map
=
None
# whether request reached finished condition
def
finished
(
self
)
->
bool
:
return
self
.
finished_reason
is
not
None
def
partial_decode
(
self
,
ids
):
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
ids
[
0
])
first_token
=
(
...
...
@@ -101,23 +126,21 @@ class Req:
return
self
.
sampling_params
.
max_new_tokens
def
check_finished
(
self
):
if
self
.
finished
:
if
self
.
finished
()
:
return
if
(
len
(
self
.
prev_output_ids
)
+
len
(
self
.
output_ids
)
>=
self
.
sampling_params
.
max_new_tokens
):
self
.
finished
=
True
self
.
finish_reason
=
FinishReason
.
LENGTH
self
.
finished_reason
=
FINISH_LENGTH
(
len
(
self
.
prev_output_ids
)
+
len
(
self
.
output_ids
))
return
if
(
self
.
output_ids
[
-
1
]
==
self
.
tokenizer
.
eos_token_id
and
self
.
sampling_params
.
ignore_eos
==
False
and
not
self
.
sampling_params
.
ignore_eos
):
self
.
finished
=
True
self
.
finish_reason
=
FinishReason
.
EOS_TOKEN
self
.
finished_reason
=
FINISH_MATCHED_TOKEN
(
matched
=
self
.
tokenizer
.
eos_token_id
)
return
if
len
(
self
.
sampling_params
.
stop_strs
)
>
0
:
...
...
@@ -128,9 +151,7 @@ class Req:
for
stop_str
in
self
.
sampling_params
.
stop_strs
:
# FIXME: (minor) try incremental match in prev_output_str
if
stop_str
in
tail_str
or
stop_str
in
self
.
prev_output_str
:
self
.
finished
=
True
self
.
finish_reason
=
FinishReason
.
STOP_STR
self
.
hit_stop_str
=
stop_str
self
.
finished_reason
=
FINISH_MATCHED_STR
(
matched
=
stop_str
)
return
def
jump_forward_and_retokenize
(
self
,
jump_forward_str
,
next_state
):
...
...
python/sglang/srt/managers/controller/manager_single.py
View file @
f70f7258
...
...
@@ -45,7 +45,7 @@ class ControllerSingle:
# async sleep for receiving the subsequent request and avoiding cache miss
slept
=
False
if
len
(
out_pyobjs
)
!=
0
:
has_finished
=
any
([
obj
.
finished
for
obj
in
out_pyobjs
])
has_finished
=
any
([
obj
.
finished
_reason
is
not
None
for
obj
in
out_pyobjs
])
if
has_finished
:
if
self
.
request_dependency_delay
>
0
:
slept
=
True
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
f70f7258
...
...
@@ -19,7 +19,7 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq
,
TokenizedGenerateReqInput
,
)
from
sglang.srt.managers.controller.infer_batch
import
Ba
tch
,
FinishReason
,
ForwardMode
,
Req
from
sglang.srt.managers.controller.infer_batch
import
Ba
se
FinishReason
,
Batch
,
FINISH_ABORT
,
ForwardMode
,
Req
from
sglang.srt.managers.controller.model_runner
import
ModelRunner
from
sglang.srt.managers.controller.radix_cache
import
RadixCache
from
sglang.srt.managers.controller.schedule_heuristic
import
ScheduleHeuristic
...
...
@@ -595,20 +595,19 @@ class ModelTpServer:
output_rids
=
[]
prev_output_strs
=
[]
output_tokens
=
[]
output_hit_stop_str
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_meta_info
=
[]
output_finished
=
[]
output_finished
_reason
:
List
[
BaseFinishReason
]
=
[]
finished_indices
=
[]
unfinished_indices
=
[]
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
finished
:
if
req
.
finished
()
:
finished_indices
.
append
(
i
)
else
:
unfinished_indices
.
append
(
i
)
if
req
.
finished
or
(
if
req
.
finished
()
or
(
(
req
.
stream
and
(
...
...
@@ -620,7 +619,6 @@ class ModelTpServer:
output_rids
.
append
(
req
.
rid
)
prev_output_strs
.
append
(
req
.
prev_output_str
)
output_tokens
.
append
(
req
.
output_ids
)
output_hit_stop_str
.
append
(
req
.
hit_stop_str
)
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
...
...
@@ -632,8 +630,7 @@ class ModelTpServer:
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
prev_output_ids
)
+
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"finish_reason"
:
FinishReason
.
to_str
(
req
.
finish_reason
),
"hit_stop_str"
:
req
.
hit_stop_str
,
"finish_reason"
:
str
(
req
.
finished_reason
),
}
if
req
.
return_logprob
:
(
...
...
@@ -650,7 +647,7 @@ class ModelTpServer:
req
.
normalized_prompt_logprob
,
)
output_meta_info
.
append
(
meta_info
)
output_finished
.
append
(
req
.
finished
)
output_finished
_reason
.
append
(
req
.
finished
_reason
)
# Send to detokenizer
if
output_rids
:
...
...
@@ -659,11 +656,10 @@ class ModelTpServer:
output_rids
,
prev_output_strs
,
output_tokens
,
output_hit_stop_str
,
output_skip_special_tokens
,
output_spaces_between_special_tokens
,
output_meta_info
,
output_finished
,
output_finished
_reason
,
)
)
...
...
@@ -720,8 +716,7 @@ class ModelTpServer:
if
self
.
running_batch
:
for
req
in
self
.
running_batch
.
reqs
:
if
req
.
rid
==
recv_req
.
rid
:
req
.
finished
=
True
req
.
finish_reason
=
FinishReason
.
ABORT
req
.
finished_reason
=
FINISH_ABORT
()
break
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
f70f7258
...
...
@@ -9,6 +9,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
from
sglang.srt.managers.io_struct
import
BatchStrOut
,
BatchTokenIDOut
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.utils
import
get_exception_traceback
,
graceful_registry
from
sglang.srt.managers.controller.infer_batch
import
FINISH_MATCHED_STR
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
...
@@ -34,49 +35,47 @@ class DetokenizerManager:
async
def
handle_loop
(
self
):
while
True
:
recv_obj
=
await
self
.
recv_from_router
.
recv_pyobj
()
if
isinstance
(
recv_obj
,
BatchTokenIDOut
):
output_tokens
=
recv_obj
.
output_tokens
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
output_strs
=
self
.
tokenizer
.
batch_decode
(
output_tokens
,
skip_special_tokens
=
recv_obj
.
skip_special_tokens
[
0
],
spaces_between_special_tokens
=
recv_obj
.
spaces_between_special_tokens
[
0
],
)
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
for
i
in
range
(
len
(
output_strs
)):
if
len
(
output_tokens
[
i
])
>
0
:
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
int
(
output_tokens
[
i
][
0
])
)
if
not
isinstance
(
first_token
,
str
):
first_token
=
first_token
.
decode
(
"utf-8"
,
errors
=
"ignore"
)
if
first_token
.
startswith
(
"▁"
):
output_strs
[
i
]
=
" "
+
output_strs
[
i
]
output_strs
[
i
]
=
recv_obj
.
prev_output_strs
[
i
]
+
output_strs
[
i
]
if
recv_obj
.
hit_stop_str
[
i
]
is
not
None
:
pos
=
output_strs
[
i
].
find
(
recv_obj
.
hit_stop_str
[
i
])
if
pos
!=
-
1
:
output_strs
[
i
]
=
output_strs
[
i
][:
pos
]
self
.
send_to_tokenizer
.
send_pyobj
(
BatchStrOut
(
recv_obj
.
rids
,
output_strs
,
recv_obj
.
meta_info
,
recv_obj
.
finished
,
recv_obj
:
BatchTokenIDOut
=
await
self
.
recv_from_router
.
recv_pyobj
()
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
output_tokens
=
recv_obj
.
output_tokens
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
output_strs
=
self
.
tokenizer
.
batch_decode
(
output_tokens
,
skip_special_tokens
=
recv_obj
.
skip_special_tokens
[
0
],
spaces_between_special_tokens
=
recv_obj
.
spaces_between_special_tokens
[
0
],
)
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
for
i
in
range
(
len
(
output_strs
)):
if
len
(
output_tokens
[
i
])
>
0
:
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
int
(
output_tokens
[
i
][
0
])
)
if
not
isinstance
(
first_token
,
str
):
first_token
=
first_token
.
decode
(
"utf-8"
,
errors
=
"ignore"
)
if
first_token
.
startswith
(
"▁"
):
output_strs
[
i
]
=
" "
+
output_strs
[
i
]
output_strs
[
i
]
=
recv_obj
.
prev_output_strs
[
i
]
+
output_strs
[
i
]
if
isinstance
(
recv_obj
.
finished_reason
[
i
],
FINISH_MATCHED_STR
):
pos
=
output_strs
[
i
].
find
(
recv_obj
.
finished_reason
[
i
].
matched
)
if
pos
!=
-
1
:
output_strs
[
i
]
=
output_strs
[
i
][:
pos
]
self
.
send_to_tokenizer
.
send_pyobj
(
BatchStrOut
(
rids
=
recv_obj
.
rids
,
output_str
=
output_strs
,
meta_info
=
recv_obj
.
meta_info
,
finished_reason
=
recv_obj
.
finished_reason
,
)
else
:
raise
ValueError
(
f
"Invalid object:
{
recv_obj
}
"
)
)
def
start_detokenizer_process
(
...
...
python/sglang/srt/managers/io_struct.py
View file @
f70f7258
...
...
@@ -3,6 +3,7 @@ from dataclasses import dataclass
from
typing
import
Dict
,
List
,
Optional
,
Union
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.managers.controller.infer_batch
import
BaseFinishReason
@
dataclass
...
...
@@ -105,21 +106,19 @@ class TokenizedGenerateReqInput:
@
dataclass
class
BatchTokenIDOut
:
rids
:
List
[
str
]
prev_output_strs
:
List
[
str
]
prev_output_strs
:
List
[
str
]
output_tokens
:
List
[
List
[
int
]]
hit_stop_str
:
List
[
Optional
[
str
]]
skip_special_tokens
:
List
[
bool
]
spaces_between_special_tokens
:
List
[
bool
]
meta_info
:
List
[
Dict
]
finished
:
List
[
bool
]
finished_reason
:
List
[
BaseFinishReason
]
@
dataclass
class
BatchStrOut
:
rids
:
List
[
str
]
output_str
:
List
[
str
]
meta_info
:
List
[
Dict
]
finished
:
List
[
bool
]
finished
_reason
:
List
[
BaseFinishReason
]
@
dataclass
...
...
@@ -134,4 +133,4 @@ class AbortReq:
@
dataclass
class
DetokenizeReqInput
:
input_ids
:
List
[
int
]
\ No newline at end of file
input_ids
:
List
[
int
]
python/sglang/srt/managers/tokenizer_manager.py
View file @
f70f7258
...
...
@@ -4,7 +4,7 @@ import dataclasses
import
logging
import
multiprocessing
as
mp
import
os
from
typing
import
List
from
typing
import
List
,
Dict
import
numpy
as
np
import
transformers
...
...
@@ -26,6 +26,7 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput
,
TokenizedGenerateReqInput
,
)
from
sglang.srt.managers.io_struct
import
BatchTokenIDOut
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
...
@@ -89,7 +90,7 @@ class TokenizerManager:
)
self
.
to_create_loop
=
True
self
.
rid_to_state
=
{}
#
Dict[str
->
ReqState]
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
async
def
get_pixel_values
(
self
,
image_data
):
aspect_ratio
=
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
...
...
@@ -183,12 +184,17 @@ class TokenizerManager:
if
self
.
server_args
.
log_requests
and
state
.
finished
:
logger
.
info
(
f
"in=
{
obj
.
text
}
, out=
{
out
}
"
)
yield
out
state
.
out_list
=
[]
if
state
.
finished
:
del
self
.
rid_to_state
[
rid
]
yield
out
break
event
.
clear
()
yield
out
else
:
if
obj
.
stream
:
raise
ValueError
(
"Do not support stream for batch mode."
)
...
...
@@ -298,24 +304,23 @@ class TokenizerManager:
async
def
handle_loop
(
self
):
while
True
:
recv_obj
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
recv_obj
:
BatchTokenIDOut
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
assert
isinstance
(
recv_obj
,
BatchStrOut
)
if
isinstance
(
recv_obj
,
BatchStrOut
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
if
state
is
None
:
continue
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
if
state
is
None
:
continue
recv_obj
.
meta_info
[
i
][
"id"
]
=
rid
out_dict
=
{
"text"
:
recv_obj
.
output_str
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
state
.
out_list
.
append
(
out_dict
)
state
.
finished
=
recv_obj
.
finished_reason
[
i
]
is
not
None
state
.
event
.
set
()
recv_obj
.
meta_info
[
i
][
"id"
]
=
rid
out_dict
=
{
"text"
:
recv_obj
.
output_str
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
state
.
out_list
.
append
(
out_dict
)
state
.
finished
=
recv_obj
.
finished
[
i
]
state
.
event
.
set
()
else
:
raise
ValueError
(
f
"Invalid object:
{
recv_obj
}
."
)
def
convert_logprob_style
(
self
,
ret
,
return_logprob
,
top_logprobs_num
,
return_text_in_logprobs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment