Unverified Commit 78c43d88 authored by JensenFire's avatar JensenFire Committed by GitHub
Browse files

[Feature] Initial eagle3 support for Deepseek-like models (#12319)

parent 7e28c67d
......@@ -21,7 +21,7 @@ import concurrent.futures
import logging
import os
from enum import IntEnum, auto
from typing import Any, Dict, Iterable, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
......@@ -2841,6 +2841,7 @@ class DeepseekV2Model(nn.Module):
self.embed_tokens.embedding_dim,
)
)
self.layers_to_capture = []
def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens
......@@ -2897,9 +2898,11 @@ class DeepseekV2Model(nn.Module):
normal_end_layer = self.first_k_dense_replace
elif self.first_k_dense_replace < normal_start_layer:
normal_end_layer = normal_start_layer = 0
aux_hidden_states = []
for i in range(normal_start_layer, normal_end_layer):
with get_global_expert_distribution_recorder().with_current_layer(i):
if i in self.layers_to_capture:
aux_hidden_states.append(hidden_states + residual)
layer = self.layers[i]
hidden_states, residual = layer(
positions,
......@@ -2937,7 +2940,9 @@ class DeepseekV2Model(nn.Module):
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
class DeepseekV2ForCausalLM(nn.Module):
......@@ -2991,6 +2996,7 @@ class DeepseekV2ForCausalLM(nn.Module):
if isinstance(layer.mlp, DeepseekV2MoE)
}
)
self.capture_aux_hidden_states = False
@property
def routed_experts_weights_of_layer(self):
......@@ -3044,10 +3050,13 @@ class DeepseekV2ForCausalLM(nn.Module):
hidden_states = self.model(
input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
)
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if self.pp_group.is_last_rank:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
)
else:
return hidden_states
......@@ -3755,6 +3764,20 @@ class DeepseekV2ForCausalLM(nn.Module):
num_groups=config.n_group,
)
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if not self.pp_group.is_last_rank:
return
if layer_ids is None:
self.capture_aux_hidden_states = True
num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
else:
self.capture_aux_hidden_states = True
# we plus 1 here because in sglang, for the ith layer, it takes the output
# of the (i-1)th layer as aux hidden state
self.model.layers_to_capture = [val + 1 for val in layer_ids]
AttentionBackendRegistry.register("ascend", handle_attention_ascend)
AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment