Unverified Commit bc938ea1 authored by Brayden Zhong's avatar Brayden Zhong Committed by GitHub
Browse files

Fix DP load for embedding (#9165)

parent eff4eb3f
...@@ -612,6 +612,8 @@ class EmbeddingReqInput: ...@@ -612,6 +612,8 @@ class EmbeddingReqInput:
if self.sampling_params is None: if self.sampling_params is None:
self.sampling_params = [{}] * self.batch_size self.sampling_params = [{}] * self.batch_size
elif isinstance(self.sampling_params, dict):
self.sampling_params = [self.sampling_params] * self.batch_size
for i in range(self.batch_size): for i in range(self.batch_size):
self.sampling_params[i]["max_new_tokens"] = 0 self.sampling_params[i]["max_new_tokens"] = 0
...@@ -660,6 +662,8 @@ class TokenizedEmbeddingReqInput: ...@@ -660,6 +662,8 @@ class TokenizedEmbeddingReqInput:
token_type_ids: List[int] token_type_ids: List[int]
# Dummy sampling params for compatibility # Dummy sampling params for compatibility
sampling_params: SamplingParams sampling_params: SamplingParams
# For data parallel rank routing
data_parallel_rank: Optional[int] = None
# For dp balance # For dp balance
dp_balance_id: int = -1 dp_balance_id: int = -1
......
...@@ -54,7 +54,7 @@ class SessionReqNode: ...@@ -54,7 +54,7 @@ class SessionReqNode:
prefix += " -- " + self.childs[0].req.rid prefix += " -- " + self.childs[0].req.rid
ret = self.childs[0]._str_helper(prefix) ret = self.childs[0]._str_helper(prefix)
for child in self.childs[1:]: for child in self.childs[1:]:
prefix = " " * len(origin_prefix) + " \- " + child.req.rid prefix = " " * len(origin_prefix) + " \\- " + child.req.rid
ret += child._str_helper(prefix) ret += child._str_helper(prefix)
return ret return ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment