Unverified Commit 96be97bf authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

Minor PD style fix (#7215)

parent 88f9c347
from .conn import (
from sglang.srt.disaggregation.base.conn import (
BaseKVBootstrapServer,
BaseKVManager,
BaseKVReceiver,
......
from .conn import CommonKVBootstrapServer, CommonKVManager, CommonKVReceiver
from sglang.srt.disaggregation.common.conn import (
CommonKVBootstrapServer,
CommonKVManager,
CommonKVReceiver,
)
......@@ -45,11 +45,7 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce,
prepare_abort,
)
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import (
KVCache,
......@@ -248,6 +244,7 @@ class DecodePreallocQueue:
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
data_parallel_rank=req.data_parallel_rank,
)
self.queue.append(
......@@ -636,15 +633,6 @@ class DecodeTransferQueue:
class SchedulerDisaggregationDecodeMixin:
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
batch, _ = self.prepare_dp_attn_batch(batch)
result = None
if batch:
result = self.run_batch(batch)
if not delay_process:
self.process_batch_result(batch, result)
return batch, result
@torch.no_grad()
def event_loop_normal_disagg_decode(self: Scheduler):
"""A normal scheduler loop for decode worker in disaggregation mode."""
......@@ -773,6 +761,15 @@ class SchedulerDisaggregationDecodeMixin:
self.last_batch = batch
self.last_batch_in_queue = last_batch_in_queue
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
batch, _ = self.prepare_dp_attn_batch(batch)
result = None
if batch:
result = self.run_batch(batch)
if not delay_process:
self.process_batch_result(batch, result)
return batch, result
def get_next_disagg_decode_batch_to_run(
self: Scheduler,
) -> Optional[Tuple[ScheduleBatch, bool]]:
......
from .conn import FakeKVReceiver, FakeKVSender
from sglang.srt.disaggregation.fake.conn import FakeKVReceiver, FakeKVSender
import logging
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional
import numpy as np
import numpy.typing as npt
......@@ -8,7 +8,6 @@ from sglang.srt.disaggregation.base.conn import (
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
......@@ -33,7 +32,7 @@ class FakeKVSender(BaseKVSender):
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
logger.info("FakeKVSender poll success")
logger.debug("FakeKVSender poll success")
return KVPoll.Success
def init(
......@@ -41,7 +40,7 @@ class FakeKVSender(BaseKVSender):
kv_indices: list[int],
aux_index: Optional[int] = None,
):
logger.info(
logger.debug(
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
)
pass
......@@ -51,7 +50,7 @@ class FakeKVSender(BaseKVSender):
kv_indices: npt.NDArray[np.int32],
):
self.has_sent = True
logger.info(f"FakeKVSender send with kv_indices: {kv_indices}")
logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
def failure_exception(self):
raise Exception("Fake KVSender Exception")
......@@ -73,12 +72,12 @@ class FakeKVReceiver(BaseKVReceiver):
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
logger.info("FakeKVReceiver poll success")
logger.debug("FakeKVReceiver poll success")
return KVPoll.Success
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
self.has_init = True
logger.info(
logger.debug(
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
)
......
from .conn import (
from sglang.srt.disaggregation.mooncake.conn import (
MooncakeKVBootstrapServer,
MooncakeKVManager,
MooncakeKVReceiver,
......
from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender
from sglang.srt.disaggregation.nixl.conn import (
NixlKVBootstrapServer,
NixlKVManager,
NixlKVReceiver,
NixlKVSender,
)
......@@ -202,7 +202,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
}
return class_mapping.get(class_type)
if transfer_backend == TransferBackend.NIXL:
elif transfer_backend == TransferBackend.NIXL:
from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.nixl import (
NixlKVBootstrapServer,
......@@ -219,7 +219,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
}
return class_mapping.get(class_type)
if transfer_backend == TransferBackend.FAKE:
elif transfer_backend == TransferBackend.FAKE:
from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment