Unverified Commit 96be97bf authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

Minor PD style fix (#7215)

parent 88f9c347
from .conn import ( from sglang.srt.disaggregation.base.conn import (
BaseKVBootstrapServer, BaseKVBootstrapServer,
BaseKVManager, BaseKVManager,
BaseKVReceiver, BaseKVReceiver,
......
from .conn import CommonKVBootstrapServer, CommonKVManager, CommonKVReceiver from sglang.srt.disaggregation.common.conn import (
CommonKVBootstrapServer,
CommonKVManager,
CommonKVReceiver,
)
...@@ -45,11 +45,7 @@ from sglang.srt.disaggregation.utils import ( ...@@ -45,11 +45,7 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce, poll_and_all_reduce,
prepare_abort, prepare_abort,
) )
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
FINISH_ABORT,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
KVCache, KVCache,
...@@ -248,6 +244,7 @@ class DecodePreallocQueue: ...@@ -248,6 +244,7 @@ class DecodePreallocQueue:
mgr=self.kv_manager, mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room, bootstrap_room=req.bootstrap_room,
data_parallel_rank=req.data_parallel_rank,
) )
self.queue.append( self.queue.append(
...@@ -636,15 +633,6 @@ class DecodeTransferQueue: ...@@ -636,15 +633,6 @@ class DecodeTransferQueue:
class SchedulerDisaggregationDecodeMixin: class SchedulerDisaggregationDecodeMixin:
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
batch, _ = self.prepare_dp_attn_batch(batch)
result = None
if batch:
result = self.run_batch(batch)
if not delay_process:
self.process_batch_result(batch, result)
return batch, result
@torch.no_grad() @torch.no_grad()
def event_loop_normal_disagg_decode(self: Scheduler): def event_loop_normal_disagg_decode(self: Scheduler):
"""A normal scheduler loop for decode worker in disaggregation mode.""" """A normal scheduler loop for decode worker in disaggregation mode."""
...@@ -773,6 +761,15 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -773,6 +761,15 @@ class SchedulerDisaggregationDecodeMixin:
self.last_batch = batch self.last_batch = batch
self.last_batch_in_queue = last_batch_in_queue self.last_batch_in_queue = last_batch_in_queue
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
batch, _ = self.prepare_dp_attn_batch(batch)
result = None
if batch:
result = self.run_batch(batch)
if not delay_process:
self.process_batch_result(batch, result)
return batch, result
def get_next_disagg_decode_batch_to_run( def get_next_disagg_decode_batch_to_run(
self: Scheduler, self: Scheduler,
) -> Optional[Tuple[ScheduleBatch, bool]]: ) -> Optional[Tuple[ScheduleBatch, bool]]:
......
from .conn import FakeKVReceiver, FakeKVSender from sglang.srt.disaggregation.fake.conn import FakeKVReceiver, FakeKVSender
import logging import logging
from typing import Dict, List, Optional, Tuple, Union from typing import List, Optional
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
...@@ -8,7 +8,6 @@ from sglang.srt.disaggregation.base.conn import ( ...@@ -8,7 +8,6 @@ from sglang.srt.disaggregation.base.conn import (
BaseKVManager, BaseKVManager,
BaseKVReceiver, BaseKVReceiver,
BaseKVSender, BaseKVSender,
KVArgs,
KVPoll, KVPoll,
) )
...@@ -33,7 +32,7 @@ class FakeKVSender(BaseKVSender): ...@@ -33,7 +32,7 @@ class FakeKVSender(BaseKVSender):
return KVPoll.WaitingForInput return KVPoll.WaitingForInput
else: else:
# Assume transfer completed instantly # Assume transfer completed instantly
logger.info("FakeKVSender poll success") logger.debug("FakeKVSender poll success")
return KVPoll.Success return KVPoll.Success
def init( def init(
...@@ -41,7 +40,7 @@ class FakeKVSender(BaseKVSender): ...@@ -41,7 +40,7 @@ class FakeKVSender(BaseKVSender):
kv_indices: list[int], kv_indices: list[int],
aux_index: Optional[int] = None, aux_index: Optional[int] = None,
): ):
logger.info( logger.debug(
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}" f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
) )
pass pass
...@@ -51,7 +50,7 @@ class FakeKVSender(BaseKVSender): ...@@ -51,7 +50,7 @@ class FakeKVSender(BaseKVSender):
kv_indices: npt.NDArray[np.int32], kv_indices: npt.NDArray[np.int32],
): ):
self.has_sent = True self.has_sent = True
logger.info(f"FakeKVSender send with kv_indices: {kv_indices}") logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
def failure_exception(self): def failure_exception(self):
raise Exception("Fake KVSender Exception") raise Exception("Fake KVSender Exception")
...@@ -73,12 +72,12 @@ class FakeKVReceiver(BaseKVReceiver): ...@@ -73,12 +72,12 @@ class FakeKVReceiver(BaseKVReceiver):
return KVPoll.WaitingForInput return KVPoll.WaitingForInput
else: else:
# Assume transfer completed instantly # Assume transfer completed instantly
logger.info("FakeKVReceiver poll success") logger.debug("FakeKVReceiver poll success")
return KVPoll.Success return KVPoll.Success
def init(self, kv_indices: list[int], aux_index: Optional[int] = None): def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
self.has_init = True self.has_init = True
logger.info( logger.debug(
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}" f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
) )
......
from .conn import ( from sglang.srt.disaggregation.mooncake.conn import (
MooncakeKVBootstrapServer, MooncakeKVBootstrapServer,
MooncakeKVManager, MooncakeKVManager,
MooncakeKVReceiver, MooncakeKVReceiver,
......
from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender from sglang.srt.disaggregation.nixl.conn import (
NixlKVBootstrapServer,
NixlKVManager,
NixlKVReceiver,
NixlKVSender,
)
...@@ -202,7 +202,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): ...@@ -202,7 +202,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer, KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
} }
return class_mapping.get(class_type) return class_mapping.get(class_type)
if transfer_backend == TransferBackend.NIXL: elif transfer_backend == TransferBackend.NIXL:
from sglang.srt.disaggregation.base import KVArgs from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.nixl import ( from sglang.srt.disaggregation.nixl import (
NixlKVBootstrapServer, NixlKVBootstrapServer,
...@@ -219,7 +219,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): ...@@ -219,7 +219,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer, KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
} }
return class_mapping.get(class_type) return class_mapping.get(class_type)
if transfer_backend == TransferBackend.FAKE: elif transfer_backend == TransferBackend.FAKE:
from sglang.srt.disaggregation.base import KVArgs from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment