Unverified Commit f0f84975 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat: add dp-rank to KV events (#6852)

parent 3f1e4339
...@@ -43,6 +43,7 @@ class EventBatch( ...@@ -43,6 +43,7 @@ class EventBatch(
): ):
ts: float ts: float
events: list[Any] events: list[Any]
attn_dp_rank: Optional[int] = None
class KVCacheEvent( class KVCacheEvent(
...@@ -76,7 +77,21 @@ class KVEventBatch(EventBatch): ...@@ -76,7 +77,21 @@ class KVEventBatch(EventBatch):
class EventPublisher(ABC): class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches.""" """
Lightweight publisher for EventBatch batches with
support for DP attention.
In DP attention - each rank has its own Scheduler and
KV cache instance in order to avoid duplicate events
and ensure proper event attribution. In our implementation
- Each DP rank has its own EventPublisher
- Publishers annotate events with the dp rank
- This allows consumers to distinguish events from different DP ranks
"""
def __init__(self, attn_dp_rank: int = 0):
self._attn_dp_rank = attn_dp_rank
@abstractmethod @abstractmethod
def publish(self, events: EventBatch) -> None: def publish(self, events: EventBatch) -> None:
...@@ -130,6 +145,7 @@ class ZmqEventPublisher(EventPublisher): ...@@ -130,6 +145,7 @@ class ZmqEventPublisher(EventPublisher):
def __init__( def __init__(
self, self,
attn_dp_rank: int,
endpoint: str = "tcp://*:5557", endpoint: str = "tcp://*:5557",
replay_endpoint: Optional[str] = None, replay_endpoint: Optional[str] = None,
buffer_steps: int = 10_000, buffer_steps: int = 10_000,
...@@ -138,6 +154,7 @@ class ZmqEventPublisher(EventPublisher): ...@@ -138,6 +154,7 @@ class ZmqEventPublisher(EventPublisher):
topic: str = "", topic: str = "",
) -> None: ) -> None:
# Storage # Storage
super().__init__(attn_dp_rank)
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size) self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps) self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
...@@ -145,8 +162,11 @@ class ZmqEventPublisher(EventPublisher): ...@@ -145,8 +162,11 @@ class ZmqEventPublisher(EventPublisher):
self._ctx = zmq.Context.instance() self._ctx = zmq.Context.instance()
self._pub: Optional[zmq.Socket] = None self._pub: Optional[zmq.Socket] = None
self._replay: Optional[zmq.Socket] = None self._replay: Optional[zmq.Socket] = None
self._endpoint = endpoint self._dp_rank = attn_dp_rank
self._replay_endpoint = replay_endpoint self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
self._replay_endpoint = self.offset_endpoint_port(
replay_endpoint, self._dp_rank
)
self._hwm = hwm self._hwm = hwm
self._socket_setup() self._socket_setup()
...@@ -168,6 +188,8 @@ class ZmqEventPublisher(EventPublisher): ...@@ -168,6 +188,8 @@ class ZmqEventPublisher(EventPublisher):
def publish(self, events: EventBatch) -> None: def publish(self, events: EventBatch) -> None:
if not self._running: if not self._running:
raise RuntimeError("Publisher is closed") raise RuntimeError("Publisher is closed")
if events.attn_dp_rank is None:
events.attn_dp_rank = self._dp_rank
self._event_queue.put(events) self._event_queue.put(events)
def shutdown(self) -> None: def shutdown(self) -> None:
...@@ -288,6 +310,39 @@ class ZmqEventPublisher(EventPublisher): ...@@ -288,6 +310,39 @@ class ZmqEventPublisher(EventPublisher):
# receiving payload is (-1, b""") # receiving payload is (-1, b""")
self._replay.send_multipart((client_id, b"", self.END_SEQ, b"")) self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
@staticmethod
def offset_endpoint_port(
endpoint: Optional[str], data_parallel_rank: int
) -> Optional[str]:
"""Helper function to offset the port in an endpoint by
the data parallel rank.
Args:
endpoint: The endpoint string
(e.g., "tcp://*:5557" or "inproc://cache")
data_parallel_rank: The data parallel rank to offset by
Returns:
The endpoint with the port offset by data_parallel_rank
or suffix appended
"""
# Do nothing if input is None or data_parallel_rank is 0
if not endpoint or data_parallel_rank == 0:
return endpoint
if "inproc" in endpoint:
return f"{endpoint}_dp{data_parallel_rank}"
if "tcp" in endpoint:
if endpoint and ":" in endpoint:
# Get everything after the last colon (the port)
last_colon_idx = endpoint.rfind(":")
base_addr = endpoint[:last_colon_idx]
base_port = int(endpoint[last_colon_idx + 1 :])
new_port = base_port + data_parallel_rank
return f"{base_addr}:{new_port}"
return endpoint
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
class KVEventsConfig(BaseModel): class KVEventsConfig(BaseModel):
"""Configuration for KV event publishing.""" """Configuration for KV event publishing."""
...@@ -342,7 +397,7 @@ class EventPublisherFactory: ...@@ -342,7 +397,7 @@ class EventPublisherFactory:
cls._registry[name] = ctor cls._registry[name] = ctor
@classmethod @classmethod
def create(cls, config: Optional[str]) -> EventPublisher: def create(cls, config: Optional[str], attn_dp_rank: int = 0) -> EventPublisher:
"""Create publisher from a config mapping.""" """Create publisher from a config mapping."""
if not config: if not config:
return NullEventPublisher() return NullEventPublisher()
...@@ -354,4 +409,4 @@ class EventPublisherFactory: ...@@ -354,4 +409,4 @@ class EventPublisherFactory:
constructor = cls._registry[kind] constructor = cls._registry[kind]
except KeyError as exc: except KeyError as exc:
raise ValueError(f"Unknown event publisher '{kind}'") from exc raise ValueError(f"Unknown event publisher '{kind}'") from exc
return constructor(**config_dict) return constructor(attn_dp_rank=attn_dp_rank, **config_dict)
...@@ -571,7 +571,9 @@ class Scheduler( ...@@ -571,7 +571,9 @@ class Scheduler(
def init_kv_events(self, kv_events_config: Optional[str]): def init_kv_events(self, kv_events_config: Optional[str]):
if self.enable_kv_cache_events: if self.enable_kv_cache_events:
self.kv_event_publisher = EventPublisherFactory.create(kv_events_config) self.kv_event_publisher = EventPublisherFactory.create(
kv_events_config, self.attn_dp_rank
)
def init_disaggregation(self): def init_disaggregation(self):
self.transfer_backend = TransferBackend( self.transfer_backend = TransferBackend(
...@@ -1988,7 +1990,7 @@ class Scheduler( ...@@ -1988,7 +1990,7 @@ class Scheduler(
self.cum_spec_accept_length = self.cum_spec_accept_count = 0 self.cum_spec_accept_length = self.cum_spec_accept_count = 0
for k, v in server_args_dict.items(): for k, v in server_args_dict.items():
global_server_args_dict[k] = v global_server_args_dict[k] = v
logger.info(f"Global server args updated! " f"{global_server_args_dict=}") logger.info(f"Global server args updated! {global_server_args_dict=}")
return SetInternalStateReqOutput( return SetInternalStateReqOutput(
updated=True, updated=True,
server_args=global_server_args_dict, server_args=global_server_args_dict,
......
...@@ -48,6 +48,9 @@ class TestKvEvents(CustomTestCase): ...@@ -48,6 +48,9 @@ class TestKvEvents(CustomTestCase):
32, 32,
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
2, 2,
"--enable-dp-attention",
"--dp-size",
1,
], ],
) )
...@@ -233,7 +236,6 @@ class TestKvEvents(CustomTestCase): ...@@ -233,7 +236,6 @@ class TestKvEvents(CustomTestCase):
_, seq_bytes, payload = sub.recv_multipart() _, seq_bytes, payload = sub.recv_multipart()
event_batch = decoder.decode(payload) event_batch = decoder.decode(payload)
for event in event_batch.events: for event in event_batch.events:
print(f" - {event}")
events.append(event) events.append(event)
for expected in expected_events: for expected in expected_events:
...@@ -242,6 +244,134 @@ class TestKvEvents(CustomTestCase): ...@@ -242,6 +244,134 @@ class TestKvEvents(CustomTestCase):
finally: finally:
kill_process_tree(process.pid) kill_process_tree(process.pid)
def test_kv_events_attn_dp(self):
"""Test that kv events are properly tagged with DP rank in attention DP mode"""
# Launch multiple subscribers for different DP ranks
decoder = Decoder(type=KVEventBatch)
context = zmq.Context()
# Subscribe to both DP rank endpoints
sub_dp0 = context.socket(zmq.SUB)
sub_dp0.connect("tcp://localhost:5557") # DP rank 0
topic = "kv-events"
sub_dp0.setsockopt_string(zmq.SUBSCRIBE, topic)
sub_dp1 = context.socket(zmq.SUB)
sub_dp1.connect("tcp://localhost:5558") # DP rank 1 (offset by rank)
sub_dp1.setsockopt_string(zmq.SUBSCRIBE, topic)
# Launch sglang server with DP attention enabled
process = popen_launch_server(
"silence09/DeepSeek-R1-Small-2layers",
DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--kv-events-config",
'{"publisher": "zmq", "topic": "kv-events"}',
"--max-total-tokens",
64,
"--cuda-graph-max-bs",
4,
"--enable-dp-attention",
"--dp-size",
2,
"--tp-size",
2,
],
)
try:
# Make requests to generate events
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
self.assertEqual(response.status_code, 200)
# Send multiple requests to trigger events from both DP ranks
for i in range(4):
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/generate",
json={
"text": f"Request {i}: The capital of country {i} is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
# Collect events from both DP ranks
events_dp0 = []
events_dp1 = []
start = time.time()
max_wait_s = 10
min_events_per_rank = 3 # Expect at least a few events from each rank
while (time.time() - start) < max_wait_s and (
len(events_dp0) < min_events_per_rank
or len(events_dp1) < min_events_per_rank
):
# Check DP rank 0
if sub_dp0.poll(timeout=100): # 100ms timeout
_, seq_bytes, payload = sub_dp0.recv_multipart()
event_batch = decoder.decode(payload)
print(
f"DP Rank 0 - EventBatch: ts={event_batch.ts}, attn_dp_rank={event_batch.attn_dp_rank}"
)
self.assertEqual(
event_batch.attn_dp_rank,
0,
"DP rank 0 events should have attn_dp_rank=0",
)
for event in event_batch.events:
print(f" DP0 - {event}")
events_dp0.append(event)
# Check DP rank 1
if sub_dp1.poll(timeout=100): # 100ms timeout
_, seq_bytes, payload = sub_dp1.recv_multipart()
event_batch = decoder.decode(payload)
print(
f"DP Rank 1 - EventBatch: ts={event_batch.ts}, attn_dp_rank={event_batch.attn_dp_rank}"
)
self.assertEqual(
event_batch.attn_dp_rank,
1,
"DP rank 1 events should have attn_dp_rank=1",
)
for event in event_batch.events:
print(f" DP1 - {event}")
events_dp1.append(event)
# Verify we got events from both DP ranks
print(f"Collected {len(events_dp0)} events from DP rank 0")
print(f"Collected {len(events_dp1)} events from DP rank 1")
self.assertGreaterEqual(
len(events_dp0),
min_events_per_rank,
f"Expected at least {min_events_per_rank} events from DP rank 0",
)
self.assertGreaterEqual(
len(events_dp1),
min_events_per_rank,
f"Expected at least {min_events_per_rank} events from DP rank 1",
)
# Verify event types are as expected
for events in [events_dp0, events_dp1]:
for event in events:
self.assertIsInstance(
event,
(BlockStored, BlockRemoved, AllBlocksCleared),
f"Event should be a KV cache event, got {type(event)}",
)
finally:
sub_dp0.close()
sub_dp1.close()
context.term()
kill_process_tree(process.pid)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment