Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f70f7258
Unverified
Commit
f70f7258
authored
Jun 08, 2024
by
Qubitium
Committed by
GitHub
Jun 07, 2024
Browse files
Fix rid state map leak + Refractor .finished (#505)
Co-authored-by:
ZX
<
zx@lbx.dev
>
parent
c0ae70c8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
134 additions
and
112 deletions
+134
-112
python/sglang/srt/managers/controller/dp_worker.py
python/sglang/srt/managers/controller/dp_worker.py
+4
-1
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+52
-31
python/sglang/srt/managers/controller/manager_single.py
python/sglang/srt/managers/controller/manager_single.py
+1
-1
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+8
-13
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+40
-41
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+5
-6
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+24
-19
No files found.
python/sglang/srt/managers/controller/dp_worker.py
View file @
f70f7258
...
@@ -10,6 +10,7 @@ import zmq
...
@@ -10,6 +10,7 @@ import zmq
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.tp_worker
import
ModelTpClient
from
sglang.srt.managers.controller.tp_worker
import
ModelTpClient
from
sglang.srt.managers.io_struct
import
BatchTokenIDOut
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
...
@@ -44,6 +45,8 @@ class DataParallelWorkerThread(threading.Thread):
...
@@ -44,6 +45,8 @@ class DataParallelWorkerThread(threading.Thread):
requests
=
[]
requests
=
[]
while
not
self
.
request_queue
.
empty
():
while
not
self
.
request_queue
.
empty
():
requests
.
append
(
self
.
request_queue
.
get
())
requests
.
append
(
self
.
request_queue
.
get
())
out_pyobjs
:
List
[
BatchTokenIDOut
]
=
[]
try
:
try
:
out_pyobjs
=
await
self
.
step
(
requests
)
out_pyobjs
=
await
self
.
step
(
requests
)
except
Exception
:
except
Exception
:
...
@@ -61,7 +64,7 @@ class DataParallelWorkerThread(threading.Thread):
...
@@ -61,7 +64,7 @@ class DataParallelWorkerThread(threading.Thread):
# async sleep for receiving the subsequent request and avoiding cache miss
# async sleep for receiving the subsequent request and avoiding cache miss
if
len
(
out_pyobjs
)
!=
0
:
if
len
(
out_pyobjs
)
!=
0
:
has_finished
=
any
([
obj
.
finished
for
obj
in
out_pyobjs
])
has_finished
=
any
([
obj
.
finished
_reason
is
not
None
for
obj
in
out_pyobjs
])
if
has_finished
:
if
has_finished
:
await
asyncio
.
sleep
(
self
.
request_dependency_delay
)
await
asyncio
.
sleep
(
self
.
request_dependency_delay
)
await
asyncio
.
sleep
(
global_config
.
wait_for_new_request_delay
)
await
asyncio
.
sleep
(
global_config
.
wait_for_new_request_delay
)
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
f70f7258
...
@@ -15,25 +15,47 @@ class ForwardMode(IntEnum):
...
@@ -15,25 +15,47 @@ class ForwardMode(IntEnum):
EXTEND
=
auto
()
EXTEND
=
auto
()
DECODE
=
auto
()
DECODE
=
auto
()
class
BaseFinishReason
:
def
__init__
(
self
,
is_error
:
bool
=
False
):
self
.
is_error
=
is_error
class
FinishReason
(
IntEnum
):
def
__str__
(
self
):
EOS_TOKEN
=
auto
()
raise
NotImplementedError
(
"Subclasses must implement this method"
)
LENGTH
=
auto
()
STOP_STR
=
auto
()
ABORT
=
auto
()
class
FINISH_MATCHED_TOKEN
(
BaseFinishReason
):
def
__init__
(
self
,
matched
:
int
|
List
[
int
]):
@
staticmethod
super
().
__init__
()
def
to_str
(
reason
):
self
.
matched
=
matched
if
reason
==
FinishReason
.
EOS_TOKEN
:
return
None
def
__str__
(
self
)
->
str
:
elif
reason
==
FinishReason
.
LENGTH
:
return
f
"FINISH_MATCHED_TOKEN:
{
self
.
matched
}
"
return
"length"
elif
reason
==
FinishReason
.
STOP_STR
:
return
"stop"
class
FINISH_LENGTH
(
BaseFinishReason
):
elif
reason
==
FinishReason
.
ABORT
:
def
__init__
(
self
,
length
:
int
):
return
"abort"
super
().
__init__
()
else
:
self
.
length
=
length
return
None
def
__str__
(
self
)
->
str
:
return
f
"FINISH_LENGTH:
{
self
.
length
}
"
class
FINISH_MATCHED_STR
(
BaseFinishReason
):
def
__init__
(
self
,
matched
:
str
):
super
().
__init__
()
self
.
matched
=
matched
def
__str__
(
self
)
->
str
:
return
f
"FINISH_MATCHED_STR:
{
self
.
matched
}
"
class
FINISH_ABORT
(
BaseFinishReason
):
def
__init__
(
self
):
super
().
__init__
(
is_error
=
True
)
def
__str__
(
self
)
->
str
:
return
"FINISH_ABORT"
class
Req
:
class
Req
:
...
@@ -61,11 +83,10 @@ class Req:
...
@@ -61,11 +83,10 @@ class Req:
self
.
sampling_params
=
None
self
.
sampling_params
=
None
self
.
stream
=
False
self
.
stream
=
False
# Check finish
self
.
tokenizer
=
None
self
.
tokenizer
=
None
self
.
finished
=
False
self
.
finish_reason
=
None
# Check finish
self
.
hit_stop_str
=
None
self
.
finished_reason
=
None
# Prefix info
# Prefix info
self
.
extend_input_len
=
0
self
.
extend_input_len
=
0
...
@@ -90,6 +111,10 @@ class Req:
...
@@ -90,6 +111,10 @@ class Req:
self
.
regex_fsm_state
=
0
self
.
regex_fsm_state
=
0
self
.
jump_forward_map
=
None
self
.
jump_forward_map
=
None
# whether request reached finished condition
def
finished
(
self
)
->
bool
:
return
self
.
finished_reason
is
not
None
def
partial_decode
(
self
,
ids
):
def
partial_decode
(
self
,
ids
):
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
ids
[
0
])
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
ids
[
0
])
first_token
=
(
first_token
=
(
...
@@ -101,23 +126,21 @@ class Req:
...
@@ -101,23 +126,21 @@ class Req:
return
self
.
sampling_params
.
max_new_tokens
return
self
.
sampling_params
.
max_new_tokens
def
check_finished
(
self
):
def
check_finished
(
self
):
if
self
.
finished
:
if
self
.
finished
()
:
return
return
if
(
if
(
len
(
self
.
prev_output_ids
)
+
len
(
self
.
output_ids
)
len
(
self
.
prev_output_ids
)
+
len
(
self
.
output_ids
)
>=
self
.
sampling_params
.
max_new_tokens
>=
self
.
sampling_params
.
max_new_tokens
):
):
self
.
finished
=
True
self
.
finished_reason
=
FINISH_LENGTH
(
len
(
self
.
prev_output_ids
)
+
len
(
self
.
output_ids
))
self
.
finish_reason
=
FinishReason
.
LENGTH
return
return
if
(
if
(
self
.
output_ids
[
-
1
]
==
self
.
tokenizer
.
eos_token_id
self
.
output_ids
[
-
1
]
==
self
.
tokenizer
.
eos_token_id
and
self
.
sampling_params
.
ignore_eos
==
False
and
not
self
.
sampling_params
.
ignore_eos
):
):
self
.
finished
=
True
self
.
finished_reason
=
FINISH_MATCHED_TOKEN
(
matched
=
self
.
tokenizer
.
eos_token_id
)
self
.
finish_reason
=
FinishReason
.
EOS_TOKEN
return
return
if
len
(
self
.
sampling_params
.
stop_strs
)
>
0
:
if
len
(
self
.
sampling_params
.
stop_strs
)
>
0
:
...
@@ -128,9 +151,7 @@ class Req:
...
@@ -128,9 +151,7 @@ class Req:
for
stop_str
in
self
.
sampling_params
.
stop_strs
:
for
stop_str
in
self
.
sampling_params
.
stop_strs
:
# FIXME: (minor) try incremental match in prev_output_str
# FIXME: (minor) try incremental match in prev_output_str
if
stop_str
in
tail_str
or
stop_str
in
self
.
prev_output_str
:
if
stop_str
in
tail_str
or
stop_str
in
self
.
prev_output_str
:
self
.
finished
=
True
self
.
finished_reason
=
FINISH_MATCHED_STR
(
matched
=
stop_str
)
self
.
finish_reason
=
FinishReason
.
STOP_STR
self
.
hit_stop_str
=
stop_str
return
return
def
jump_forward_and_retokenize
(
self
,
jump_forward_str
,
next_state
):
def
jump_forward_and_retokenize
(
self
,
jump_forward_str
,
next_state
):
...
...
python/sglang/srt/managers/controller/manager_single.py
View file @
f70f7258
...
@@ -45,7 +45,7 @@ class ControllerSingle:
...
@@ -45,7 +45,7 @@ class ControllerSingle:
# async sleep for receiving the subsequent request and avoiding cache miss
# async sleep for receiving the subsequent request and avoiding cache miss
slept
=
False
slept
=
False
if
len
(
out_pyobjs
)
!=
0
:
if
len
(
out_pyobjs
)
!=
0
:
has_finished
=
any
([
obj
.
finished
for
obj
in
out_pyobjs
])
has_finished
=
any
([
obj
.
finished
_reason
is
not
None
for
obj
in
out_pyobjs
])
if
has_finished
:
if
has_finished
:
if
self
.
request_dependency_delay
>
0
:
if
self
.
request_dependency_delay
>
0
:
slept
=
True
slept
=
True
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
f70f7258
...
@@ -19,7 +19,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -19,7 +19,7 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq
,
FlushCacheReq
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
)
)
from
sglang.srt.managers.controller.infer_batch
import
Ba
tch
,
FinishReason
,
ForwardMode
,
Req
from
sglang.srt.managers.controller.infer_batch
import
Ba
se
FinishReason
,
Batch
,
FINISH_ABORT
,
ForwardMode
,
Req
from
sglang.srt.managers.controller.model_runner
import
ModelRunner
from
sglang.srt.managers.controller.model_runner
import
ModelRunner
from
sglang.srt.managers.controller.radix_cache
import
RadixCache
from
sglang.srt.managers.controller.radix_cache
import
RadixCache
from
sglang.srt.managers.controller.schedule_heuristic
import
ScheduleHeuristic
from
sglang.srt.managers.controller.schedule_heuristic
import
ScheduleHeuristic
...
@@ -595,20 +595,19 @@ class ModelTpServer:
...
@@ -595,20 +595,19 @@ class ModelTpServer:
output_rids
=
[]
output_rids
=
[]
prev_output_strs
=
[]
prev_output_strs
=
[]
output_tokens
=
[]
output_tokens
=
[]
output_hit_stop_str
=
[]
output_skip_special_tokens
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_meta_info
=
[]
output_meta_info
=
[]
output_finished
=
[]
output_finished
_reason
:
List
[
BaseFinishReason
]
=
[]
finished_indices
=
[]
finished_indices
=
[]
unfinished_indices
=
[]
unfinished_indices
=
[]
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
finished
:
if
req
.
finished
()
:
finished_indices
.
append
(
i
)
finished_indices
.
append
(
i
)
else
:
else
:
unfinished_indices
.
append
(
i
)
unfinished_indices
.
append
(
i
)
if
req
.
finished
or
(
if
req
.
finished
()
or
(
(
(
req
.
stream
req
.
stream
and
(
and
(
...
@@ -620,7 +619,6 @@ class ModelTpServer:
...
@@ -620,7 +619,6 @@ class ModelTpServer:
output_rids
.
append
(
req
.
rid
)
output_rids
.
append
(
req
.
rid
)
prev_output_strs
.
append
(
req
.
prev_output_str
)
prev_output_strs
.
append
(
req
.
prev_output_str
)
output_tokens
.
append
(
req
.
output_ids
)
output_tokens
.
append
(
req
.
output_ids
)
output_hit_stop_str
.
append
(
req
.
hit_stop_str
)
output_skip_special_tokens
.
append
(
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
req
.
sampling_params
.
skip_special_tokens
)
)
...
@@ -632,8 +630,7 @@ class ModelTpServer:
...
@@ -632,8 +630,7 @@ class ModelTpServer:
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
prev_output_ids
)
+
len
(
req
.
output_ids
),
"completion_tokens"
:
len
(
req
.
prev_output_ids
)
+
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"finish_reason"
:
FinishReason
.
to_str
(
req
.
finish_reason
),
"finish_reason"
:
str
(
req
.
finished_reason
),
"hit_stop_str"
:
req
.
hit_stop_str
,
}
}
if
req
.
return_logprob
:
if
req
.
return_logprob
:
(
(
...
@@ -650,7 +647,7 @@ class ModelTpServer:
...
@@ -650,7 +647,7 @@ class ModelTpServer:
req
.
normalized_prompt_logprob
,
req
.
normalized_prompt_logprob
,
)
)
output_meta_info
.
append
(
meta_info
)
output_meta_info
.
append
(
meta_info
)
output_finished
.
append
(
req
.
finished
)
output_finished
_reason
.
append
(
req
.
finished
_reason
)
# Send to detokenizer
# Send to detokenizer
if
output_rids
:
if
output_rids
:
...
@@ -659,11 +656,10 @@ class ModelTpServer:
...
@@ -659,11 +656,10 @@ class ModelTpServer:
output_rids
,
output_rids
,
prev_output_strs
,
prev_output_strs
,
output_tokens
,
output_tokens
,
output_hit_stop_str
,
output_skip_special_tokens
,
output_skip_special_tokens
,
output_spaces_between_special_tokens
,
output_spaces_between_special_tokens
,
output_meta_info
,
output_meta_info
,
output_finished
,
output_finished
_reason
,
)
)
)
)
...
@@ -720,8 +716,7 @@ class ModelTpServer:
...
@@ -720,8 +716,7 @@ class ModelTpServer:
if
self
.
running_batch
:
if
self
.
running_batch
:
for
req
in
self
.
running_batch
.
reqs
:
for
req
in
self
.
running_batch
.
reqs
:
if
req
.
rid
==
recv_req
.
rid
:
if
req
.
rid
==
recv_req
.
rid
:
req
.
finished
=
True
req
.
finished_reason
=
FINISH_ABORT
()
req
.
finish_reason
=
FinishReason
.
ABORT
break
break
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
f70f7258
...
@@ -9,6 +9,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
...
@@ -9,6 +9,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
from
sglang.srt.managers.io_struct
import
BatchStrOut
,
BatchTokenIDOut
from
sglang.srt.managers.io_struct
import
BatchStrOut
,
BatchTokenIDOut
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.utils
import
get_exception_traceback
,
graceful_registry
from
sglang.utils
import
get_exception_traceback
,
graceful_registry
from
sglang.srt.managers.controller.infer_batch
import
FINISH_MATCHED_STR
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
@@ -34,49 +35,47 @@ class DetokenizerManager:
...
@@ -34,49 +35,47 @@ class DetokenizerManager:
async
def
handle_loop
(
self
):
async
def
handle_loop
(
self
):
while
True
:
while
True
:
recv_obj
=
await
self
.
recv_from_router
.
recv_pyobj
()
recv_obj
:
BatchTokenIDOut
=
await
self
.
recv_from_router
.
recv_pyobj
()
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
if
isinstance
(
recv_obj
,
BatchTokenIDOut
):
output_tokens
=
recv_obj
.
output_tokens
output_tokens
=
recv_obj
.
output_tokens
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
output_strs
=
self
.
tokenizer
.
batch_decode
(
output_strs
=
self
.
tokenizer
.
batch_decode
(
output_tokens
,
output_tokens
,
skip_special_tokens
=
recv_obj
.
skip_special_tokens
[
0
],
skip_special_tokens
=
recv_obj
.
skip_special_tokens
[
0
],
spaces_between_special_tokens
=
recv_obj
.
spaces_between_special_tokens
[
spaces_between_special_tokens
=
recv_obj
.
spaces_between_special_tokens
[
0
0
],
],
)
)
# Trim stop str
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
# TODO(lmzheng): handle the case where multiple stop strs are hit
for
i
in
range
(
len
(
output_strs
)):
for
i
in
range
(
len
(
output_strs
)):
if
len
(
output_tokens
[
i
])
>
0
:
if
len
(
output_tokens
[
i
])
>
0
:
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
int
(
output_tokens
[
i
][
0
])
int
(
output_tokens
[
i
][
0
])
)
if
not
isinstance
(
first_token
,
str
):
first_token
=
first_token
.
decode
(
"utf-8"
,
errors
=
"ignore"
)
if
first_token
.
startswith
(
"▁"
):
output_strs
[
i
]
=
" "
+
output_strs
[
i
]
output_strs
[
i
]
=
recv_obj
.
prev_output_strs
[
i
]
+
output_strs
[
i
]
if
recv_obj
.
hit_stop_str
[
i
]
is
not
None
:
pos
=
output_strs
[
i
].
find
(
recv_obj
.
hit_stop_str
[
i
])
if
pos
!=
-
1
:
output_strs
[
i
]
=
output_strs
[
i
][:
pos
]
self
.
send_to_tokenizer
.
send_pyobj
(
BatchStrOut
(
recv_obj
.
rids
,
output_strs
,
recv_obj
.
meta_info
,
recv_obj
.
finished
,
)
)
if
not
isinstance
(
first_token
,
str
):
first_token
=
first_token
.
decode
(
"utf-8"
,
errors
=
"ignore"
)
if
first_token
.
startswith
(
"▁"
):
output_strs
[
i
]
=
" "
+
output_strs
[
i
]
output_strs
[
i
]
=
recv_obj
.
prev_output_strs
[
i
]
+
output_strs
[
i
]
if
isinstance
(
recv_obj
.
finished_reason
[
i
],
FINISH_MATCHED_STR
):
pos
=
output_strs
[
i
].
find
(
recv_obj
.
finished_reason
[
i
].
matched
)
if
pos
!=
-
1
:
output_strs
[
i
]
=
output_strs
[
i
][:
pos
]
self
.
send_to_tokenizer
.
send_pyobj
(
BatchStrOut
(
rids
=
recv_obj
.
rids
,
output_str
=
output_strs
,
meta_info
=
recv_obj
.
meta_info
,
finished_reason
=
recv_obj
.
finished_reason
,
)
)
else
:
)
raise
ValueError
(
f
"Invalid object:
{
recv_obj
}
"
)
def
start_detokenizer_process
(
def
start_detokenizer_process
(
...
...
python/sglang/srt/managers/io_struct.py
View file @
f70f7258
...
@@ -3,6 +3,7 @@ from dataclasses import dataclass
...
@@ -3,6 +3,7 @@ from dataclasses import dataclass
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.managers.controller.infer_batch
import
BaseFinishReason
@
dataclass
@
dataclass
...
@@ -105,21 +106,19 @@ class TokenizedGenerateReqInput:
...
@@ -105,21 +106,19 @@ class TokenizedGenerateReqInput:
@
dataclass
@
dataclass
class
BatchTokenIDOut
:
class
BatchTokenIDOut
:
rids
:
List
[
str
]
rids
:
List
[
str
]
prev_output_strs
:
List
[
str
]
prev_output_strs
:
List
[
str
]
output_tokens
:
List
[
List
[
int
]]
output_tokens
:
List
[
List
[
int
]]
hit_stop_str
:
List
[
Optional
[
str
]]
skip_special_tokens
:
List
[
bool
]
skip_special_tokens
:
List
[
bool
]
spaces_between_special_tokens
:
List
[
bool
]
spaces_between_special_tokens
:
List
[
bool
]
meta_info
:
List
[
Dict
]
meta_info
:
List
[
Dict
]
finished
:
List
[
bool
]
finished_reason
:
List
[
BaseFinishReason
]
@
dataclass
@
dataclass
class
BatchStrOut
:
class
BatchStrOut
:
rids
:
List
[
str
]
rids
:
List
[
str
]
output_str
:
List
[
str
]
output_str
:
List
[
str
]
meta_info
:
List
[
Dict
]
meta_info
:
List
[
Dict
]
finished
:
List
[
bool
]
finished
_reason
:
List
[
BaseFinishReason
]
@
dataclass
@
dataclass
...
@@ -134,4 +133,4 @@ class AbortReq:
...
@@ -134,4 +133,4 @@ class AbortReq:
@
dataclass
@
dataclass
class
DetokenizeReqInput
:
class
DetokenizeReqInput
:
input_ids
:
List
[
int
]
input_ids
:
List
[
int
]
\ No newline at end of file
python/sglang/srt/managers/tokenizer_manager.py
View file @
f70f7258
...
@@ -4,7 +4,7 @@ import dataclasses
...
@@ -4,7 +4,7 @@ import dataclasses
import
logging
import
logging
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
os
import
os
from
typing
import
List
from
typing
import
List
,
Dict
import
numpy
as
np
import
numpy
as
np
import
transformers
import
transformers
...
@@ -26,6 +26,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -26,6 +26,7 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput
,
GenerateReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
)
)
from
sglang.srt.managers.io_struct
import
BatchTokenIDOut
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
@@ -89,7 +90,7 @@ class TokenizerManager:
...
@@ -89,7 +90,7 @@ class TokenizerManager:
)
)
self
.
to_create_loop
=
True
self
.
to_create_loop
=
True
self
.
rid_to_state
=
{}
#
Dict[str
->
ReqState]
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
async
def
get_pixel_values
(
self
,
image_data
):
async
def
get_pixel_values
(
self
,
image_data
):
aspect_ratio
=
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
aspect_ratio
=
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
...
@@ -183,12 +184,17 @@ class TokenizerManager:
...
@@ -183,12 +184,17 @@ class TokenizerManager:
if
self
.
server_args
.
log_requests
and
state
.
finished
:
if
self
.
server_args
.
log_requests
and
state
.
finished
:
logger
.
info
(
f
"in=
{
obj
.
text
}
, out=
{
out
}
"
)
logger
.
info
(
f
"in=
{
obj
.
text
}
, out=
{
out
}
"
)
yield
out
state
.
out_list
=
[]
state
.
out_list
=
[]
if
state
.
finished
:
if
state
.
finished
:
del
self
.
rid_to_state
[
rid
]
del
self
.
rid_to_state
[
rid
]
yield
out
break
break
event
.
clear
()
event
.
clear
()
yield
out
else
:
else
:
if
obj
.
stream
:
if
obj
.
stream
:
raise
ValueError
(
"Do not support stream for batch mode."
)
raise
ValueError
(
"Do not support stream for batch mode."
)
...
@@ -298,24 +304,23 @@ class TokenizerManager:
...
@@ -298,24 +304,23 @@ class TokenizerManager:
async
def
handle_loop
(
self
):
async
def
handle_loop
(
self
):
while
True
:
while
True
:
recv_obj
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
recv_obj
:
BatchTokenIDOut
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
assert
isinstance
(
recv_obj
,
BatchStrOut
)
if
isinstance
(
recv_obj
,
BatchStrOut
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
if
state
is
None
:
if
state
is
None
:
continue
continue
recv_obj
.
meta_info
[
i
][
"id"
]
=
rid
out_dict
=
{
"text"
:
recv_obj
.
output_str
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
state
.
out_list
.
append
(
out_dict
)
state
.
finished
=
recv_obj
.
finished_reason
[
i
]
is
not
None
state
.
event
.
set
()
recv_obj
.
meta_info
[
i
][
"id"
]
=
rid
out_dict
=
{
"text"
:
recv_obj
.
output_str
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
state
.
out_list
.
append
(
out_dict
)
state
.
finished
=
recv_obj
.
finished
[
i
]
state
.
event
.
set
()
else
:
raise
ValueError
(
f
"Invalid object:
{
recv_obj
}
."
)
def
convert_logprob_style
(
def
convert_logprob_style
(
self
,
ret
,
return_logprob
,
top_logprobs_num
,
return_text_in_logprobs
self
,
ret
,
return_logprob
,
top_logprobs_num
,
return_text_in_logprobs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment