Unverified Commit 9fafa62d authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Share target model embed and head weights for nextn (#4033)

parent 146ac8df
...@@ -280,7 +280,8 @@ class ForwardBatch: ...@@ -280,7 +280,8 @@ class ForwardBatch:
).to(device, non_blocking=True) ).to(device, non_blocking=True)
if ( if (
model_runner.server_args.attention_backend != "torch_native" model_runner.server_args.attention_backend != "torch_native"
and model_runner.server_args.speculative_algorithm != "NEXTN" # TODO: Fix triton kernel illegal memory access for EAGLE
and model_runner.server_args.speculative_algorithm != "EAGLE"
): ):
ret.extend_num_tokens = batch.extend_num_tokens ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position_triton( positions, ret.extend_start_loc = compute_position_triton(
......
...@@ -116,14 +116,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -116,14 +116,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
self.model = DeepseekModelNextN(config, quant_config) self.model = DeepseekModelNextN(config, quant_config)
if global_server_args_dict["enable_dp_attention"]: if global_server_args_dict["enable_dp_attention"]:
self.model.shared_head.head = ReplicatedLinear( self.lm_head = ReplicatedLinear(
config.hidden_size, config.hidden_size,
config.vocab_size, config.vocab_size,
bias=False, bias=False,
) )
self.logits_processor = LogitsProcessor(config, skip_all_gather=True) self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else: else:
self.model.shared_head.head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
...@@ -139,7 +139,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -139,7 +139,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch) hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.model.shared_head.head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
...@@ -168,10 +168,8 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -168,10 +168,8 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
nextn_layer_prefix = "model.layers.0" nextn_layer_prefix = "model.layers.0"
nextn_spec_weight_names = [ nextn_spec_weight_names = [
"shared_head.head",
"shared_head.norm", "shared_head.norm",
"eh_proj", "eh_proj",
"embed_tokens",
"enorm", "enorm",
"hnorm", "hnorm",
] ]
...@@ -180,17 +178,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -180,17 +178,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
for name, loaded_weight in weights: for name, loaded_weight in weights:
if not name.startswith(nextn_layer_prefix): if not name.startswith(nextn_layer_prefix):
continue continue
else:
is_decoder = True # Use shared head and embed weights from target model
# For nextn specific weights if "shared_head.head" in name or "embed_tokens" in name:
for weight_name in nextn_spec_weight_names: continue
if weight_name in name:
name = name.replace(nextn_layer_prefix, "model") is_decoder = True
is_decoder = False # For nextn specific weights
break for weight_name in nextn_spec_weight_names:
# For decoder layer weights if weight_name in name:
if is_decoder: name = name.replace(nextn_layer_prefix, "model")
name = name.replace(nextn_layer_prefix, "model.decoder") is_decoder = False
break
# For decoder layer weights
if is_decoder:
name = name.replace(nextn_layer_prefix, "model.decoder")
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
...@@ -1179,6 +1179,17 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1179,6 +1179,17 @@ class DeepseekV2ForCausalLM(nn.Module):
if is_hip_: if is_hip_:
self_attn.w_scale *= 2.0 self_attn.w_scale *= 2.0
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass pass
......
...@@ -270,10 +270,11 @@ class ServerArgs: ...@@ -270,10 +270,11 @@ class ServerArgs:
) )
# Speculative Decoding # Speculative Decoding
if ( if self.speculative_algorithm == "NEXTN":
self.speculative_algorithm == "EAGLE" # NEXTN shares the same implementation of EAGLE
or self.speculative_algorithm == "NEXTN" self.speculative_algorithm = "EAGLE"
):
if self.speculative_algorithm == "EAGLE":
self.disable_overlap_schedule = True self.disable_overlap_schedule = True
self.prefill_only_one_req = True self.prefill_only_one_req = True
self.disable_cuda_graph_padding = True self.disable_cuda_graph_padding = True
......
...@@ -83,23 +83,16 @@ class EAGLEWorker(TpModelWorker): ...@@ -83,23 +83,16 @@ class EAGLEWorker(TpModelWorker):
self.server_args = server_args self.server_args = server_args
# Share the embedding and lm_head # Share the embedding and lm_head
if not self.speculative_algorithm.is_nextn(): embed, head = self.target_worker.model_runner.model.get_embed_and_head()
embed, head = self.target_worker.model_runner.model.get_embed_and_head() if server_args.speculative_token_map is not None:
if server_args.speculative_token_map is not None: head = head.clone()
head = head.clone() self.hot_token_id = torch.tensor(
self.hot_token_id = torch.tensor( self.hot_token_id, dtype=torch.int32, device=head.device
self.hot_token_id, dtype=torch.int32, device=head.device )
) head.data = head.data[self.hot_token_id]
head.data = head.data[self.hot_token_id]
else:
self.hot_token_id = None
self.model_runner.model.set_embed_and_head(embed, head)
else: else:
if server_args.speculative_token_map is not None:
raise NotImplementedError(
"NEXTN does not support speculative-token-map now"
)
self.hot_token_id = None self.hot_token_id = None
self.model_runner.model.set_embed_and_head(embed, head)
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
# Create multi-step attn backends and cuda graph runners # Create multi-step attn backends and cuda graph runners
......
...@@ -5,24 +5,16 @@ class SpeculativeAlgorithm(IntEnum): ...@@ -5,24 +5,16 @@ class SpeculativeAlgorithm(IntEnum):
NONE = auto() NONE = auto()
EAGLE = auto() EAGLE = auto()
# NEXTN spec decoding is for DeepSeek V3/R1
# currently it's implemented based on EAGLE
NEXTN = auto()
def is_none(self): def is_none(self):
return self == SpeculativeAlgorithm.NONE return self == SpeculativeAlgorithm.NONE
def is_eagle(self): def is_eagle(self):
return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.NEXTN return self == SpeculativeAlgorithm.EAGLE
def is_nextn(self):
return self == SpeculativeAlgorithm.NEXTN
@staticmethod @staticmethod
def from_string(name: str): def from_string(name: str):
name_map = { name_map = {
"EAGLE": SpeculativeAlgorithm.EAGLE, "EAGLE": SpeculativeAlgorithm.EAGLE,
"NEXTN": SpeculativeAlgorithm.NEXTN,
None: SpeculativeAlgorithm.NONE, None: SpeculativeAlgorithm.NONE,
} }
if name is not None: if name is not None:
......
...@@ -62,6 +62,8 @@ def export_nextn_layer_parameters(input_dir, output_dir, nextn_layer_id): ...@@ -62,6 +62,8 @@ def export_nextn_layer_parameters(input_dir, output_dir, nextn_layer_id):
continue continue
for key in matching_keys: for key in matching_keys:
if "embed_tokens" in key or "shared_head.head" in key:
continue
new_key = key.replace(prefix, "model.layers.0") new_key = key.replace(prefix, "model.layers.0")
params[new_key] = f.get_tensor(key) params[new_key] = f.get_tensor(key)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment