Unverified Commit 1a0a04da authored by Chen Ding's avatar Chen Ding Committed by GitHub
Browse files

[Perf] Optimize memory peak during EAGLE model loading. (#24585)


Signed-off-by: default avatarChen Ding <candy.dc@alibaba-inc.com>
parent 6d8246aa
...@@ -229,14 +229,15 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): ...@@ -229,14 +229,15 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def transform(inputs):
name, loaded_weight = inputs
if "lm_head" not in name:
name = "model." + name
return name, loaded_weight
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=None, skip_prefixes=None,
) )
loader.load_weights(map(transform, weights))
model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
...@@ -205,23 +205,21 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): ...@@ -205,23 +205,21 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None: torch.Tensor]]) -> None:
def transform(inputs):
name, loaded_weight = inputs
name, weight = self.permute_qk_weight_for_rotary(
name, loaded_weight)
if "lm_head" not in name:
name = "model." + name
return name, weight
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
# lm_head is tied with target model (Llama4ForCausalLM) # lm_head is tied with target model (Llama4ForCausalLM)
skip_prefixes=(["lm_head."]), skip_prefixes=(["lm_head."]),
) )
loader.load_weights(map(transform, weights))
model_weights = {}
weights = [
self.permute_qk_weight_for_rotary(name, loaded_weight)
for name, loaded_weight in weights
]
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
def get_input_embeddings( def get_input_embeddings(
self, self,
......
...@@ -158,14 +158,15 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): ...@@ -158,14 +158,15 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
return self.model(input_ids, positions, hidden_states) return self.model(input_ids, positions, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def transform(inputs):
name, loaded_weight = inputs
if "lm_head" not in name:
name = "model." + name
return name, loaded_weight
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=None, skip_prefixes=None,
) )
loader.load_weights(map(transform, weights))
model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment