Unverified Commit 0a4fc73b authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Fix failure abort (#6535)

parent a6970a17
...@@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import ( ...@@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce, poll_and_all_reduce,
prepare_abort, prepare_abort,
) )
from sglang.srt.managers.schedule_batch import FINISH_ABORT
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
...@@ -321,11 +322,15 @@ class DecodeTransferQueue: ...@@ -321,11 +322,15 @@ class DecodeTransferQueue:
gloo_group: ProcessGroup, gloo_group: ProcessGroup,
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: torch.Tensor, metadata_buffers: torch.Tensor,
scheduler: Scheduler,
tree_cache: BasePrefixCache,
): ):
self.queue: List[DecodeRequest] = [] self.queue: List[DecodeRequest] = []
self.gloo_group = gloo_group self.gloo_group = gloo_group
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.metadata_buffers = metadata_buffers self.metadata_buffers = metadata_buffers
self.scheduler = scheduler
self.tree_cache = tree_cache
def add(self, req_conn: DecodeRequest) -> None: def add(self, req_conn: DecodeRequest) -> None:
self.queue.append(req_conn) self.queue.append(req_conn)
...@@ -341,6 +346,14 @@ class DecodeTransferQueue: ...@@ -341,6 +346,14 @@ class DecodeTransferQueue:
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
) )
# First, remove all failed requests from the queue
for i, decode_req in enumerate(self.queue):
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob
)
indices_to_remove.add(i)
transferred_reqs = [] transferred_reqs = []
indices_to_remove = set() indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
...@@ -396,95 +409,6 @@ class DecodeTransferQueue: ...@@ -396,95 +409,6 @@ class DecodeTransferQueue:
return transferred_reqs return transferred_reqs
class ScheduleBatchDisaggregationDecodeMixin:
def prepare_for_prebuilt_extend(self: ScheduleBatch):
"""
Prepare a prebuilt extend by populate metadata
Adapted from .prepare_for_extend().
"""
self.forward_mode = ForwardMode.EXTEND
reqs = self.reqs
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = []
pre_lens = []
req_pool_indices = []
# Pre-calculate total size
total_size = sum(req.extend_input_len for req in reqs)
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
# Fill the tensor in one pass
offset = 0
for i, req in enumerate(reqs):
req_pool_indices.append(req.req_pool_idx)
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
: req.extend_input_len
]
assert (
offset + req.extend_input_len <= total_size
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
out_cache_loc[offset : offset + req.extend_input_len] = chunk
offset += req.extend_input_len
pre_len = len(req.prefix_indices)
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
seq_lens.append(seq_len)
if len(req.output_ids) == 0:
assert (
seq_len - pre_len == req.extend_input_len
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len
req.is_retracted = False
pre_lens.append(pre_len)
req.extend_logprob_start_len = 0
extend_input_logprob_token_ids = None
# Set fields
self.input_ids = torch.tensor(
sum(input_ids, []), dtype=torch.int32, device=self.device
)
self.req_pool_indices = torch.tensor(
req_pool_indices, dtype=torch.int64, device=self.device
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens)
self.extend_num_tokens = extend_num_tokens
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
# Build sampling info
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
)
def process_prebuilt_extend(
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
):
"""Assign the buffered last input id to schedule batch"""
self.output_ids = []
for req in self.reqs:
if req.output_ids and len(req.output_ids) > 0:
# resumed retracted req
self.output_ids.append(req.output_ids[-1])
else:
assert req.transferred_output_id is not None
req.output_ids.append(req.transferred_output_id)
self.output_ids.append(req.transferred_output_id)
self.tree_cache.cache_unfinished_req(req)
self.output_ids = torch.tensor(self.output_ids, device=self.device)
class SchedulerDisaggregationDecodeMixin: class SchedulerDisaggregationDecodeMixin:
def _prepare_idle_batch_and_run(self, batch, delay_process=False): def _prepare_idle_batch_and_run(self, batch, delay_process=False):
......
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import torch
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.server_args import ServerArgs
class ScheduleBatchDisaggregationDecodeMixin:
def prepare_for_prebuilt_extend(self: ScheduleBatch):
"""
Prepare a prebuilt extend by populate metadata
Adapted from .prepare_for_extend().
"""
self.forward_mode = ForwardMode.EXTEND
reqs = self.reqs
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = []
pre_lens = []
req_pool_indices = []
# Pre-calculate total size
total_size = sum(req.extend_input_len for req in reqs)
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
# Fill the tensor in one pass
offset = 0
for i, req in enumerate(reqs):
req_pool_indices.append(req.req_pool_idx)
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
: req.extend_input_len
]
assert (
offset + req.extend_input_len <= total_size
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
out_cache_loc[offset : offset + req.extend_input_len] = chunk
offset += req.extend_input_len
pre_len = len(req.prefix_indices)
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
seq_lens.append(seq_len)
if len(req.output_ids) == 0:
assert (
seq_len - pre_len == req.extend_input_len
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len
req.is_retracted = False
pre_lens.append(pre_len)
req.extend_logprob_start_len = 0
extend_input_logprob_token_ids = None
# Set fields
self.input_ids = torch.tensor(
sum(input_ids, []), dtype=torch.int32, device=self.device
)
self.req_pool_indices = torch.tensor(
req_pool_indices, dtype=torch.int64, device=self.device
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens)
self.extend_num_tokens = extend_num_tokens
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
# Build sampling info
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
)
def process_prebuilt_extend(
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
):
"""Assign the buffered last input id to schedule batch"""
self.output_ids = []
for req in self.reqs:
if req.output_ids and len(req.output_ids) > 0:
# resumed retracted req
self.output_ids.append(req.output_ids[-1])
else:
assert req.transferred_output_id is not None
req.output_ids.append(req.transferred_output_id)
self.output_ids.append(req.transferred_output_id)
self.tree_cache.cache_unfinished_req(req)
self.output_ids = torch.tensor(self.output_ids, device=self.device)
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import os
import random
import warnings import warnings
from collections import deque from collections import deque
from enum import Enum from enum import Enum
...@@ -15,6 +17,9 @@ from sglang.srt.utils import get_ip ...@@ -15,6 +17,9 @@ from sglang.srt.utils import get_ip
FakeBootstrapHost = "2.2.2.2" FakeBootstrapHost = "2.2.2.2"
# env var for testing failure, convert to float explicitly
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
class DisaggregationMode(Enum): class DisaggregationMode(Enum):
NULL = "null" NULL = "null"
...@@ -23,6 +28,15 @@ class DisaggregationMode(Enum): ...@@ -23,6 +28,15 @@ class DisaggregationMode(Enum):
def poll_and_all_reduce(pollers, gloo_group): def poll_and_all_reduce(pollers, gloo_group):
# at a certain prob, the poll is failed to simulate failure
if FAILURE_PROB > 0:
from sglang.srt.disaggregation.base import KVPoll
polls = [
int(KVPoll.Failed) if random.random() < FAILURE_PROB else int(poller.poll())
for poller in pollers
]
else:
polls = [int(poller.poll()) for poller in pollers] polls = [int(poller.poll()) for poller in pollers]
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu") tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group) dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
......
...@@ -48,7 +48,9 @@ from sglang.global_config import global_config ...@@ -48,7 +48,9 @@ from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.disaggregation.base import BaseKVSender from sglang.srt.disaggregation.base import BaseKVSender
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin,
)
from sglang.srt.layers.multimodal import gpu_tensor_hash from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
......
...@@ -582,6 +582,8 @@ class Scheduler( ...@@ -582,6 +582,8 @@ class Scheduler(
gloo_group=self.attn_tp_cpu_group, gloo_group=self.attn_tp_cpu_group,
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers, metadata_buffers=metadata_buffers,
scheduler=self,
tree_cache=self.tree_cache,
) )
# The decode requests pending for pre-allocation # The decode requests pending for pre-allocation
......
...@@ -38,7 +38,9 @@ class ChunkCache(BasePrefixCache): ...@@ -38,7 +38,9 @@ class ChunkCache(BasePrefixCache):
def cache_finished_req(self, req: Req): def cache_finished_req(self, req: Req):
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1 req.req_pool_idx,
# For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
] ]
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool_allocator.free(kv_indices) self.token_to_kv_pool_allocator.free(kv_indices)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment