Unverified Commit 750940ae authored by Rain H's avatar Rain H Committed by GitHub
Browse files

Eagle3 DP attention for Qwen3 MoE (#12002)

parent 42f8ea40
......@@ -15,7 +15,7 @@
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from typing import Dict, Optional
from typing import Dict, List, Optional
import torch
......@@ -216,6 +216,28 @@ class LayerCommunicator:
get_global_server_args().speculative_algorithm
)
def prepare_attn_and_capture_last_layer_outputs(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
):
hidden_states, residual = self.prepare_attn(
hidden_states, residual, forward_batch
)
if captured_last_layer_outputs is not None:
gathered_last_layer_output = self._communicate_simple_fn(
hidden_states=residual,
forward_batch=forward_batch,
context=self._context,
)
if gathered_last_layer_output is residual:
# Clone to avoid modifying the original residual by Custom RMSNorm inplace operation
gathered_last_layer_output = residual.clone()
captured_last_layer_outputs.append(gathered_last_layer_output)
return hidden_states, residual
def prepare_attn(
self,
hidden_states: torch.Tensor,
......
......@@ -19,6 +19,7 @@ from sglang.srt.utils import add_prefix
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
import copy
from typing import Iterable, Optional, Tuple
import torch
......@@ -161,6 +162,10 @@ class LlamaModel(nn.Module):
if hidden_states.shape[-1] != embeds.shape[-1]:
hidden_states = self.fc(hidden_states)
# idle batch
if hidden_states.shape[0] == 0:
return hidden_states, [hidden_states]
residual = None
hidden_states, residual = self.midlayer(
positions,
......@@ -212,7 +217,12 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
config_ = copy.deepcopy(config)
config_.vocab_size = (
config_.draft_vocab_size
) # draft logits processor has it's own vocab size
self.logits_processor = LogitsProcessor(config_)
self.capture_aux_hidden_states = True
self.hot_token_id = None
......
......@@ -473,10 +473,16 @@ class Qwen2MoeDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
hidden_states, residual = (
self.layer_communicator.prepare_attn_and_capture_last_layer_outputs(
hidden_states,
residual,
forward_batch,
captured_last_layer_outputs=captured_last_layer_outputs,
)
)
if hidden_states.shape[0] != 0:
......@@ -553,6 +559,11 @@ class Qwen2MoeModel(nn.Module):
# For EAGLE3 support
self.layers_to_capture = []
def set_eagle3_layers_to_capture(self, layers_to_capture: List[int]):
self.layers_to_capture = layers_to_capture
for layer_id in self.layers_to_capture:
setattr(self.layers[layer_id], "_is_layer_to_capture", True)
def forward(
self,
input_ids: torch.Tensor,
......@@ -585,12 +596,6 @@ class Qwen2MoeModel(nn.Module):
)
else:
for i in range(self.start_layer, self.end_layer):
if i in self.layers_to_capture:
aux_hidden_states.append(
hidden_states + residual
if residual is not None
else hidden_states
)
ctx = (
nullcontext()
if get_global_server_args().enable_piecewise_cuda_graph
......@@ -599,7 +604,15 @@ class Qwen2MoeModel(nn.Module):
with ctx:
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual
positions,
hidden_states,
forward_batch,
residual,
captured_last_layer_outputs=(
aux_hidden_states
if getattr(layer, "_is_layer_to_capture", False)
else None
),
)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
......@@ -830,13 +843,15 @@ class Qwen2MoeForCausalLM(nn.Module):
self.capture_aux_hidden_states = True
if layer_ids is None:
num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [
2,
num_layers // 2,
num_layers - 3,
] # Specific layers for EAGLE3 support
self.model.set_eagle3_layers_to_capture(
[
2,
num_layers // 2,
num_layers - 3,
]
) # Specific layers for EAGLE3 support
else:
self.model.layers_to_capture = [val + 1 for val in layer_ids]
self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids])
EntryClass = Qwen2MoeForCausalLM
......@@ -537,10 +537,16 @@ class Qwen3MoeDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
hidden_states, residual = (
self.layer_communicator.prepare_attn_and_capture_last_layer_outputs(
hidden_states,
residual,
forward_batch,
captured_last_layer_outputs=captured_last_layer_outputs,
)
)
if hidden_states.shape[0] != 0:
......@@ -772,13 +778,15 @@ class Qwen3MoeForCausalLM(nn.Module):
self.capture_aux_hidden_states = True
if layer_ids is None:
num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [
2,
num_layers // 2,
num_layers - 3,
] # Specific layers for EAGLE3 support
self.model.set_eagle3_layers_to_capture(
[
2,
num_layers // 2,
num_layers - 3,
]
) # Specific layers for EAGLE3 support
else:
self.model.layers_to_capture = [val + 1 for val in layer_ids]
self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids])
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -822,7 +822,7 @@ class ServerArgs:
capture_bs = (
list(range(1, 9, 1))
+ list(range(10, 33, 2))
+ list(range(40, 64, 4))
+ list(range(40, 65, 4))
+ list(range(72, 257, 8))
+ list(range(272, self.cuda_graph_max_bs + 1, 16))
)
......
......@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
import torch
from sglang.srt.distributed import get_tp_group
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch
......@@ -117,7 +118,11 @@ class EAGLEWorker(TpModelWorker):
self.hot_token_id = None
# Init draft worker
with empty_context():
if server_args.enable_dp_attention and self.speculative_algorithm.is_eagle3():
ctx = draft_tp_context(get_attention_tp_group())
else:
ctx = empty_context()
with ctx:
super().__init__(
server_args=server_args,
gpu_id=gpu_id,
......
......@@ -84,6 +84,8 @@ DEFAULT_MODEL_NAME_FOR_TEST_AWQ_INT4 = (
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_EAGLE_DP_ATTENTION_TARGET_MODEL_FOR_TEST = "Qwen/Qwen3-30B-A3B"
DEFAULT_EAGLE_DP_ATTENTION_DRAFT_MODEL_FOR_TEST = "Tengyunw/qwen3_30b_moe_eagle3"
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B"
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
"meta-llama/Llama-3.1-8B-Instruct"
......
......@@ -158,6 +158,7 @@ suites = {
TestFile("test_load_weights_from_remote_instance.py", 72),
TestFile("test_patch_torch.py", 19),
TestFile("test_release_memory_occupation.py", 257),
TestFile("test_eagle_dp_attention.py", 200),
],
"per-commit-4-gpu": [
TestFile("models/test_qwen3_next_models.py", 291),
......
import unittest
from types import SimpleNamespace
import requests
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.send_one import BenchArgs, send_one_prompt
from sglang.test.test_utils import (
DEFAULT_EAGLE_DP_ATTENTION_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_DP_ATTENTION_TARGET_MODEL_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_amd_ci,
is_in_ci,
kill_process_tree,
popen_launch_server,
write_github_step_summary,
)
class TestEAGLE3EngineDPAttention(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_EAGLE_DP_ATTENTION_TARGET_MODEL_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--trust-remote-code",
"--speculative-algorithm",
"EAGLE3",
"--speculative-num-steps",
"6",
"--speculative-eagle-topk",
"10",
"--speculative-num-draft-tokens",
"32",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DP_ATTENTION_DRAFT_MODEL_FOR_TEST,
"--tp-size",
"2",
"--dp-size",
"2",
"--enable-dp-attention",
"--enable-dp-lm-head",
"--moe-dense-tp-size",
"1",
"--attention-backend",
"fa3",
"--mem-fraction-static",
"0.75",
"--cuda-graph-max-bs",
"64",
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_a_gsm8k(self):
"""Test GSM8K evaluation - append 'a' to run first alphabetically"""
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"{metrics=}")
server_info = requests.get(self.base_url + "/get_server_info")
server_data = server_info.json()
# Try to get avg_spec_accept_length
avg_spec_accept_length = None
if "internal_states" in server_data and len(server_data["internal_states"]) > 0:
internal_state = server_data["internal_states"][0]
if "avg_spec_accept_length" in internal_state:
avg_spec_accept_length = internal_state["avg_spec_accept_length"]
elif "spec_accept_length" in internal_state:
avg_spec_accept_length = internal_state["spec_accept_length"]
print(f"{avg_spec_accept_length=}")
if is_in_ci():
write_github_step_summary(
f"### test_gsm8k (EAGLE3 DP Attention)\n"
f'{metrics["accuracy"]=:.3f}\n'
f"{avg_spec_accept_length=:.2f}\n"
)
self.assertGreater(metrics["accuracy"], 0.91)
if avg_spec_accept_length is not None:
self.assertGreater(avg_spec_accept_length, 2.5)
def test_bs_1_speed(self):
"""Test batch size 1 speed with EAGLE3 DP Attention"""
args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048)
acc_length, speed = send_one_prompt(args)
print(f"{acc_length=:.2f} {speed=:.2f}")
if is_in_ci():
write_github_step_summary(
f"### test_bs_1_speed (EAGLE3 DP Attention)\n"
f"{acc_length=:.2f}\n"
f"{speed=:.2f} token/s\n"
)
if is_in_amd_ci():
self.assertGreater(acc_length, 2.0)
else:
self.assertGreater(acc_length, 2.3)
if is_in_amd_ci():
self.assertGreater(speed, 10)
else:
self.assertGreater(speed, 40)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment