Unverified Commit 78c43d88 authored by JensenFire's avatar JensenFire Committed by GitHub
Browse files

[Feature] Initial eagle3 support for Deepseek-like models (#12319)

parent 7e28c67d
...@@ -21,7 +21,7 @@ import concurrent.futures ...@@ -21,7 +21,7 @@ import concurrent.futures
import logging import logging
import os import os
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Any, Dict, Iterable, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -2841,6 +2841,7 @@ class DeepseekV2Model(nn.Module): ...@@ -2841,6 +2841,7 @@ class DeepseekV2Model(nn.Module):
self.embed_tokens.embedding_dim, self.embed_tokens.embedding_dim,
) )
) )
self.layers_to_capture = []
def get_input_embeddings(self) -> torch.Tensor: def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens return self.embed_tokens
...@@ -2897,9 +2898,11 @@ class DeepseekV2Model(nn.Module): ...@@ -2897,9 +2898,11 @@ class DeepseekV2Model(nn.Module):
normal_end_layer = self.first_k_dense_replace normal_end_layer = self.first_k_dense_replace
elif self.first_k_dense_replace < normal_start_layer: elif self.first_k_dense_replace < normal_start_layer:
normal_end_layer = normal_start_layer = 0 normal_end_layer = normal_start_layer = 0
aux_hidden_states = []
for i in range(normal_start_layer, normal_end_layer): for i in range(normal_start_layer, normal_end_layer):
with get_global_expert_distribution_recorder().with_current_layer(i): with get_global_expert_distribution_recorder().with_current_layer(i):
if i in self.layers_to_capture:
aux_hidden_states.append(hidden_states + residual)
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
...@@ -2937,7 +2940,9 @@ class DeepseekV2Model(nn.Module): ...@@ -2937,7 +2940,9 @@ class DeepseekV2Model(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
else: else:
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
class DeepseekV2ForCausalLM(nn.Module): class DeepseekV2ForCausalLM(nn.Module):
...@@ -2991,6 +2996,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2991,6 +2996,7 @@ class DeepseekV2ForCausalLM(nn.Module):
if isinstance(layer.mlp, DeepseekV2MoE) if isinstance(layer.mlp, DeepseekV2MoE)
} }
) )
self.capture_aux_hidden_states = False
@property @property
def routed_experts_weights_of_layer(self): def routed_experts_weights_of_layer(self):
...@@ -3044,10 +3050,13 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3044,10 +3050,13 @@ class DeepseekV2ForCausalLM(nn.Module):
hidden_states = self.model( hidden_states = self.model(
input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
) )
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
) )
else: else:
return hidden_states return hidden_states
...@@ -3755,6 +3764,20 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3755,6 +3764,20 @@ class DeepseekV2ForCausalLM(nn.Module):
num_groups=config.n_group, num_groups=config.n_group,
) )
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if not self.pp_group.is_last_rank:
return
if layer_ids is None:
self.capture_aux_hidden_states = True
num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
else:
self.capture_aux_hidden_states = True
# we plus 1 here because in sglang, for the ith layer, it takes the output
# of the (i-1)th layer as aux hidden state
self.model.layers_to_capture = [val + 1 for val in layer_ids]
AttentionBackendRegistry.register("ascend", handle_attention_ascend) AttentionBackendRegistry.register("ascend", handle_attention_ascend)
AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer) AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment