Unverified Commit bc938ea1 authored by Brayden Zhong's avatar Brayden Zhong Committed by GitHub
Browse files

Fix DP load for embedding (#9165)

parent eff4eb3f
......@@ -612,6 +612,8 @@ class EmbeddingReqInput:
if self.sampling_params is None:
self.sampling_params = [{}] * self.batch_size
elif isinstance(self.sampling_params, dict):
self.sampling_params = [self.sampling_params] * self.batch_size
for i in range(self.batch_size):
self.sampling_params[i]["max_new_tokens"] = 0
......@@ -660,6 +662,8 @@ class TokenizedEmbeddingReqInput:
token_type_ids: List[int]
# Dummy sampling params for compatibility
sampling_params: SamplingParams
# For data parallel rank routing
data_parallel_rank: Optional[int] = None
# For dp balance
dp_balance_id: int = -1
......
......@@ -54,7 +54,7 @@ class SessionReqNode:
prefix += " -- " + self.childs[0].req.rid
ret = self.childs[0]._str_helper(prefix)
for child in self.childs[1:]:
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
prefix = " " * len(origin_prefix) + " \\- " + child.req.rid
ret += child._str_helper(prefix)
return ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment