Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f1569876
Unverified
Commit
f1569876
authored
Jun 09, 2025
by
ishandhanani
Committed by
GitHub
Jun 09, 2025
Browse files
feat: add direct routing strategy to DP worker (#6884)
parent
3465d7ae
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
78 additions
and
8 deletions
+78
-8
python/sglang/srt/disaggregation/common/conn.py
python/sglang/srt/disaggregation/common/conn.py
+7
-1
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+1
-0
python/sglang/srt/disaggregation/fake/conn.py
python/sglang/srt/disaggregation/fake/conn.py
+1
-0
python/sglang/srt/disaggregation/mooncake/conn.py
python/sglang/srt/disaggregation/mooncake/conn.py
+7
-1
python/sglang/srt/disaggregation/nixl/conn.py
python/sglang/srt/disaggregation/nixl/conn.py
+2
-1
python/sglang/srt/entrypoints/EngineBase.py
python/sglang/srt/entrypoints/EngineBase.py
+6
-0
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+26
-0
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+13
-5
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+9
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-0
No files found.
python/sglang/srt/disaggregation/common/conn.py
View file @
f1569876
...
@@ -109,10 +109,12 @@ class CommonKVReceiver(BaseKVReceiver):
...
@@ -109,10 +109,12 @@ class CommonKVReceiver(BaseKVReceiver):
mgr
:
BaseKVManager
,
mgr
:
BaseKVManager
,
bootstrap_addr
:
str
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
,
bootstrap_room
:
Optional
[
int
]
=
None
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
):
):
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_addr
=
bootstrap_addr
self
.
bootstrap_addr
=
bootstrap_addr
self
.
kv_mgr
=
mgr
self
.
kv_mgr
=
mgr
self
.
data_parallel_rank
=
data_parallel_rank
if
self
.
bootstrap_addr
not
in
self
.
kv_mgr
.
prefill_dp_size_table
:
if
self
.
bootstrap_addr
not
in
self
.
kv_mgr
.
prefill_dp_size_table
:
self
.
prefill_tp_size
,
self
.
prefill_dp_size
=
(
self
.
prefill_tp_size
,
self
.
prefill_dp_size
=
(
...
@@ -180,7 +182,11 @@ class CommonKVReceiver(BaseKVReceiver):
...
@@ -180,7 +182,11 @@ class CommonKVReceiver(BaseKVReceiver):
self
.
target_tp_rank
=
self
.
target_tp_ranks
[
0
]
self
.
target_tp_rank
=
self
.
target_tp_ranks
[
0
]
self
.
required_dst_info_num
=
1
self
.
required_dst_info_num
=
1
self
.
target_dp_group
=
bootstrap_room
%
self
.
prefill_dp_size
if
self
.
data_parallel_rank
is
not
None
:
logger
.
debug
(
f
"Targeting DP rank:
{
self
.
data_parallel_rank
}
"
)
self
.
target_dp_group
=
self
.
data_parallel_rank
else
:
self
.
target_dp_group
=
bootstrap_room
%
self
.
prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key
=
(
bootstrap_key
=
(
...
...
python/sglang/srt/disaggregation/decode.py
View file @
f1569876
...
@@ -156,6 +156,7 @@ class DecodePreallocQueue:
...
@@ -156,6 +156,7 @@ class DecodePreallocQueue:
mgr
=
self
.
kv_manager
,
mgr
=
self
.
kv_manager
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
req
.
bootstrap_port
}
"
,
bootstrap_addr
=
f
"
{
req
.
bootstrap_host
}
:
{
req
.
bootstrap_port
}
"
,
bootstrap_room
=
req
.
bootstrap_room
,
bootstrap_room
=
req
.
bootstrap_room
,
data_parallel_rank
=
req
.
data_parallel_rank
,
)
)
self
.
queue
.
append
(
DecodeRequest
(
req
=
req
,
kv_receiver
=
kv_receiver
))
self
.
queue
.
append
(
DecodeRequest
(
req
=
req
,
kv_receiver
=
kv_receiver
))
...
...
python/sglang/srt/disaggregation/fake/conn.py
View file @
f1569876
...
@@ -56,6 +56,7 @@ class FakeKVReceiver(BaseKVReceiver):
...
@@ -56,6 +56,7 @@ class FakeKVReceiver(BaseKVReceiver):
mgr
:
BaseKVManager
,
mgr
:
BaseKVManager
,
bootstrap_addr
:
str
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
,
bootstrap_room
:
Optional
[
int
]
=
None
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
):
):
self
.
has_init
=
False
self
.
has_init
=
False
...
...
python/sglang/srt/disaggregation/mooncake/conn.py
View file @
f1569876
...
@@ -765,6 +765,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
...
@@ -765,6 +765,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
mgr
:
MooncakeKVManager
,
mgr
:
MooncakeKVManager
,
bootstrap_addr
:
str
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
,
bootstrap_room
:
Optional
[
int
]
=
None
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
):
):
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_room
=
bootstrap_room
self
.
bootstrap_addr
=
bootstrap_addr
self
.
bootstrap_addr
=
bootstrap_addr
...
@@ -772,6 +773,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
...
@@ -772,6 +773,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
self
.
session_id
=
self
.
kv_mgr
.
get_session_id
()
self
.
session_id
=
self
.
kv_mgr
.
get_session_id
()
self
.
kv_mgr
.
update_status
(
self
.
bootstrap_room
,
KVPoll
.
Bootstrapping
)
self
.
kv_mgr
.
update_status
(
self
.
bootstrap_room
,
KVPoll
.
Bootstrapping
)
self
.
conclude_state
=
None
self
.
conclude_state
=
None
self
.
data_parallel_rank
=
data_parallel_rank
if
self
.
bootstrap_addr
not
in
self
.
kv_mgr
.
prefill_dp_size_table
:
if
self
.
bootstrap_addr
not
in
self
.
kv_mgr
.
prefill_dp_size_table
:
self
.
prefill_tp_size
,
self
.
prefill_dp_size
=
(
self
.
prefill_tp_size
,
self
.
prefill_dp_size
=
(
...
@@ -845,7 +847,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
...
@@ -845,7 +847,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
self
.
target_tp_rank
=
self
.
target_tp_ranks
[
0
]
self
.
target_tp_rank
=
self
.
target_tp_ranks
[
0
]
self
.
required_dst_info_num
=
1
self
.
required_dst_info_num
=
1
self
.
target_dp_group
=
self
.
bootstrap_room
%
self
.
prefill_dp_size
if
self
.
data_parallel_rank
is
not
None
:
logger
.
debug
(
f
"Targeting DP rank:
{
self
.
data_parallel_rank
}
"
)
self
.
target_dp_group
=
self
.
data_parallel_rank
else
:
self
.
target_dp_group
=
bootstrap_room
%
self
.
prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key
=
(
bootstrap_key
=
(
...
...
python/sglang/srt/disaggregation/nixl/conn.py
View file @
f1569876
...
@@ -407,9 +407,10 @@ class NixlKVReceiver(CommonKVReceiver):
...
@@ -407,9 +407,10 @@ class NixlKVReceiver(CommonKVReceiver):
mgr
:
NixlKVManager
,
mgr
:
NixlKVManager
,
bootstrap_addr
:
str
,
bootstrap_addr
:
str
,
bootstrap_room
:
Optional
[
int
]
=
None
,
bootstrap_room
:
Optional
[
int
]
=
None
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
):
):
self
.
started_transfer
=
False
self
.
started_transfer
=
False
super
().
__init__
(
mgr
,
bootstrap_addr
,
bootstrap_room
)
super
().
__init__
(
mgr
,
bootstrap_addr
,
bootstrap_room
,
data_parallel_rank
)
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
aux_index
:
Optional
[
int
]
=
None
):
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int64
],
aux_index
:
Optional
[
int
]
=
None
):
for
bootstrap_info
in
self
.
bootstrap_infos
:
for
bootstrap_info
in
self
.
bootstrap_infos
:
...
...
python/sglang/srt/entrypoints/EngineBase.py
View file @
f1569876
...
@@ -23,6 +23,12 @@ class EngineBase(ABC):
...
@@ -23,6 +23,12 @@ class EngineBase(ABC):
token_ids_logprob
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
,
token_ids_logprob
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
,
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
,
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
,
custom_logit_processor
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
custom_logit_processor
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
return_hidden_states
:
Optional
[
bool
]
=
None
,
stream
:
Optional
[
bool
]
=
None
,
bootstrap_host
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
bootstrap_port
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
bootstrap_room
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
)
->
Union
[
Dict
,
Iterator
[
Dict
]]:
)
->
Union
[
Dict
,
Iterator
[
Dict
]]:
"""Generate outputs based on given inputs."""
"""Generate outputs based on given inputs."""
pass
pass
...
...
python/sglang/srt/entrypoints/engine.py
View file @
f1569876
...
@@ -167,11 +167,22 @@ class Engine(EngineBase):
...
@@ -167,11 +167,22 @@ class Engine(EngineBase):
bootstrap_host
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
bootstrap_host
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
bootstrap_port
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
bootstrap_port
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
bootstrap_room
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
bootstrap_room
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
)
->
Union
[
Dict
,
Iterator
[
Dict
]]:
)
->
Union
[
Dict
,
Iterator
[
Dict
]]:
"""
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
Please refer to `GenerateReqInput` for the documentation.
"""
"""
if
self
.
server_args
.
enable_dp_attention
:
if
data_parallel_rank
is
None
:
logger
.
info
(
"data_parallel_rank not provided, using default dispatch"
)
elif
data_parallel_rank
<
0
:
raise
ValueError
(
"data_parallel_rank must be non-negative"
)
elif
data_parallel_rank
>=
self
.
server_args
.
dp_size
:
raise
ValueError
(
f
"data_parallel_rank must be less than dp_size:
{
self
.
server_args
.
dp_size
}
"
)
obj
=
GenerateReqInput
(
obj
=
GenerateReqInput
(
text
=
prompt
,
text
=
prompt
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
@@ -188,6 +199,7 @@ class Engine(EngineBase):
...
@@ -188,6 +199,7 @@ class Engine(EngineBase):
bootstrap_host
=
bootstrap_host
,
bootstrap_host
=
bootstrap_host
,
bootstrap_port
=
bootstrap_port
,
bootstrap_port
=
bootstrap_port
,
bootstrap_room
=
bootstrap_room
,
bootstrap_room
=
bootstrap_room
,
data_parallel_rank
=
data_parallel_rank
,
)
)
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
generator
=
self
.
tokenizer_manager
.
generate_request
(
obj
,
None
)
generator
=
self
.
tokenizer_manager
.
generate_request
(
obj
,
None
)
...
@@ -237,11 +249,24 @@ class Engine(EngineBase):
...
@@ -237,11 +249,24 @@ class Engine(EngineBase):
bootstrap_host
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
bootstrap_host
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
,
bootstrap_port
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
bootstrap_port
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
bootstrap_room
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
bootstrap_room
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
)
->
Union
[
Dict
,
AsyncIterator
[
Dict
]]:
)
->
Union
[
Dict
,
AsyncIterator
[
Dict
]]:
"""
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
Please refer to `GenerateReqInput` for the documentation.
"""
"""
if
self
.
server_args
.
enable_dp_attention
:
if
data_parallel_rank
is
None
:
logger
.
info
(
"data_parallel_rank not provided, using default dispatch"
)
elif
data_parallel_rank
<
0
:
raise
ValueError
(
"data_parallel_rank must be non-negative"
)
elif
data_parallel_rank
>=
self
.
server_args
.
dp_size
:
raise
ValueError
(
f
"data_parallel_rank must be in range [0,
{
self
.
server_args
.
dp_size
-
1
}
]"
)
logger
.
info
(
f
"data_parallel_rank:
{
data_parallel_rank
}
"
)
obj
=
GenerateReqInput
(
obj
=
GenerateReqInput
(
text
=
prompt
,
text
=
prompt
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
@@ -257,6 +282,7 @@ class Engine(EngineBase):
...
@@ -257,6 +282,7 @@ class Engine(EngineBase):
bootstrap_host
=
bootstrap_host
,
bootstrap_host
=
bootstrap_host
,
bootstrap_port
=
bootstrap_port
,
bootstrap_port
=
bootstrap_port
,
bootstrap_room
=
bootstrap_room
,
bootstrap_room
=
bootstrap_room
,
data_parallel_rank
=
data_parallel_rank
,
)
)
generator
=
self
.
tokenizer_manager
.
generate_request
(
obj
,
None
)
generator
=
self
.
tokenizer_manager
.
generate_request
(
obj
,
None
)
...
...
python/sglang/srt/managers/data_parallel_controller.py
View file @
f1569876
...
@@ -248,12 +248,20 @@ class DataParallelController:
...
@@ -248,12 +248,20 @@ class DataParallelController:
def
round_robin_scheduler
(
self
,
req
:
Req
):
def
round_robin_scheduler
(
self
,
req
:
Req
):
if
self
.
server_args
.
disaggregation_mode
==
"null"
:
if
self
.
server_args
.
disaggregation_mode
==
"null"
:
self
.
workers
[
self
.
round_robin_counter
].
send_pyobj
(
req
)
if
req
.
data_parallel_rank
is
not
None
:
self
.
round_robin_counter
=
(
self
.
round_robin_counter
+
1
)
%
len
(
logger
.
debug
(
f
"Direct routing to DP rank
{
req
.
data_parallel_rank
}
"
)
self
.
workers
self
.
workers
[
req
.
data_parallel_rank
].
send_pyobj
(
req
)
)
else
:
self
.
workers
[
self
.
round_robin_counter
].
send_pyobj
(
req
)
self
.
round_robin_counter
=
(
self
.
round_robin_counter
+
1
)
%
len
(
self
.
workers
)
else
:
else
:
self
.
workers
[
req
.
bootstrap_room
%
len
(
self
.
workers
)].
send_pyobj
(
req
)
if
req
.
data_parallel_rank
is
not
None
:
logger
.
debug
(
f
"Direct routing to DP rank
{
req
.
data_parallel_rank
}
"
)
self
.
workers
[
req
.
data_parallel_rank
].
send_pyobj
(
req
)
else
:
self
.
workers
[
req
.
bootstrap_room
%
len
(
self
.
workers
)].
send_pyobj
(
req
)
def
shortest_queue_scheduler
(
self
,
input_requests
):
def
shortest_queue_scheduler
(
self
,
input_requests
):
raise
NotImplementedError
()
raise
NotImplementedError
()
...
...
python/sglang/srt/managers/io_struct.py
View file @
f1569876
...
@@ -106,6 +106,9 @@ class GenerateReqInput:
...
@@ -106,6 +106,9 @@ class GenerateReqInput:
bootstrap_port
:
Optional
[
Union
[
List
[
Optional
[
int
]],
int
]]
=
None
bootstrap_port
:
Optional
[
Union
[
List
[
Optional
[
int
]],
int
]]
=
None
bootstrap_room
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
bootstrap_room
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
# For data parallel rank routing
data_parallel_rank
:
Optional
[
int
]
=
None
def
contains_mm_input
(
self
)
->
bool
:
def
contains_mm_input
(
self
)
->
bool
:
return
has_valid_data
(
self
.
image_data
)
or
has_valid_data
(
self
.
audio_data
)
return
has_valid_data
(
self
.
image_data
)
or
has_valid_data
(
self
.
audio_data
)
...
@@ -417,6 +420,9 @@ class GenerateReqInput:
...
@@ -417,6 +420,9 @@ class GenerateReqInput:
bootstrap_room
=
(
bootstrap_room
=
(
self
.
bootstrap_room
[
i
]
if
self
.
bootstrap_room
is
not
None
else
None
self
.
bootstrap_room
[
i
]
if
self
.
bootstrap_room
is
not
None
else
None
),
),
data_parallel_rank
=
(
self
.
data_parallel_rank
if
self
.
data_parallel_rank
is
not
None
else
None
),
)
)
...
@@ -464,6 +470,9 @@ class TokenizedGenerateReqInput:
...
@@ -464,6 +470,9 @@ class TokenizedGenerateReqInput:
bootstrap_port
:
Optional
[
int
]
=
None
bootstrap_port
:
Optional
[
int
]
=
None
bootstrap_room
:
Optional
[
int
]
=
None
bootstrap_room
:
Optional
[
int
]
=
None
# For data parallel rank routing
data_parallel_rank
:
Optional
[
int
]
=
None
@
dataclass
@
dataclass
class
EmbeddingReqInput
:
class
EmbeddingReqInput
:
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
f1569876
...
@@ -451,6 +451,7 @@ class Req:
...
@@ -451,6 +451,7 @@ class Req:
bootstrap_host
:
Optional
[
str
]
=
None
,
bootstrap_host
:
Optional
[
str
]
=
None
,
bootstrap_port
:
Optional
[
int
]
=
None
,
bootstrap_port
:
Optional
[
int
]
=
None
,
bootstrap_room
:
Optional
[
int
]
=
None
,
bootstrap_room
:
Optional
[
int
]
=
None
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
):
):
# Input and output info
# Input and output info
self
.
rid
=
rid
self
.
rid
=
rid
...
@@ -605,6 +606,9 @@ class Req:
...
@@ -605,6 +606,9 @@ class Req:
self
.
bootstrap_room
:
Optional
[
int
]
=
bootstrap_room
self
.
bootstrap_room
:
Optional
[
int
]
=
bootstrap_room
self
.
disagg_kv_sender
:
Optional
[
BaseKVSender
]
=
None
self
.
disagg_kv_sender
:
Optional
[
BaseKVSender
]
=
None
# For data parallel rank routing
self
.
data_parallel_rank
:
Optional
[
int
]
=
data_parallel_rank
# the start index of the sent kv cache
# the start index of the sent kv cache
# We want to send it chunk by chunk for chunked prefill.
# We want to send it chunk by chunk for chunked prefill.
# After every chunk forward, we do the following:
# After every chunk forward, we do the following:
...
...
python/sglang/srt/managers/scheduler.py
View file @
f1569876
...
@@ -949,6 +949,7 @@ class Scheduler(
...
@@ -949,6 +949,7 @@ class Scheduler(
bootstrap_host
=
recv_req
.
bootstrap_host
,
bootstrap_host
=
recv_req
.
bootstrap_host
,
bootstrap_port
=
recv_req
.
bootstrap_port
,
bootstrap_port
=
recv_req
.
bootstrap_port
,
bootstrap_room
=
recv_req
.
bootstrap_room
,
bootstrap_room
=
recv_req
.
bootstrap_room
,
data_parallel_rank
=
recv_req
.
data_parallel_rank
,
)
)
req
.
tokenizer
=
self
.
tokenizer
req
.
tokenizer
=
self
.
tokenizer
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
f1569876
...
@@ -570,6 +570,7 @@ class TokenizerManager:
...
@@ -570,6 +570,7 @@ class TokenizerManager:
session_params
=
session_params
,
session_params
=
session_params
,
custom_logit_processor
=
obj
.
custom_logit_processor
,
custom_logit_processor
=
obj
.
custom_logit_processor
,
return_hidden_states
=
obj
.
return_hidden_states
,
return_hidden_states
=
obj
.
return_hidden_states
,
data_parallel_rank
=
obj
.
data_parallel_rank
,
)
)
elif
isinstance
(
obj
,
EmbeddingReqInput
):
elif
isinstance
(
obj
,
EmbeddingReqInput
):
tokenized_obj
=
TokenizedEmbeddingReqInput
(
tokenized_obj
=
TokenizedEmbeddingReqInput
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment