Unverified Commit 10d60cd4 authored by u4lr451's avatar u4lr451 Committed by GitHub
Browse files

feat: mtp support dp-attention (#6081)


Co-authored-by: default avataraustindeng <austindeng@tencent.com>
Co-authored-by: default avatartianqilin.99 <tianqilin.99@bytedance.com>
Co-authored-by: default avatarQiaolin Yu <liin1211@outlook.com>
Co-authored-by: default avatarch-wan <cwan39@gatech.edu>
parent 8a10c4c3
......@@ -7,8 +7,12 @@ from typing import List, Optional, Tuple
import torch
from huggingface_hub import snapshot_download
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.layers.dp_attention import disable_dp_size
from sglang.srt.distributed import (
GroupCoordinator,
get_tensor_model_parallel_world_size,
get_tp_group,
patch_tensor_parallel_group,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import (
......@@ -57,7 +61,7 @@ logger = logging.getLogger(__name__)
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with disable_dp_size(), patch_tensor_parallel_group(tp_group):
with patch_tensor_parallel_group(tp_group):
yield
......@@ -76,6 +80,7 @@ class EAGLEWorker(TpModelWorker):
self.server_args = server_args
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
self.enable_nan_detection = server_args.enable_nan_detection
self.gpu_id = gpu_id
self.device = server_args.device
......@@ -302,17 +307,29 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepted,
the batch id (used for overlap schedule), and number of accepted tokens.
"""
if batch.forward_mode.is_decode():
if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
logits_output, next_token_ids, bid, seq_lens_cpu = (
self.forward_target_extend(batch)
)
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
)
return logits_output, next_token_ids, bid, 0, False
else:
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
self.verify(batch, spec_info)
)
# If it is None, it means all requests are finished
if batch.spec_info.verified_id is not None:
need_forward, can_run_draft_extend_cuda_graph = (
self.check_forward_draft_extend_after_decode(batch)
)
if need_forward:
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend_after_decode(batch)
self.forward_draft_extend_after_decode(
batch, can_run_draft_extend_cuda_graph
)
return (
logits_output,
verify_output.verified_id,
......@@ -320,22 +337,30 @@ class EAGLEWorker(TpModelWorker):
sum(verify_output.accept_length_per_req_cpu),
can_run_cuda_graph,
)
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids, _ = (
self.target_worker.forward_batch_generation(model_worker_batch)
)
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
else:
logits_output, next_token_ids, bid, seq_lens_cpu = (
self.forward_target_extend(batch)
)
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
)
return logits_output, next_token_ids, bid, 0, False
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
local_need_forward = (
batch.spec_info.verified_id is not None
and batch.spec_info.verified_id.shape[0] > 0
)
if not self.server_args.enable_dp_attention:
return local_need_forward, True
global_need_forward = torch.tensor(
[
(local_need_forward),
],
dtype=torch.int64,
)
torch.distributed.all_reduce(
global_need_forward, group=get_tp_group().cpu_group
)
global_need_forward_cnt = global_need_forward[0].item()
need_forward = global_need_forward_cnt > 0
can_run_draft_extend_cuda_graph = (
global_need_forward_cnt == get_tensor_model_parallel_world_size()
)
return need_forward, can_run_draft_extend_cuda_graph
def forward_target_extend(
self, batch: ScheduleBatch
......@@ -354,6 +379,7 @@ class EAGLEWorker(TpModelWorker):
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
model_worker_batch.spec_num_draft_tokens = 1
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
model_worker_batch
)
......@@ -364,7 +390,7 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch.seq_lens_cpu,
)
def draft(self, batch: ScheduleBatch):
def _draft_preprocess_decode(self, batch: ScheduleBatch):
# Parse args
num_seqs = batch.batch_size()
spec_info = batch.spec_info
......@@ -466,10 +492,32 @@ class EAGLEWorker(TpModelWorker):
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
batch.return_hidden_states = False
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
def _draft_preprocess_idle(self, batch: ScheduleBatch):
batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=self.model_config.hidden_size,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
def draft(self, batch: ScheduleBatch):
# Parse args
if batch.forward_mode.is_idle():
self._draft_preprocess_idle(batch)
else:
self._draft_preprocess_decode(batch)
spec_info = batch.spec_info
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
batch.return_hidden_states = False
# Get forward batch
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.topk
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
......@@ -481,12 +529,18 @@ class EAGLEWorker(TpModelWorker):
forward_batch
)
else:
# Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch)
if not forward_batch.forward_mode.is_idle():
# Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch)
# Run forward steps
score_list, token_list, parents_list = self.draft_forward(forward_batch)
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
if batch.forward_mode.is_idle():
return EagleVerifyInput.create_idle_input(
self.topk,
self.speculative_num_steps,
self.speculative_num_draft_tokens,
)
(
tree_mask,
......@@ -504,7 +558,7 @@ class EAGLEWorker(TpModelWorker):
batch.seq_lens_sum,
self.topk,
self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens,
self.speculative_num_draft_tokens,
)
return EagleVerifyInput(
......@@ -584,11 +638,16 @@ class EAGLEWorker(TpModelWorker):
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch, self.page_size)
batch.return_hidden_states = False
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.forward_mode = (
ForwardMode.TARGET_VERIFY
if not batch.forward_mode.is_idle()
else ForwardMode.IDLE
)
batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=spec_info.seq_lens_cpu
)
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
if batch.has_grammar:
......@@ -646,7 +705,9 @@ class EAGLEWorker(TpModelWorker):
self.add_logprob_values(batch, res, logits_output)
# Prepare the batch for the next draft forwards.
batch.forward_mode = ForwardMode.DECODE
batch.forward_mode = (
ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
)
batch.spec_info = res.draft_input
return logits_output, res, model_worker_batch, can_run_cuda_graph
......@@ -743,6 +804,7 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu
)
model_worker_batch.spec_num_draft_tokens = 1
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
......@@ -753,19 +815,37 @@ class EAGLEWorker(TpModelWorker):
assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
def forward_draft_extend_after_decode(
self, batch: ScheduleBatch, can_run_draft_extend_cuda_graph: bool
):
# Backup fields that will be modified in-place
seq_lens_backup = batch.seq_lens.clone()
req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob
# Prepare metadata
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
)
input_is_idle = batch.forward_mode.is_idle()
if not input_is_idle:
# Prepare metadata
if batch.spec_info.verified_id is not None:
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
)
else:
batch = batch.copy()
batch.prepare_for_idle()
batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=self.model_config.hidden_size,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
......@@ -776,7 +856,8 @@ class EAGLEWorker(TpModelWorker):
# Run
can_cuda_graph = (
self.cuda_graph_runner_for_draft_extend
can_run_draft_extend_cuda_graph
and self.cuda_graph_runner_for_draft_extend
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
)
if can_cuda_graph:
......@@ -789,7 +870,10 @@ class EAGLEWorker(TpModelWorker):
)
forward_batch.spec_info.hidden_states = logits_output.hidden_states
else:
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
if not forward_batch.forward_mode.is_idle():
self.draft_model_runner.attn_backend.init_forward_metadata(
forward_batch
)
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
......@@ -799,7 +883,9 @@ class EAGLEWorker(TpModelWorker):
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.forward_mode = ForwardMode.DECODE
batch.forward_mode = (
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
)
batch.seq_lens = seq_lens_backup
batch.req_pool_indices = req_pool_indices_backup
batch.spec_info.accept_length = accept_length_backup
......
import unittest
from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_amd_ci,
popen_launch_server,
)
......@@ -65,5 +71,71 @@ class TestDPAttentionDP2TP2(CustomTestCase):
self.assertGreater(metrics["score"], 0.8)
class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--trust-remote-code",
"--disable-radix",
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
"2",
"--speculative-eagle-topk",
"4",
"--speculative-num-draft-tokens",
"4",
"--speculative-draft",
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
"--tp-size",
"2",
"--enable-dp-attention",
"--dp-size",
"2",
]
if not is_in_amd_ci():
other_args += ["--mem-frac", "0.7"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["internal_states"][0][
"avg_spec_accept_length"
]
print(
f"###test_gsm8k (deepseek-v3 mtp + dp):\n"
f"accuracy={metrics['accuracy']=:.3f}\n"
f"{avg_spec_accept_length=:.3f}\n"
)
self.assertGreater(avg_spec_accept_length, 2.5)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment