"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ac863934870556505f6035127ed39466e57b6002"
Unverified Commit f202ed97 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Refactor] Simplify io_struct and tokenizer_manager (#1549)

parent 100f5b8b
...@@ -36,7 +36,7 @@ class GenerateReqInput: ...@@ -36,7 +36,7 @@ class GenerateReqInput:
# See also python/sglang/srt/utils.py:load_image. # See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None image_data: Optional[Union[List[str], str]] = None
# The sampling_params. See descriptions below. # The sampling_params. See descriptions below.
sampling_params: Union[List[Dict], Dict] = None sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id. # The request id.
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Whether to return logprobs. # Whether to return logprobs.
...@@ -55,28 +55,47 @@ class GenerateReqInput: ...@@ -55,28 +55,47 @@ class GenerateReqInput:
# LoRA related # LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Whether it is a single request or a batch request
is_single: bool = True
def post_init(self): def post_init(self):
if (self.text is None and self.input_ids is None) or ( if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None self.text is not None and self.input_ids is not None
): ):
raise ValueError("Either text or input_ids should be provided.") raise ValueError("Either text or input_ids should be provided.")
if ( self.is_single = False
isinstance(self.sampling_params, dict) if self.text is not None:
and self.sampling_params.get("n", 1) != 1 if isinstance(self.text, str):
): self.is_single = True
is_single = False self.batch_size = 1
else:
self.batch_size = len(self.text)
else: else:
if self.text is not None: if isinstance(self.input_ids[0], int):
is_single = isinstance(self.text, str) self.is_single = True
self.batch_size = 1
else: else:
is_single = isinstance(self.input_ids[0], int) self.batch_size = len(self.input_ids)
self.is_single = is_single
if self.sampling_params is None:
self.parallel_sample_num = 1
if isinstance(self.sampling_params, dict):
self.parallel_sample_num = self.sampling_params.get("n", 1)
else: # isinstance(self.sampling_params, list):
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
for sp in self.sampling_params:
# TODO cope with the case that the parallel_sample_num is different for different samples
assert self.parallel_sample_num == sp.get(
"n", 1
), "The parallel_sample_num should be the same for all samples in sample params."
if self.parallel_sample_num > 1:
if self.is_single:
self.is_single = False
if self.text is not None:
self.text = [self.text]
if self.input_ids is not None:
self.input_ids = [self.input_ids]
if is_single: if self.is_single:
if self.sampling_params is None: if self.sampling_params is None:
self.sampling_params = {} self.sampling_params = {}
if self.rid is None: if self.rid is None:
...@@ -88,79 +107,54 @@ class GenerateReqInput: ...@@ -88,79 +107,54 @@ class GenerateReqInput:
if self.top_logprobs_num is None: if self.top_logprobs_num is None:
self.top_logprobs_num = 0 self.top_logprobs_num = 0
else: else:
parallel_sample_num_list = [] if self.parallel_sample_num == 1:
if isinstance(self.sampling_params, dict): num = self.batch_size
parallel_sample_num = self.sampling_params.get("n", 1)
elif isinstance(self.sampling_params, list):
for sp in self.sampling_params:
parallel_sample_num = sp.get("n", 1)
parallel_sample_num_list.append(parallel_sample_num)
parallel_sample_num = max(parallel_sample_num_list)
all_equal = all(
element == parallel_sample_num
for element in parallel_sample_num_list
)
if parallel_sample_num > 1 and (not all_equal):
# TODO cope with the case that the parallel_sample_num is different for different samples
raise ValueError(
"The parallel_sample_num should be the same for all samples in sample params."
)
else: else:
parallel_sample_num = 1 # FIXME support cascade inference
self.parallel_sample_num = parallel_sample_num # first bs samples are used for caching the prefix for parallel sampling
num = self.batch_size + self.parallel_sample_num * self.batch_size
if parallel_sample_num != 1:
# parallel sampling +1 represents the original prefill stage
num = parallel_sample_num + 1
if isinstance(self.text, list):
# suppot batch operation
self.batch_size = len(self.text)
num = num * len(self.text)
elif isinstance(self.input_ids, list) and isinstance(
self.input_ids[0], list
):
self.batch_size = len(self.input_ids)
num = num * len(self.input_ids)
else:
self.batch_size = 1
else:
# support select operation
num = len(self.text) if self.text is not None else len(self.input_ids)
self.batch_size = num
if self.image_data is None: if self.image_data is None:
self.image_data = [None] * num self.image_data = [None] * num
elif not isinstance(self.image_data, list): elif not isinstance(self.image_data, list):
self.image_data = [self.image_data] * num self.image_data = [self.image_data] * num
elif isinstance(self.image_data, list): elif isinstance(self.image_data, list):
# multi-image with n > 1 # FIXME incorrect order for duplication
self.image_data = self.image_data * num self.image_data = self.image_data * num
if self.sampling_params is None: if self.sampling_params is None:
self.sampling_params = [{}] * num self.sampling_params = [{}] * num
elif not isinstance(self.sampling_params, list): elif not isinstance(self.sampling_params, list):
self.sampling_params = [self.sampling_params] * num self.sampling_params = [self.sampling_params] * num
else:
assert self.parallel_sample_num == 1
if self.rid is None: if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(num)] self.rid = [uuid.uuid4().hex for _ in range(num)]
else: else:
if not isinstance(self.rid, list): assert isinstance(self.rid, list), "The rid should be a list."
raise ValueError("The rid should be a list.") assert self.parallel_sample_num == 1
if self.return_logprob is None: if self.return_logprob is None:
self.return_logprob = [False] * num self.return_logprob = [False] * num
elif not isinstance(self.return_logprob, list): elif not isinstance(self.return_logprob, list):
self.return_logprob = [self.return_logprob] * num self.return_logprob = [self.return_logprob] * num
else:
assert self.parallel_sample_num == 1
if self.logprob_start_len is None: if self.logprob_start_len is None:
self.logprob_start_len = [-1] * num self.logprob_start_len = [-1] * num
elif not isinstance(self.logprob_start_len, list): elif not isinstance(self.logprob_start_len, list):
self.logprob_start_len = [self.logprob_start_len] * num self.logprob_start_len = [self.logprob_start_len] * num
else:
assert self.parallel_sample_num == 1
if self.top_logprobs_num is None: if self.top_logprobs_num is None:
self.top_logprobs_num = [0] * num self.top_logprobs_num = [0] * num
elif not isinstance(self.top_logprobs_num, list): elif not isinstance(self.top_logprobs_num, list):
self.top_logprobs_num = [self.top_logprobs_num] * num self.top_logprobs_num = [self.top_logprobs_num] * num
else:
assert self.parallel_sample_num == 1
@dataclass @dataclass
...@@ -199,8 +193,6 @@ class EmbeddingReqInput: ...@@ -199,8 +193,6 @@ class EmbeddingReqInput:
# Dummy sampling params for compatibility # Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None sampling_params: Union[List[Dict], Dict] = None
is_single: bool = True
def post_init(self): def post_init(self):
if (self.text is None and self.input_ids is None) or ( if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None self.text is not None and self.input_ids is not None
...@@ -255,8 +247,6 @@ class RewardReqInput: ...@@ -255,8 +247,6 @@ class RewardReqInput:
# Dummy sampling params for compatibility # Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None sampling_params: Union[List[Dict], Dict] = None
is_single: bool = True
def post_init(self): def post_init(self):
self.is_single = isinstance(self.conv[0], dict) self.is_single = isinstance(self.conv[0], dict)
......
...@@ -159,58 +159,72 @@ class TokenizerManager: ...@@ -159,58 +159,72 @@ class TokenizerManager:
async for response in self._handle_batch_request(obj, request): async for response in self._handle_batch_request(obj, request):
yield response yield response
async def _handle_single_request( async def _send_single_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None,
index: Optional[int] = None, index: Optional[int] = None,
input_id_index: Optional[int] = None,
is_cache_for_prefill: Optional[bool] = False, is_cache_for_prefill: Optional[bool] = False,
): ):
if not is_cache_for_prefill: # The normal case with a single prompt if not is_cache_for_prefill: # The normal case with a single prompt
not_use_index = index is None if index is None:
rid = obj.rid
rid = obj.rid if not_use_index else obj.rid[index] if hasattr(obj, "conv"):
input_text = obj.text if not_use_index else obj.text[index] # reward model
if hasattr(obj, "conv"): conv = obj.conv
# reward model input_text = self.tokenizer.apply_chat_template(
assert self.tokenizer is not None conv, tokenize=False
conv = obj.conv if not_use_index else obj.conv[index] )
input_text = self.tokenizer.apply_chat_template(conv, tokenize=False) input_ids = self.tokenizer.encode(input_text)
input_ids = self.tokenizer.encode(input_text) elif obj.input_ids is None:
elif obj.input_ids is None: input_text = obj.text
assert self.tokenizer is not None input_ids = self.tokenizer.encode(input_text)
input_ids = self.tokenizer.encode(input_text) else:
input_text = obj.text if obj.text is not None else None
input_ids = obj.input_ids
sampling_params = self._get_sampling_params(obj.sampling_params)
if self.is_generation:
image_inputs = await self.image_processor.process_images_async(
obj.image_data, obj
)
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
else: else:
input_ids = obj.input_ids if not_use_index else obj.input_ids[index] rid = obj.rid[index]
if hasattr(obj, "conv"):
# reward model
conv = obj.conv[index]
input_text = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
input_text = obj.text[input_id_index]
input_ids = self.tokenizer.encode(input_text)
else:
input_text = (
obj.text[input_id_index] if obj.text is not None else None
)
input_ids = obj.input_ids[input_id_index]
self._validate_input_length(input_ids) sampling_params = self._get_sampling_params(obj.sampling_params[index])
if self.is_generation:
image_inputs = await self.image_processor.process_images_async(
obj.image_data[index], obj
)
return_logprob = obj.return_logprob[index]
logprob_start_len = obj.logprob_start_len[index]
top_logprobs_num = obj.top_logprobs_num[index]
sampling_params = self._get_sampling_params( self._validate_input_length(input_ids)
obj.sampling_params if not_use_index else obj.sampling_params[index]
)
if self.is_generation:
image_inputs = await self.image_processor.process_images_async(
obj.image_data if not_use_index else obj.image_data[index], obj
)
return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index]
)
logprob_start_len = (
obj.logprob_start_len
if not_use_index
else obj.logprob_start_len[index]
)
top_logprobs_num = (
obj.top_logprobs_num
if not_use_index
else obj.top_logprobs_num[index]
)
else: # A prefill request to cache the common prompt for parallel sampling else: # A prefill request to cache the common prompt for parallel sampling
assert self.is_generation assert self.is_generation
if obj.text is not None: if obj.text is not None:
if isinstance(obj.text, list): if isinstance(obj.text, list):
input_text = obj.text[index] input_text = obj.text[input_id_index]
rid = obj.rid[index] rid = obj.rid[index]
else: else:
input_text = obj.text input_text = obj.text
...@@ -224,7 +238,7 @@ class TokenizerManager: ...@@ -224,7 +238,7 @@ class TokenizerManager:
obj.input_ids[0], list obj.input_ids[0], list
): ):
# when obj["input_ids"] is List[List[int]] # when obj["input_ids"] is List[List[int]]
input_ids = obj.input_ids[index] input_ids = obj.input_ids[input_id_index]
rid = obj.rid[index] rid = obj.rid[index]
else: else:
input_ids = obj.input_ids input_ids = obj.input_ids
...@@ -235,7 +249,7 @@ class TokenizerManager: ...@@ -235,7 +249,7 @@ class TokenizerManager:
obj.input_ids[0], list obj.input_ids[0], list
): ):
# when obj["input_ids"] is List[List[int]] # when obj["input_ids"] is List[List[int]]
input_ids = obj.input_ids[index] input_ids = obj.input_ids[input_id_index]
rid = obj.rid[index] rid = obj.rid[index]
else: else:
input_ids = obj.input_ids input_ids = obj.input_ids
...@@ -263,7 +277,7 @@ class TokenizerManager: ...@@ -263,7 +277,7 @@ class TokenizerManager:
top_logprobs_num, top_logprobs_num,
obj.stream, obj.stream,
( (
obj.lora_path[index] obj.lora_path[input_id_index]
if isinstance(obj.lora_path, list) if isinstance(obj.lora_path, list)
else obj.lora_path else obj.lora_path
), ),
...@@ -283,12 +297,30 @@ class TokenizerManager: ...@@ -283,12 +297,30 @@ class TokenizerManager:
input_ids, input_ids,
sampling_params, sampling_params,
) )
self.send_to_scheduler.send_pyobj(tokenized_obj) self.send_to_scheduler.send_pyobj(tokenized_obj)
return rid, input_ids
async def _handle_single_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
request: Optional[fastapi.Request] = None,
index: Optional[int] = None,
input_id_index: Optional[int] = None,
is_cache_for_prefill: Optional[bool] = False,
):
rid, input_ids = await self._send_single_request(
obj,
index,
input_id_index=input_id_index,
is_cache_for_prefill=is_cache_for_prefill,
)
# Recv results # Recv results
event = asyncio.Event() event = asyncio.Event()
state = ReqState([], False, event) state = ReqState([], False, event)
self.rid_to_state[rid] = state self.rid_to_state[rid] = state
if not is_cache_for_prefill: if not is_cache_for_prefill:
async for response in self._wait_for_response(state, obj, rid, request): async for response in self._wait_for_response(state, obj, rid, request):
yield response yield response
...@@ -312,14 +344,16 @@ class TokenizerManager: ...@@ -312,14 +344,16 @@ class TokenizerManager:
input_id_result = [] if obj.input_ids is None else None input_id_result = [] if obj.input_ids is None else None
for i in range(batch_size): for i in range(batch_size):
async for input_id in self._handle_single_request( async for input_id in self._handle_single_request(
obj, request, index=i, is_cache_for_prefill=True obj,
request,
index=i,
input_id_index=i,
is_cache_for_prefill=True,
): ):
if input_id_result is not None: if input_id_result is not None:
input_id_result.append(input_id) input_id_result.append(input_id)
if input_id_result is not None and len(input_id_result) > 1: if input_id_result is not None:
obj.input_ids = input_id_result obj.input_ids = input_id_result
elif input_id_result is not None:
obj.input_ids = input_id_result[0]
else: else:
parallel_sample_num = 1 parallel_sample_num = 1
...@@ -333,69 +367,10 @@ class TokenizerManager: ...@@ -333,69 +367,10 @@ class TokenizerManager:
if parallel_sample_num != 1: if parallel_sample_num != 1:
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1 # Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
index += batch_size - 1 - i index += batch_size - 1 - i
rid = obj.rid[index]
if parallel_sample_num == 1:
## select operation
if hasattr(obj, "conv"):
# reward model
conv = obj.conv[i]
input_text = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
input_text = obj.text[i]
input_ids = self.tokenizer.encode(input_text)
else:
input_text = None
input_ids = obj.input_ids[i]
else:
assert obj.input_ids is not None
if batch_size == 1:
input_text = None
input_ids = obj.input_ids
else:
input_text = None
input_ids = obj.input_ids[i]
sampling_params = self._get_sampling_params(obj.sampling_params[index])
if self.is_generation:
image_inputs = await self.image_processor.process_images_async(
obj.image_data[index], obj
)
tokenized_obj = TokenizedGenerateReqInput( rid, _ = await self._send_single_request(
rid, obj, index, input_id_index=i, is_cache_for_prefill=False
input_text, )
input_ids,
image_inputs,
sampling_params,
obj.return_logprob[index],
obj.logprob_start_len[index],
obj.top_logprobs_num[index],
obj.stream,
(
obj.lora_path[index]
if isinstance(obj.lora_path, list)
else obj.lora_path
),
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
else:
assert isinstance(obj, RewardReqInput)
tokenized_obj = TokenizedRewardReqInput(
rid,
input_text,
input_ids,
sampling_params,
)
self.send_to_scheduler.send_pyobj(tokenized_obj)
event = asyncio.Event() event = asyncio.Event()
state = ReqState([], False, event) state = ReqState([], False, event)
...@@ -418,7 +393,7 @@ class TokenizerManager: ...@@ -418,7 +393,7 @@ class TokenizerManager:
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators] tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
output_list = [None] * len(tasks) output_list = [None] * len(tasks)
# Recv results # Fetch results
while tasks: while tasks:
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment