Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f202ed97
Unverified
Commit
f202ed97
authored
Oct 01, 2024
by
Ying Sheng
Committed by
GitHub
Oct 01, 2024
Browse files
[Refactor] Simplify io_struct and tokenizer_manager (#1549)
parent
100f5b8b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
132 additions
and
167 deletions
+132
-167
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+49
-59
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+83
-108
No files found.
python/sglang/srt/managers/io_struct.py
View file @
f202ed97
...
@@ -36,7 +36,7 @@ class GenerateReqInput:
...
@@ -36,7 +36,7 @@ class GenerateReqInput:
# See also python/sglang/srt/utils.py:load_image.
# See also python/sglang/srt/utils.py:load_image.
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The sampling_params. See descriptions below.
# The sampling_params. See descriptions below.
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]
]
=
None
# The request id.
# The request id.
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Whether to return logprobs.
# Whether to return logprobs.
...
@@ -55,28 +55,47 @@ class GenerateReqInput:
...
@@ -55,28 +55,47 @@ class GenerateReqInput:
# LoRA related
# LoRA related
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
def
post_init
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
):
):
raise
ValueError
(
"Either text or input_ids should be provided."
)
raise
ValueError
(
"Either text or input_ids should be provided."
)
if
(
self
.
is_single
=
False
isinstance
(
self
.
sampling_params
,
dict
)
if
self
.
text
is
not
None
:
and
self
.
sampling_params
.
get
(
"n"
,
1
)
!=
1
if
isinstance
(
self
.
text
,
str
):
):
self
.
is_single
=
True
is_single
=
False
self
.
batch_size
=
1
else
:
self
.
batch_size
=
len
(
self
.
text
)
else
:
else
:
if
self
.
text
is
not
None
:
if
isinstance
(
self
.
input_ids
[
0
],
int
):
is_single
=
isinstance
(
self
.
text
,
str
)
self
.
is_single
=
True
self
.
batch_size
=
1
else
:
else
:
is_single
=
isinstance
(
self
.
input_ids
[
0
],
int
)
self
.
batch_size
=
len
(
self
.
input_ids
)
self
.
is_single
=
is_single
if
self
.
sampling_params
is
None
:
self
.
parallel_sample_num
=
1
if
isinstance
(
self
.
sampling_params
,
dict
):
self
.
parallel_sample_num
=
self
.
sampling_params
.
get
(
"n"
,
1
)
else
:
# isinstance(self.sampling_params, list):
self
.
parallel_sample_num
=
self
.
sampling_params
[
0
].
get
(
"n"
,
1
)
for
sp
in
self
.
sampling_params
:
# TODO cope with the case that the parallel_sample_num is different for different samples
assert
self
.
parallel_sample_num
==
sp
.
get
(
"n"
,
1
),
"The parallel_sample_num should be the same for all samples in sample params."
if
self
.
parallel_sample_num
>
1
:
if
self
.
is_single
:
self
.
is_single
=
False
if
self
.
text
is
not
None
:
self
.
text
=
[
self
.
text
]
if
self
.
input_ids
is
not
None
:
self
.
input_ids
=
[
self
.
input_ids
]
if
is_single
:
if
self
.
is_single
:
if
self
.
sampling_params
is
None
:
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
{}
self
.
sampling_params
=
{}
if
self
.
rid
is
None
:
if
self
.
rid
is
None
:
...
@@ -88,79 +107,54 @@ class GenerateReqInput:
...
@@ -88,79 +107,54 @@ class GenerateReqInput:
if
self
.
top_logprobs_num
is
None
:
if
self
.
top_logprobs_num
is
None
:
self
.
top_logprobs_num
=
0
self
.
top_logprobs_num
=
0
else
:
else
:
parallel_sample_num_list
=
[]
if
self
.
parallel_sample_num
==
1
:
if
isinstance
(
self
.
sampling_params
,
dict
):
num
=
self
.
batch_size
parallel_sample_num
=
self
.
sampling_params
.
get
(
"n"
,
1
)
elif
isinstance
(
self
.
sampling_params
,
list
):
for
sp
in
self
.
sampling_params
:
parallel_sample_num
=
sp
.
get
(
"n"
,
1
)
parallel_sample_num_list
.
append
(
parallel_sample_num
)
parallel_sample_num
=
max
(
parallel_sample_num_list
)
all_equal
=
all
(
element
==
parallel_sample_num
for
element
in
parallel_sample_num_list
)
if
parallel_sample_num
>
1
and
(
not
all_equal
):
# TODO cope with the case that the parallel_sample_num is different for different samples
raise
ValueError
(
"The parallel_sample_num should be the same for all samples in sample params."
)
else
:
else
:
parallel_sample_num
=
1
# FIXME support cascade inference
self
.
parallel_sample_num
=
parallel_sample_num
# first bs samples are used for caching the prefix for parallel sampling
num
=
self
.
batch_size
+
self
.
parallel_sample_num
*
self
.
batch_size
if
parallel_sample_num
!=
1
:
# parallel sampling +1 represents the original prefill stage
num
=
parallel_sample_num
+
1
if
isinstance
(
self
.
text
,
list
):
# suppot batch operation
self
.
batch_size
=
len
(
self
.
text
)
num
=
num
*
len
(
self
.
text
)
elif
isinstance
(
self
.
input_ids
,
list
)
and
isinstance
(
self
.
input_ids
[
0
],
list
):
self
.
batch_size
=
len
(
self
.
input_ids
)
num
=
num
*
len
(
self
.
input_ids
)
else
:
self
.
batch_size
=
1
else
:
# support select operation
num
=
len
(
self
.
text
)
if
self
.
text
is
not
None
else
len
(
self
.
input_ids
)
self
.
batch_size
=
num
if
self
.
image_data
is
None
:
if
self
.
image_data
is
None
:
self
.
image_data
=
[
None
]
*
num
self
.
image_data
=
[
None
]
*
num
elif
not
isinstance
(
self
.
image_data
,
list
):
elif
not
isinstance
(
self
.
image_data
,
list
):
self
.
image_data
=
[
self
.
image_data
]
*
num
self
.
image_data
=
[
self
.
image_data
]
*
num
elif
isinstance
(
self
.
image_data
,
list
):
elif
isinstance
(
self
.
image_data
,
list
):
#
multi-image with n > 1
#
FIXME incorrect order for duplication
self
.
image_data
=
self
.
image_data
*
num
self
.
image_data
=
self
.
image_data
*
num
if
self
.
sampling_params
is
None
:
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
[{}]
*
num
self
.
sampling_params
=
[{}]
*
num
elif
not
isinstance
(
self
.
sampling_params
,
list
):
elif
not
isinstance
(
self
.
sampling_params
,
list
):
self
.
sampling_params
=
[
self
.
sampling_params
]
*
num
self
.
sampling_params
=
[
self
.
sampling_params
]
*
num
else
:
assert
self
.
parallel_sample_num
==
1
if
self
.
rid
is
None
:
if
self
.
rid
is
None
:
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
num
)]
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
num
)]
else
:
else
:
if
no
t
isinstance
(
self
.
rid
,
list
)
:
asser
t
isinstance
(
self
.
rid
,
list
)
,
"The rid should be a list."
raise
ValueError
(
"The rid should be a list."
)
assert
self
.
parallel_sample_num
==
1
if
self
.
return_logprob
is
None
:
if
self
.
return_logprob
is
None
:
self
.
return_logprob
=
[
False
]
*
num
self
.
return_logprob
=
[
False
]
*
num
elif
not
isinstance
(
self
.
return_logprob
,
list
):
elif
not
isinstance
(
self
.
return_logprob
,
list
):
self
.
return_logprob
=
[
self
.
return_logprob
]
*
num
self
.
return_logprob
=
[
self
.
return_logprob
]
*
num
else
:
assert
self
.
parallel_sample_num
==
1
if
self
.
logprob_start_len
is
None
:
if
self
.
logprob_start_len
is
None
:
self
.
logprob_start_len
=
[
-
1
]
*
num
self
.
logprob_start_len
=
[
-
1
]
*
num
elif
not
isinstance
(
self
.
logprob_start_len
,
list
):
elif
not
isinstance
(
self
.
logprob_start_len
,
list
):
self
.
logprob_start_len
=
[
self
.
logprob_start_len
]
*
num
self
.
logprob_start_len
=
[
self
.
logprob_start_len
]
*
num
else
:
assert
self
.
parallel_sample_num
==
1
if
self
.
top_logprobs_num
is
None
:
if
self
.
top_logprobs_num
is
None
:
self
.
top_logprobs_num
=
[
0
]
*
num
self
.
top_logprobs_num
=
[
0
]
*
num
elif
not
isinstance
(
self
.
top_logprobs_num
,
list
):
elif
not
isinstance
(
self
.
top_logprobs_num
,
list
):
self
.
top_logprobs_num
=
[
self
.
top_logprobs_num
]
*
num
self
.
top_logprobs_num
=
[
self
.
top_logprobs_num
]
*
num
else
:
assert
self
.
parallel_sample_num
==
1
@
dataclass
@
dataclass
...
@@ -199,8 +193,6 @@ class EmbeddingReqInput:
...
@@ -199,8 +193,6 @@ class EmbeddingReqInput:
# Dummy sampling params for compatibility
# Dummy sampling params for compatibility
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
is_single
:
bool
=
True
def
post_init
(
self
):
def
post_init
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
...
@@ -255,8 +247,6 @@ class RewardReqInput:
...
@@ -255,8 +247,6 @@ class RewardReqInput:
# Dummy sampling params for compatibility
# Dummy sampling params for compatibility
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
is_single
:
bool
=
True
def
post_init
(
self
):
def
post_init
(
self
):
self
.
is_single
=
isinstance
(
self
.
conv
[
0
],
dict
)
self
.
is_single
=
isinstance
(
self
.
conv
[
0
],
dict
)
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
f202ed97
...
@@ -159,58 +159,72 @@ class TokenizerManager:
...
@@ -159,58 +159,72 @@ class TokenizerManager:
async
for
response
in
self
.
_handle_batch_request
(
obj
,
request
):
async
for
response
in
self
.
_handle_batch_request
(
obj
,
request
):
yield
response
yield
response
async
def
_
handle
_single_request
(
async
def
_
send
_single_request
(
self
,
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
,
RewardReqInput
],
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
,
RewardReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
index
:
Optional
[
int
]
=
None
,
index
:
Optional
[
int
]
=
None
,
input_id_index
:
Optional
[
int
]
=
None
,
is_cache_for_prefill
:
Optional
[
bool
]
=
False
,
is_cache_for_prefill
:
Optional
[
bool
]
=
False
,
):
):
if
not
is_cache_for_prefill
:
# The normal case with a single prompt
if
not
is_cache_for_prefill
:
# The normal case with a single prompt
not_use_index
=
index
is
None
if
index
is
None
:
rid
=
obj
.
rid
rid
=
obj
.
rid
if
not_use_index
else
obj
.
rid
[
index
]
if
hasattr
(
obj
,
"conv"
):
input_text
=
obj
.
text
if
not_use_index
else
obj
.
text
[
index
]
# reward model
if
hasattr
(
obj
,
"conv"
):
conv
=
obj
.
conv
# reward model
input_text
=
self
.
tokenizer
.
apply_chat_template
(
assert
self
.
tokenizer
is
not
None
conv
,
tokenize
=
False
conv
=
obj
.
conv
if
not_use_index
else
obj
.
conv
[
index
]
)
input_text
=
self
.
tokenizer
.
apply_chat_template
(
conv
,
tokenize
=
False
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
elif
obj
.
input_ids
is
None
:
elif
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
assert
self
.
tokenizer
is
not
None
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
else
:
input_text
=
obj
.
text
if
obj
.
text
is
not
None
else
None
input_ids
=
obj
.
input_ids
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
)
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
,
obj
)
return_logprob
=
obj
.
return_logprob
logprob_start_len
=
obj
.
logprob_start_len
top_logprobs_num
=
obj
.
top_logprobs_num
else
:
else
:
input_ids
=
obj
.
input_ids
if
not_use_index
else
obj
.
input_ids
[
index
]
rid
=
obj
.
rid
[
index
]
if
hasattr
(
obj
,
"conv"
):
# reward model
conv
=
obj
.
conv
[
index
]
input_text
=
self
.
tokenizer
.
apply_chat_template
(
conv
,
tokenize
=
False
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
elif
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
[
input_id_index
]
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
else
:
input_text
=
(
obj
.
text
[
input_id_index
]
if
obj
.
text
is
not
None
else
None
)
input_ids
=
obj
.
input_ids
[
input_id_index
]
self
.
_validate_input_length
(
input_ids
)
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
[
index
],
obj
)
return_logprob
=
obj
.
return_logprob
[
index
]
logprob_start_len
=
obj
.
logprob_start_len
[
index
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
index
]
sampling_params
=
self
.
_get_sampling_params
(
self
.
_validate_input_length
(
input_ids
)
obj
.
sampling_params
if
not_use_index
else
obj
.
sampling_params
[
index
]
)
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
if
not_use_index
else
obj
.
image_data
[
index
],
obj
)
return_logprob
=
(
obj
.
return_logprob
if
not_use_index
else
obj
.
return_logprob
[
index
]
)
logprob_start_len
=
(
obj
.
logprob_start_len
if
not_use_index
else
obj
.
logprob_start_len
[
index
]
)
top_logprobs_num
=
(
obj
.
top_logprobs_num
if
not_use_index
else
obj
.
top_logprobs_num
[
index
]
)
else
:
# A prefill request to cache the common prompt for parallel sampling
else
:
# A prefill request to cache the common prompt for parallel sampling
assert
self
.
is_generation
assert
self
.
is_generation
if
obj
.
text
is
not
None
:
if
obj
.
text
is
not
None
:
if
isinstance
(
obj
.
text
,
list
):
if
isinstance
(
obj
.
text
,
list
):
input_text
=
obj
.
text
[
index
]
input_text
=
obj
.
text
[
input_id_
index
]
rid
=
obj
.
rid
[
index
]
rid
=
obj
.
rid
[
index
]
else
:
else
:
input_text
=
obj
.
text
input_text
=
obj
.
text
...
@@ -224,7 +238,7 @@ class TokenizerManager:
...
@@ -224,7 +238,7 @@ class TokenizerManager:
obj
.
input_ids
[
0
],
list
obj
.
input_ids
[
0
],
list
):
):
# when obj["input_ids"] is List[List[int]]
# when obj["input_ids"] is List[List[int]]
input_ids
=
obj
.
input_ids
[
index
]
input_ids
=
obj
.
input_ids
[
input_id_
index
]
rid
=
obj
.
rid
[
index
]
rid
=
obj
.
rid
[
index
]
else
:
else
:
input_ids
=
obj
.
input_ids
input_ids
=
obj
.
input_ids
...
@@ -235,7 +249,7 @@ class TokenizerManager:
...
@@ -235,7 +249,7 @@ class TokenizerManager:
obj
.
input_ids
[
0
],
list
obj
.
input_ids
[
0
],
list
):
):
# when obj["input_ids"] is List[List[int]]
# when obj["input_ids"] is List[List[int]]
input_ids
=
obj
.
input_ids
[
index
]
input_ids
=
obj
.
input_ids
[
input_id_
index
]
rid
=
obj
.
rid
[
index
]
rid
=
obj
.
rid
[
index
]
else
:
else
:
input_ids
=
obj
.
input_ids
input_ids
=
obj
.
input_ids
...
@@ -263,7 +277,7 @@ class TokenizerManager:
...
@@ -263,7 +277,7 @@ class TokenizerManager:
top_logprobs_num
,
top_logprobs_num
,
obj
.
stream
,
obj
.
stream
,
(
(
obj
.
lora_path
[
index
]
obj
.
lora_path
[
input_id_
index
]
if
isinstance
(
obj
.
lora_path
,
list
)
if
isinstance
(
obj
.
lora_path
,
list
)
else
obj
.
lora_path
else
obj
.
lora_path
),
),
...
@@ -283,12 +297,30 @@ class TokenizerManager:
...
@@ -283,12 +297,30 @@ class TokenizerManager:
input_ids
,
input_ids
,
sampling_params
,
sampling_params
,
)
)
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
return
rid
,
input_ids
async
def
_handle_single_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
,
RewardReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
index
:
Optional
[
int
]
=
None
,
input_id_index
:
Optional
[
int
]
=
None
,
is_cache_for_prefill
:
Optional
[
bool
]
=
False
,
):
rid
,
input_ids
=
await
self
.
_send_single_request
(
obj
,
index
,
input_id_index
=
input_id_index
,
is_cache_for_prefill
=
is_cache_for_prefill
,
)
# Recv results
# Recv results
event
=
asyncio
.
Event
()
event
=
asyncio
.
Event
()
state
=
ReqState
([],
False
,
event
)
state
=
ReqState
([],
False
,
event
)
self
.
rid_to_state
[
rid
]
=
state
self
.
rid_to_state
[
rid
]
=
state
if
not
is_cache_for_prefill
:
if
not
is_cache_for_prefill
:
async
for
response
in
self
.
_wait_for_response
(
state
,
obj
,
rid
,
request
):
async
for
response
in
self
.
_wait_for_response
(
state
,
obj
,
rid
,
request
):
yield
response
yield
response
...
@@ -312,14 +344,16 @@ class TokenizerManager:
...
@@ -312,14 +344,16 @@ class TokenizerManager:
input_id_result
=
[]
if
obj
.
input_ids
is
None
else
None
input_id_result
=
[]
if
obj
.
input_ids
is
None
else
None
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
async
for
input_id
in
self
.
_handle_single_request
(
async
for
input_id
in
self
.
_handle_single_request
(
obj
,
request
,
index
=
i
,
is_cache_for_prefill
=
True
obj
,
request
,
index
=
i
,
input_id_index
=
i
,
is_cache_for_prefill
=
True
,
):
):
if
input_id_result
is
not
None
:
if
input_id_result
is
not
None
:
input_id_result
.
append
(
input_id
)
input_id_result
.
append
(
input_id
)
if
input_id_result
is
not
None
and
len
(
input_id_result
)
>
1
:
if
input_id_result
is
not
None
:
obj
.
input_ids
=
input_id_result
obj
.
input_ids
=
input_id_result
elif
input_id_result
is
not
None
:
obj
.
input_ids
=
input_id_result
[
0
]
else
:
else
:
parallel_sample_num
=
1
parallel_sample_num
=
1
...
@@ -333,69 +367,10 @@ class TokenizerManager:
...
@@ -333,69 +367,10 @@ class TokenizerManager:
if
parallel_sample_num
!=
1
:
if
parallel_sample_num
!=
1
:
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
index
+=
batch_size
-
1
-
i
index
+=
batch_size
-
1
-
i
rid
=
obj
.
rid
[
index
]
if
parallel_sample_num
==
1
:
## select operation
if
hasattr
(
obj
,
"conv"
):
# reward model
conv
=
obj
.
conv
[
i
]
input_text
=
self
.
tokenizer
.
apply_chat_template
(
conv
,
tokenize
=
False
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
elif
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
[
i
]
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
else
:
input_text
=
None
input_ids
=
obj
.
input_ids
[
i
]
else
:
assert
obj
.
input_ids
is
not
None
if
batch_size
==
1
:
input_text
=
None
input_ids
=
obj
.
input_ids
else
:
input_text
=
None
input_ids
=
obj
.
input_ids
[
i
]
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
[
index
],
obj
)
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
,
_
=
await
self
.
_send_single_request
(
rid
,
obj
,
index
,
input_id_index
=
i
,
is_cache_for_prefill
=
False
input_text
,
)
input_ids
,
image_inputs
,
sampling_params
,
obj
.
return_logprob
[
index
],
obj
.
logprob_start_len
[
index
],
obj
.
top_logprobs_num
[
index
],
obj
.
stream
,
(
obj
.
lora_path
[
index
]
if
isinstance
(
obj
.
lora_path
,
list
)
else
obj
.
lora_path
),
)
elif
isinstance
(
obj
,
EmbeddingReqInput
):
tokenized_obj
=
TokenizedEmbeddingReqInput
(
rid
,
input_text
,
input_ids
,
sampling_params
,
)
else
:
assert
isinstance
(
obj
,
RewardReqInput
)
tokenized_obj
=
TokenizedRewardReqInput
(
rid
,
input_text
,
input_ids
,
sampling_params
,
)
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
event
=
asyncio
.
Event
()
event
=
asyncio
.
Event
()
state
=
ReqState
([],
False
,
event
)
state
=
ReqState
([],
False
,
event
)
...
@@ -418,7 +393,7 @@ class TokenizerManager:
...
@@ -418,7 +393,7 @@ class TokenizerManager:
tasks
=
[
asyncio
.
create_task
(
gen
.
__anext__
())
for
gen
in
generators
]
tasks
=
[
asyncio
.
create_task
(
gen
.
__anext__
())
for
gen
in
generators
]
output_list
=
[
None
]
*
len
(
tasks
)
output_list
=
[
None
]
*
len
(
tasks
)
#
Recv
results
#
Fetch
results
while
tasks
:
while
tasks
:
done
,
_
=
await
asyncio
.
wait
(
tasks
,
return_when
=
asyncio
.
FIRST_COMPLETED
)
done
,
_
=
await
asyncio
.
wait
(
tasks
,
return_when
=
asyncio
.
FIRST_COMPLETED
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment