"vscode:/vscode.git/clone" did not exist on "27ff762a60e04976d8586a9de4d53efb03bbcfcd"
Unverified Commit 750940ae authored by Rain H's avatar Rain H Committed by GitHub
Browse files

Eagle3 DP attention for Qwen3 MoE (#12002)

parent 42f8ea40
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from functools import partial from functools import partial
from typing import Dict, Optional from typing import Dict, List, Optional
import torch import torch
...@@ -216,6 +216,28 @@ class LayerCommunicator: ...@@ -216,6 +216,28 @@ class LayerCommunicator:
get_global_server_args().speculative_algorithm get_global_server_args().speculative_algorithm
) )
def prepare_attn_and_capture_last_layer_outputs(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
):
hidden_states, residual = self.prepare_attn(
hidden_states, residual, forward_batch
)
if captured_last_layer_outputs is not None:
gathered_last_layer_output = self._communicate_simple_fn(
hidden_states=residual,
forward_batch=forward_batch,
context=self._context,
)
if gathered_last_layer_output is residual:
# Clone to avoid modifying the original residual by Custom RMSNorm inplace operation
gathered_last_layer_output = residual.clone()
captured_last_layer_outputs.append(gathered_last_layer_output)
return hidden_states, residual
def prepare_attn( def prepare_attn(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -19,6 +19,7 @@ from sglang.srt.utils import add_prefix ...@@ -19,6 +19,7 @@ from sglang.srt.utils import add_prefix
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights.""" """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
import copy
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
...@@ -161,6 +162,10 @@ class LlamaModel(nn.Module): ...@@ -161,6 +162,10 @@ class LlamaModel(nn.Module):
if hidden_states.shape[-1] != embeds.shape[-1]: if hidden_states.shape[-1] != embeds.shape[-1]:
hidden_states = self.fc(hidden_states) hidden_states = self.fc(hidden_states)
# idle batch
if hidden_states.shape[0] == 0:
return hidden_states, [hidden_states]
residual = None residual = None
hidden_states, residual = self.midlayer( hidden_states, residual = self.midlayer(
positions, positions,
...@@ -212,7 +217,12 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM): ...@@ -212,7 +217,12 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) config_ = copy.deepcopy(config)
config_.vocab_size = (
config_.draft_vocab_size
) # draft logits processor has it's own vocab size
self.logits_processor = LogitsProcessor(config_)
self.capture_aux_hidden_states = True self.capture_aux_hidden_states = True
self.hot_token_id = None self.hot_token_id = None
......
...@@ -473,10 +473,16 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -473,10 +473,16 @@ class Qwen2MoeDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_states, residual = self.layer_communicator.prepare_attn( hidden_states, residual = (
hidden_states, residual, forward_batch self.layer_communicator.prepare_attn_and_capture_last_layer_outputs(
hidden_states,
residual,
forward_batch,
captured_last_layer_outputs=captured_last_layer_outputs,
)
) )
if hidden_states.shape[0] != 0: if hidden_states.shape[0] != 0:
...@@ -553,6 +559,11 @@ class Qwen2MoeModel(nn.Module): ...@@ -553,6 +559,11 @@ class Qwen2MoeModel(nn.Module):
# For EAGLE3 support # For EAGLE3 support
self.layers_to_capture = [] self.layers_to_capture = []
def set_eagle3_layers_to_capture(self, layers_to_capture: List[int]):
self.layers_to_capture = layers_to_capture
for layer_id in self.layers_to_capture:
setattr(self.layers[layer_id], "_is_layer_to_capture", True)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -585,12 +596,6 @@ class Qwen2MoeModel(nn.Module): ...@@ -585,12 +596,6 @@ class Qwen2MoeModel(nn.Module):
) )
else: else:
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
if i in self.layers_to_capture:
aux_hidden_states.append(
hidden_states + residual
if residual is not None
else hidden_states
)
ctx = ( ctx = (
nullcontext() nullcontext()
if get_global_server_args().enable_piecewise_cuda_graph if get_global_server_args().enable_piecewise_cuda_graph
...@@ -599,7 +604,15 @@ class Qwen2MoeModel(nn.Module): ...@@ -599,7 +604,15 @@ class Qwen2MoeModel(nn.Module):
with ctx: with ctx:
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual positions,
hidden_states,
forward_batch,
residual,
captured_last_layer_outputs=(
aux_hidden_states
if getattr(layer, "_is_layer_to_capture", False)
else None
),
) )
if not self.pp_group.is_last_rank: if not self.pp_group.is_last_rank:
return PPProxyTensors( return PPProxyTensors(
...@@ -830,13 +843,15 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -830,13 +843,15 @@ class Qwen2MoeForCausalLM(nn.Module):
self.capture_aux_hidden_states = True self.capture_aux_hidden_states = True
if layer_ids is None: if layer_ids is None:
num_layers = self.config.num_hidden_layers num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [ self.model.set_eagle3_layers_to_capture(
[
2, 2,
num_layers // 2, num_layers // 2,
num_layers - 3, num_layers - 3,
] # Specific layers for EAGLE3 support ]
) # Specific layers for EAGLE3 support
else: else:
self.model.layers_to_capture = [val + 1 for val in layer_ids] self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids])
EntryClass = Qwen2MoeForCausalLM EntryClass = Qwen2MoeForCausalLM
...@@ -537,10 +537,16 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -537,10 +537,16 @@ class Qwen3MoeDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_states, residual = self.layer_communicator.prepare_attn( hidden_states, residual = (
hidden_states, residual, forward_batch self.layer_communicator.prepare_attn_and_capture_last_layer_outputs(
hidden_states,
residual,
forward_batch,
captured_last_layer_outputs=captured_last_layer_outputs,
)
) )
if hidden_states.shape[0] != 0: if hidden_states.shape[0] != 0:
...@@ -772,13 +778,15 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -772,13 +778,15 @@ class Qwen3MoeForCausalLM(nn.Module):
self.capture_aux_hidden_states = True self.capture_aux_hidden_states = True
if layer_ids is None: if layer_ids is None:
num_layers = self.config.num_hidden_layers num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [ self.model.set_eagle3_layers_to_capture(
[
2, 2,
num_layers // 2, num_layers // 2,
num_layers - 3, num_layers - 3,
] # Specific layers for EAGLE3 support ]
) # Specific layers for EAGLE3 support
else: else:
self.model.layers_to_capture = [val + 1 for val in layer_ids] self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids])
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -822,7 +822,7 @@ class ServerArgs: ...@@ -822,7 +822,7 @@ class ServerArgs:
capture_bs = ( capture_bs = (
list(range(1, 9, 1)) list(range(1, 9, 1))
+ list(range(10, 33, 2)) + list(range(10, 33, 2))
+ list(range(40, 64, 4)) + list(range(40, 65, 4))
+ list(range(72, 257, 8)) + list(range(72, 257, 8))
+ list(range(272, self.cuda_graph_max_bs + 1, 16)) + list(range(272, self.cuda_graph_max_bs + 1, 16))
) )
......
...@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple ...@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
import torch import torch
from sglang.srt.distributed import get_tp_group from sglang.srt.distributed import get_tp_group
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
...@@ -117,7 +118,11 @@ class EAGLEWorker(TpModelWorker): ...@@ -117,7 +118,11 @@ class EAGLEWorker(TpModelWorker):
self.hot_token_id = None self.hot_token_id = None
# Init draft worker # Init draft worker
with empty_context(): if server_args.enable_dp_attention and self.speculative_algorithm.is_eagle3():
ctx = draft_tp_context(get_attention_tp_group())
else:
ctx = empty_context()
with ctx:
super().__init__( super().__init__(
server_args=server_args, server_args=server_args,
gpu_id=gpu_id, gpu_id=gpu_id,
......
...@@ -84,6 +84,8 @@ DEFAULT_MODEL_NAME_FOR_TEST_AWQ_INT4 = ( ...@@ -84,6 +84,8 @@ DEFAULT_MODEL_NAME_FOR_TEST_AWQ_INT4 = (
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf" DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B" DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct" DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_EAGLE_DP_ATTENTION_TARGET_MODEL_FOR_TEST = "Qwen/Qwen3-30B-A3B"
DEFAULT_EAGLE_DP_ATTENTION_DRAFT_MODEL_FOR_TEST = "Tengyunw/qwen3_30b_moe_eagle3"
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B" DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B"
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = ( DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
"meta-llama/Llama-3.1-8B-Instruct" "meta-llama/Llama-3.1-8B-Instruct"
......
...@@ -158,6 +158,7 @@ suites = { ...@@ -158,6 +158,7 @@ suites = {
TestFile("test_load_weights_from_remote_instance.py", 72), TestFile("test_load_weights_from_remote_instance.py", 72),
TestFile("test_patch_torch.py", 19), TestFile("test_patch_torch.py", 19),
TestFile("test_release_memory_occupation.py", 257), TestFile("test_release_memory_occupation.py", 257),
TestFile("test_eagle_dp_attention.py", 200),
], ],
"per-commit-4-gpu": [ "per-commit-4-gpu": [
TestFile("models/test_qwen3_next_models.py", 291), TestFile("models/test_qwen3_next_models.py", 291),
......
import unittest
from types import SimpleNamespace
import requests
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.send_one import BenchArgs, send_one_prompt
from sglang.test.test_utils import (
DEFAULT_EAGLE_DP_ATTENTION_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_DP_ATTENTION_TARGET_MODEL_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_amd_ci,
is_in_ci,
kill_process_tree,
popen_launch_server,
write_github_step_summary,
)
class TestEAGLE3EngineDPAttention(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_EAGLE_DP_ATTENTION_TARGET_MODEL_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--trust-remote-code",
"--speculative-algorithm",
"EAGLE3",
"--speculative-num-steps",
"6",
"--speculative-eagle-topk",
"10",
"--speculative-num-draft-tokens",
"32",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DP_ATTENTION_DRAFT_MODEL_FOR_TEST,
"--tp-size",
"2",
"--dp-size",
"2",
"--enable-dp-attention",
"--enable-dp-lm-head",
"--moe-dense-tp-size",
"1",
"--attention-backend",
"fa3",
"--mem-fraction-static",
"0.75",
"--cuda-graph-max-bs",
"64",
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_a_gsm8k(self):
"""Test GSM8K evaluation - append 'a' to run first alphabetically"""
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"{metrics=}")
server_info = requests.get(self.base_url + "/get_server_info")
server_data = server_info.json()
# Try to get avg_spec_accept_length
avg_spec_accept_length = None
if "internal_states" in server_data and len(server_data["internal_states"]) > 0:
internal_state = server_data["internal_states"][0]
if "avg_spec_accept_length" in internal_state:
avg_spec_accept_length = internal_state["avg_spec_accept_length"]
elif "spec_accept_length" in internal_state:
avg_spec_accept_length = internal_state["spec_accept_length"]
print(f"{avg_spec_accept_length=}")
if is_in_ci():
write_github_step_summary(
f"### test_gsm8k (EAGLE3 DP Attention)\n"
f'{metrics["accuracy"]=:.3f}\n'
f"{avg_spec_accept_length=:.2f}\n"
)
self.assertGreater(metrics["accuracy"], 0.91)
if avg_spec_accept_length is not None:
self.assertGreater(avg_spec_accept_length, 2.5)
def test_bs_1_speed(self):
"""Test batch size 1 speed with EAGLE3 DP Attention"""
args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048)
acc_length, speed = send_one_prompt(args)
print(f"{acc_length=:.2f} {speed=:.2f}")
if is_in_ci():
write_github_step_summary(
f"### test_bs_1_speed (EAGLE3 DP Attention)\n"
f"{acc_length=:.2f}\n"
f"{speed=:.2f} token/s\n"
)
if is_in_amd_ci():
self.assertGreater(acc_length, 2.0)
else:
self.assertGreater(acc_length, 2.3)
if is_in_amd_ci():
self.assertGreater(speed, 10)
else:
self.assertGreater(speed, 40)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment