Unverified Commit d88ac9bc authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[overlap-spec] Make plan stream an option (#11724)

parent ce11dd82
......@@ -221,6 +221,9 @@ class Envs:
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)
# Overlap Spec V2
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)
# VLM
SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28)
SGLANG_RESIZE_RESAMPLE = EnvStr("")
......
......@@ -365,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
forward_batch.extend_prefix_lens.sum().item(),
sum(forward_batch.extend_prefix_lens_cpu),
dtype=torch.int64,
device=self.device,
)
......
......@@ -404,6 +404,8 @@ class ForwardBatch:
if ret.positions is None:
ret.positions = clamp_position(batch.seq_lens)
else:
assert isinstance(batch.extend_seq_lens, list)
assert isinstance(batch.extend_prefix_lens, list)
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True)
......
......@@ -114,7 +114,7 @@ class EagleDraftInputV2Mixin:
num_draft_tokens: int,
draft_model_runner: Any,
):
seq_lens_cpu_backup = batch.seq_lens_cpu
seq_lens_cpu_ = batch.seq_lens_cpu
extend_num_tokens = len(batch.seq_lens) * num_draft_tokens
batch.spec_info = self
......@@ -123,8 +123,7 @@ class EagleDraftInputV2Mixin:
batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens
batch.seq_lens_sum += extend_num_tokens
batch.extend_seq_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))]
batch.extend_prefix_lens = seq_lens_cpu_backup.tolist()
batch.extend_prefix_lens_cpu = seq_lens_cpu_backup
batch.extend_prefix_lens = seq_lens_cpu_.tolist()
batch.extend_num_tokens = extend_num_tokens
batch.capture_hidden_mode = CaptureHiddenMode.FULL
batch.forward_mode = ForwardMode.DRAFT_EXTEND_V2
......
import contextlib
import logging
from typing import List, Optional
import torch
from torch.cuda import Stream as CudaStream
from sglang.srt.environ import envs
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req
from sglang.srt.managers.scheduler import GenerationBatchResult
......@@ -50,9 +52,13 @@ class EAGLEWorkerV2(EAGLEWorker):
self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens
)
self.tree_mask_mode = TreeMaskMode.FULL_MASK
self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
# TODO(lsyin): potential bugs with a separate plan stream
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream)
if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream)
else:
self.plan_stream = None
self.plan_stream_ctx = contextlib.nullcontext()
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
if model_worker_batch.forward_mode.is_decode():
......@@ -232,9 +238,13 @@ class EAGLEWorkerV2(EAGLEWorker):
batch: ModelWorkerBatch,
pre_draft_allocate_lens: torch.Tensor,
):
# Since batch.seq_lens is allocated in another stream, we need
# record_stream() to prevent pytorch gc and reuse the gpu memory
# while forward_stream is still running.
batch.seq_lens.record_stream(torch.cuda.current_stream())
# Parse args
verify_input: EagleVerifyInput = batch.spec_info
seq_lens_backup = batch.seq_lens
bs = len(batch.seq_lens)
# Batch 1: Target verify
......@@ -280,17 +290,8 @@ class EAGLEWorkerV2(EAGLEWorker):
accept_length,
accept_index,
) = verify_input.sample(batch, logits_output)
new_seq_lens = seq_lens_backup + accept_length
new_seq_lens = batch.seq_lens + accept_length
verify_done = torch.cuda.Event()
# Move the accepted tokens to the target KV cache locations
batch.seq_lens = seq_lens_backup
self.move_accepted_tokens_to_target_kvcache(
batch,
accept_index,
accept_length,
)
verify_done.record()
all_verified_id = predict[accept_index]
......@@ -341,11 +342,6 @@ class EAGLEWorkerV2(EAGLEWorker):
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
ret_hidden_states = draft_logits_output.hidden_states
# Since seq_lens_backup's tensor is allocated in another stream, we
# need record_stream() to prevent pytorch gc and reuse the gpu memory
# while forward_stream is still running.
seq_lens_backup.record_stream(torch.cuda.current_stream())
# Construct the return values
next_draft_input = EagleDraftInput(
topk_p=ret_topk_p,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment