Unverified Commit 10d60cd4 authored by u4lr451's avatar u4lr451 Committed by GitHub
Browse files

feat: mtp support dp-attention (#6081)


Co-authored-by: default avataraustindeng <austindeng@tencent.com>
Co-authored-by: default avatartianqilin.99 <tianqilin.99@bytedance.com>
Co-authored-by: default avatarQiaolin Yu <liin1211@outlook.com>
Co-authored-by: default avatarch-wan <cwan39@gatech.edu>
parent 8a10c4c3
...@@ -7,8 +7,12 @@ from typing import List, Optional, Tuple ...@@ -7,8 +7,12 @@ from typing import List, Optional, Tuple
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group from sglang.srt.distributed import (
from sglang.srt.layers.dp_attention import disable_dp_size GroupCoordinator,
get_tensor_model_parallel_world_size,
get_tp_group,
patch_tensor_parallel_group,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
...@@ -57,7 +61,7 @@ logger = logging.getLogger(__name__) ...@@ -57,7 +61,7 @@ logger = logging.getLogger(__name__)
def draft_tp_context(tp_group: GroupCoordinator): def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group. # Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups. # We disable mscclpp now because it doesn't support 2 comm groups.
with disable_dp_size(), patch_tensor_parallel_group(tp_group): with patch_tensor_parallel_group(tp_group):
yield yield
...@@ -76,6 +80,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -76,6 +80,7 @@ class EAGLEWorker(TpModelWorker):
self.server_args = server_args self.server_args = server_args
self.topk = server_args.speculative_eagle_topk self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
self.enable_nan_detection = server_args.enable_nan_detection self.enable_nan_detection = server_args.enable_nan_detection
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.device = server_args.device self.device = server_args.device
...@@ -302,17 +307,29 @@ class EAGLEWorker(TpModelWorker): ...@@ -302,17 +307,29 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepted, A tuple of the final logit output of the target model, next tokens accepted,
the batch id (used for overlap schedule), and number of accepted tokens. the batch id (used for overlap schedule), and number of accepted tokens.
""" """
if batch.forward_mode.is_decode(): if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
logits_output, next_token_ids, bid, seq_lens_cpu = (
self.forward_target_extend(batch)
)
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
)
return logits_output, next_token_ids, bid, 0, False
else:
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch) spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
self.verify(batch, spec_info) self.verify(batch, spec_info)
) )
need_forward, can_run_draft_extend_cuda_graph = (
# If it is None, it means all requests are finished self.check_forward_draft_extend_after_decode(batch)
if batch.spec_info.verified_id is not None: )
if need_forward:
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend_after_decode(batch) self.forward_draft_extend_after_decode(
batch, can_run_draft_extend_cuda_graph
)
return ( return (
logits_output, logits_output,
verify_output.verified_id, verify_output.verified_id,
...@@ -320,22 +337,30 @@ class EAGLEWorker(TpModelWorker): ...@@ -320,22 +337,30 @@ class EAGLEWorker(TpModelWorker):
sum(verify_output.accept_length_per_req_cpu), sum(verify_output.accept_length_per_req_cpu),
can_run_cuda_graph, can_run_cuda_graph,
) )
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch() def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
logits_output, next_token_ids, _ = ( local_need_forward = (
self.target_worker.forward_batch_generation(model_worker_batch) batch.spec_info.verified_id is not None
and batch.spec_info.verified_id.shape[0] > 0
) )
if not self.server_args.enable_dp_attention:
return local_need_forward, True
return logits_output, next_token_ids, model_worker_batch.bid, 0, False global_need_forward = torch.tensor(
else: [
logits_output, next_token_ids, bid, seq_lens_cpu = ( (local_need_forward),
self.forward_target_extend(batch) ],
dtype=torch.int64,
) )
with self.draft_tp_context(self.draft_model_runner.tp_group): torch.distributed.all_reduce(
self.forward_draft_extend( global_need_forward, group=get_tp_group().cpu_group
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
) )
return logits_output, next_token_ids, bid, 0, False global_need_forward_cnt = global_need_forward[0].item()
need_forward = global_need_forward_cnt > 0
can_run_draft_extend_cuda_graph = (
global_need_forward_cnt == get_tensor_model_parallel_world_size()
)
return need_forward, can_run_draft_extend_cuda_graph
def forward_target_extend( def forward_target_extend(
self, batch: ScheduleBatch self, batch: ScheduleBatch
...@@ -354,6 +379,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -354,6 +379,7 @@ class EAGLEWorker(TpModelWorker):
# We need the full hidden states to prefill the KV cache of the draft model. # We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
model_worker_batch.spec_num_draft_tokens = 1
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation( logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
model_worker_batch model_worker_batch
) )
...@@ -364,7 +390,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -364,7 +390,7 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch.seq_lens_cpu, model_worker_batch.seq_lens_cpu,
) )
def draft(self, batch: ScheduleBatch): def _draft_preprocess_decode(self, batch: ScheduleBatch):
# Parse args # Parse args
num_seqs = batch.batch_size() num_seqs = batch.batch_size()
spec_info = batch.spec_info spec_info = batch.spec_info
...@@ -466,10 +492,32 @@ class EAGLEWorker(TpModelWorker): ...@@ -466,10 +492,32 @@ class EAGLEWorker(TpModelWorker):
batch.seq_lens_sum = torch.sum(batch.seq_lens).item() batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
batch.return_hidden_states = False batch.return_hidden_states = False
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0) spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
def _draft_preprocess_idle(self, batch: ScheduleBatch):
batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=self.model_config.hidden_size,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
def draft(self, batch: ScheduleBatch):
# Parse args
if batch.forward_mode.is_idle():
self._draft_preprocess_idle(batch)
else:
self._draft_preprocess_decode(batch)
spec_info = batch.spec_info
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
batch.return_hidden_states = False
# Get forward batch # Get forward batch
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.topk
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
...@@ -481,12 +529,18 @@ class EAGLEWorker(TpModelWorker): ...@@ -481,12 +529,18 @@ class EAGLEWorker(TpModelWorker):
forward_batch forward_batch
) )
else: else:
if not forward_batch.forward_mode.is_idle():
# Initialize attention backend # Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch) self.draft_attn_backend.init_forward_metadata(forward_batch)
# Run forward steps # Run forward steps
score_list, token_list, parents_list = self.draft_forward(forward_batch) score_list, token_list, parents_list = self.draft_forward(forward_batch)
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup) if batch.forward_mode.is_idle():
return EagleVerifyInput.create_idle_input(
self.topk,
self.speculative_num_steps,
self.speculative_num_draft_tokens,
)
( (
tree_mask, tree_mask,
...@@ -504,7 +558,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -504,7 +558,7 @@ class EAGLEWorker(TpModelWorker):
batch.seq_lens_sum, batch.seq_lens_sum,
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens, self.speculative_num_draft_tokens,
) )
return EagleVerifyInput( return EagleVerifyInput(
...@@ -584,11 +638,16 @@ class EAGLEWorker(TpModelWorker): ...@@ -584,11 +638,16 @@ class EAGLEWorker(TpModelWorker):
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch, self.page_size) spec_info.prepare_for_verify(batch, self.page_size)
batch.return_hidden_states = False batch.return_hidden_states = False
batch.forward_mode = ForwardMode.TARGET_VERIFY batch.forward_mode = (
ForwardMode.TARGET_VERIFY
if not batch.forward_mode.is_idle()
else ForwardMode.IDLE
)
batch.spec_info = spec_info batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch( model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=spec_info.seq_lens_cpu seq_lens_cpu_cache=spec_info.seq_lens_cpu
) )
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
if batch.has_grammar: if batch.has_grammar:
...@@ -646,7 +705,9 @@ class EAGLEWorker(TpModelWorker): ...@@ -646,7 +705,9 @@ class EAGLEWorker(TpModelWorker):
self.add_logprob_values(batch, res, logits_output) self.add_logprob_values(batch, res, logits_output)
# Prepare the batch for the next draft forwards. # Prepare the batch for the next draft forwards.
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = (
ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
)
batch.spec_info = res.draft_input batch.spec_info = res.draft_input
return logits_output, res, model_worker_batch, can_run_cuda_graph return logits_output, res, model_worker_batch, can_run_cuda_graph
...@@ -743,6 +804,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -743,6 +804,7 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch = batch.get_model_worker_batch( model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu seq_lens_cpu_cache=seq_lens_cpu
) )
model_worker_batch.spec_num_draft_tokens = 1
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
...@@ -753,19 +815,37 @@ class EAGLEWorker(TpModelWorker): ...@@ -753,19 +815,37 @@ class EAGLEWorker(TpModelWorker):
assert forward_batch.spec_info is batch.spec_info assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info) self.capture_for_decode(logits_output, forward_batch.spec_info)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch): def forward_draft_extend_after_decode(
self, batch: ScheduleBatch, can_run_draft_extend_cuda_graph: bool
):
# Backup fields that will be modified in-place # Backup fields that will be modified in-place
seq_lens_backup = batch.seq_lens.clone() seq_lens_backup = batch.seq_lens.clone()
req_pool_indices_backup = batch.req_pool_indices req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob return_logprob_backup = batch.return_logprob
input_is_idle = batch.forward_mode.is_idle()
if not input_is_idle:
# Prepare metadata # Prepare metadata
if batch.spec_info.verified_id is not None:
batch.spec_info.prepare_extend_after_decode( batch.spec_info.prepare_extend_after_decode(
batch, batch,
self.speculative_num_steps, self.speculative_num_steps,
) )
else:
batch = batch.copy()
batch.prepare_for_idle()
batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=self.model_config.hidden_size,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
...@@ -776,7 +856,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -776,7 +856,8 @@ class EAGLEWorker(TpModelWorker):
# Run # Run
can_cuda_graph = ( can_cuda_graph = (
self.cuda_graph_runner_for_draft_extend can_run_draft_extend_cuda_graph
and self.cuda_graph_runner_for_draft_extend
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch) and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
) )
if can_cuda_graph: if can_cuda_graph:
...@@ -789,7 +870,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -789,7 +870,10 @@ class EAGLEWorker(TpModelWorker):
) )
forward_batch.spec_info.hidden_states = logits_output.hidden_states forward_batch.spec_info.hidden_states = logits_output.hidden_states
else: else:
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch) if not forward_batch.forward_mode.is_idle():
self.draft_model_runner.attn_backend.init_forward_metadata(
forward_batch
)
logits_output = self.draft_model_runner.model.forward( logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch.input_ids, forward_batch.positions, forward_batch
) )
...@@ -799,7 +883,9 @@ class EAGLEWorker(TpModelWorker): ...@@ -799,7 +883,9 @@ class EAGLEWorker(TpModelWorker):
# Restore backup. # Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode` # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = (
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
)
batch.seq_lens = seq_lens_backup batch.seq_lens = seq_lens_backup
batch.req_pool_indices = req_pool_indices_backup batch.req_pool_indices = req_pool_indices_backup
batch.spec_info.accept_length = accept_length_backup batch.spec_info.accept_length = accept_length_backup
......
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST, DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase, CustomTestCase,
is_in_amd_ci,
popen_launch_server, popen_launch_server,
) )
...@@ -65,5 +71,71 @@ class TestDPAttentionDP2TP2(CustomTestCase): ...@@ -65,5 +71,71 @@ class TestDPAttentionDP2TP2(CustomTestCase):
self.assertGreater(metrics["score"], 0.8) self.assertGreater(metrics["score"], 0.8)
class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--trust-remote-code",
"--disable-radix",
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
"2",
"--speculative-eagle-topk",
"4",
"--speculative-num-draft-tokens",
"4",
"--speculative-draft",
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
"--tp-size",
"2",
"--enable-dp-attention",
"--dp-size",
"2",
]
if not is_in_amd_ci():
other_args += ["--mem-frac", "0.7"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["internal_states"][0][
"avg_spec_accept_length"
]
print(
f"###test_gsm8k (deepseek-v3 mtp + dp):\n"
f"accuracy={metrics['accuracy']=:.3f}\n"
f"{avg_spec_accept_length=:.3f}\n"
)
self.assertGreater(avg_spec_accept_length, 2.5)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment