Unverified Commit 1a0a04da authored by Chen Ding's avatar Chen Ding Committed by GitHub
Browse files

[Perf] Optimize memory peak during EAGLE model loading. (#24585)


Signed-off-by: default avatarChen Ding <candy.dc@alibaba-inc.com>
parent 6d8246aa
......@@ -229,14 +229,15 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def transform(inputs):
name, loaded_weight = inputs
if "lm_head" not in name:
name = "model." + name
return name, loaded_weight
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
)
model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
loader.load_weights(map(transform, weights))
......@@ -205,23 +205,21 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None:
def transform(inputs):
name, loaded_weight = inputs
name, weight = self.permute_qk_weight_for_rotary(
name, loaded_weight)
if "lm_head" not in name:
name = "model." + name
return name, weight
loader = AutoWeightsLoader(
self,
# lm_head is tied with target model (Llama4ForCausalLM)
skip_prefixes=(["lm_head."]),
)
model_weights = {}
weights = [
self.permute_qk_weight_for_rotary(name, loaded_weight)
for name, loaded_weight in weights
]
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
loader.load_weights(map(transform, weights))
def get_input_embeddings(
self,
......
......@@ -158,14 +158,15 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
return self.model(input_ids, positions, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def transform(inputs):
name, loaded_weight = inputs
if "lm_head" not in name:
name = "model." + name
return name, loaded_weight
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
)
model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
loader.load_weights(map(transform, weights))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment