"vscode:/vscode.git/clone" did not exist on "cadf5824e334225677e6376d837506404a299dcf"
Unverified Commit d88ac9bc authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[overlap-spec] Make plan stream an option (#11724)

parent ce11dd82
...@@ -221,6 +221,9 @@ class Envs: ...@@ -221,6 +221,9 @@ class Envs:
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096) SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256) SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)
# Overlap Spec V2
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)
# VLM # VLM
SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28) SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28)
SGLANG_RESIZE_RESAMPLE = EnvStr("") SGLANG_RESIZE_RESAMPLE = EnvStr("")
......
...@@ -365,7 +365,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -365,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
) )
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty( kv_indices = torch.empty(
forward_batch.extend_prefix_lens.sum().item(), sum(forward_batch.extend_prefix_lens_cpu),
dtype=torch.int64, dtype=torch.int64,
device=self.device, device=self.device,
) )
......
...@@ -404,6 +404,8 @@ class ForwardBatch: ...@@ -404,6 +404,8 @@ class ForwardBatch:
if ret.positions is None: if ret.positions is None:
ret.positions = clamp_position(batch.seq_lens) ret.positions = clamp_position(batch.seq_lens)
else: else:
assert isinstance(batch.extend_seq_lens, list)
assert isinstance(batch.extend_prefix_lens, list)
ret.extend_seq_lens = torch.tensor( ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32 batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True) ).to(device, non_blocking=True)
......
...@@ -114,7 +114,7 @@ class EagleDraftInputV2Mixin: ...@@ -114,7 +114,7 @@ class EagleDraftInputV2Mixin:
num_draft_tokens: int, num_draft_tokens: int,
draft_model_runner: Any, draft_model_runner: Any,
): ):
seq_lens_cpu_backup = batch.seq_lens_cpu seq_lens_cpu_ = batch.seq_lens_cpu
extend_num_tokens = len(batch.seq_lens) * num_draft_tokens extend_num_tokens = len(batch.seq_lens) * num_draft_tokens
batch.spec_info = self batch.spec_info = self
...@@ -123,8 +123,7 @@ class EagleDraftInputV2Mixin: ...@@ -123,8 +123,7 @@ class EagleDraftInputV2Mixin:
batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens
batch.seq_lens_sum += extend_num_tokens batch.seq_lens_sum += extend_num_tokens
batch.extend_seq_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))] batch.extend_seq_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))]
batch.extend_prefix_lens = seq_lens_cpu_backup.tolist() batch.extend_prefix_lens = seq_lens_cpu_.tolist()
batch.extend_prefix_lens_cpu = seq_lens_cpu_backup
batch.extend_num_tokens = extend_num_tokens batch.extend_num_tokens = extend_num_tokens
batch.capture_hidden_mode = CaptureHiddenMode.FULL batch.capture_hidden_mode = CaptureHiddenMode.FULL
batch.forward_mode = ForwardMode.DRAFT_EXTEND_V2 batch.forward_mode = ForwardMode.DRAFT_EXTEND_V2
......
import contextlib
import logging import logging
from typing import List, Optional from typing import List, Optional
import torch import torch
from torch.cuda import Stream as CudaStream from torch.cuda import Stream as CudaStream
from sglang.srt.environ import envs
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req
from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.scheduler import GenerationBatchResult
...@@ -50,9 +52,13 @@ class EAGLEWorkerV2(EAGLEWorker): ...@@ -50,9 +52,13 @@ class EAGLEWorkerV2(EAGLEWorker):
self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens
) )
self.tree_mask_mode = TreeMaskMode.FULL_MASK self.tree_mask_mode = TreeMaskMode.FULL_MASK
self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
# TODO(lsyin): potential bugs with a separate plan stream if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream) self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream)
else:
self.plan_stream = None
self.plan_stream_ctx = contextlib.nullcontext()
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
if model_worker_batch.forward_mode.is_decode(): if model_worker_batch.forward_mode.is_decode():
...@@ -232,9 +238,13 @@ class EAGLEWorkerV2(EAGLEWorker): ...@@ -232,9 +238,13 @@ class EAGLEWorkerV2(EAGLEWorker):
batch: ModelWorkerBatch, batch: ModelWorkerBatch,
pre_draft_allocate_lens: torch.Tensor, pre_draft_allocate_lens: torch.Tensor,
): ):
# Since batch.seq_lens is allocated in another stream, we need
# record_stream() to prevent pytorch gc and reuse the gpu memory
# while forward_stream is still running.
batch.seq_lens.record_stream(torch.cuda.current_stream())
# Parse args # Parse args
verify_input: EagleVerifyInput = batch.spec_info verify_input: EagleVerifyInput = batch.spec_info
seq_lens_backup = batch.seq_lens
bs = len(batch.seq_lens) bs = len(batch.seq_lens)
# Batch 1: Target verify # Batch 1: Target verify
...@@ -280,17 +290,8 @@ class EAGLEWorkerV2(EAGLEWorker): ...@@ -280,17 +290,8 @@ class EAGLEWorkerV2(EAGLEWorker):
accept_length, accept_length,
accept_index, accept_index,
) = verify_input.sample(batch, logits_output) ) = verify_input.sample(batch, logits_output)
new_seq_lens = seq_lens_backup + accept_length new_seq_lens = batch.seq_lens + accept_length
verify_done = torch.cuda.Event() verify_done = torch.cuda.Event()
# Move the accepted tokens to the target KV cache locations
batch.seq_lens = seq_lens_backup
self.move_accepted_tokens_to_target_kvcache(
batch,
accept_index,
accept_length,
)
verify_done.record() verify_done.record()
all_verified_id = predict[accept_index] all_verified_id = predict[accept_index]
...@@ -341,11 +342,6 @@ class EAGLEWorkerV2(EAGLEWorker): ...@@ -341,11 +342,6 @@ class EAGLEWorkerV2(EAGLEWorker):
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
ret_hidden_states = draft_logits_output.hidden_states ret_hidden_states = draft_logits_output.hidden_states
# Since seq_lens_backup's tensor is allocated in another stream, we
# need record_stream() to prevent pytorch gc and reuse the gpu memory
# while forward_stream is still running.
seq_lens_backup.record_stream(torch.cuda.current_stream())
# Construct the return values # Construct the return values
next_draft_input = EagleDraftInput( next_draft_input = EagleDraftInput(
topk_p=ret_topk_p, topk_p=ret_topk_p,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment