Unverified Commit 6e215118 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Fix incorrect default get_hidden_dim logic (#8987)

parent a47baff1
...@@ -92,11 +92,30 @@ def get_hidden_dim( ...@@ -92,11 +92,30 @@ def get_hidden_dim(
Please implement the function in the model class if it is not. Please implement the function in the model class if it is not.
You can reference this function in llama.py. You can reference this function in llama.py.
""" """
if module_name in ["q_proj", "o_proj", "qkv_proj"]: head_dim = getattr(
return config.hidden_size, config.hidden_size config, "head_dim", config.hidden_size // config.num_attention_heads
elif module_name in ["kv_proj"]: )
return config.hidden_size, config.hidden_size // (
config.num_attention_heads // config.num_key_value_heads # TODO: the special handling of qkv will be addressed in #8940.
if module_name == "qkv_proj":
return (
config.hidden_size,
None, # qkv_proj is only used in LoRA A
)
elif module_name == "kv_proj":
return (
None, # kv_proj is only used in LoRA B
head_dim * config.num_key_value_heads,
)
elif module_name == "q_proj":
return (
None, # q_proj is only used in LoRA B
head_dim * config.num_attention_heads,
)
elif module_name == "o_proj":
return (
head_dim * config.num_attention_heads,
config.hidden_size,
) )
elif module_name == "gate_up_proj": elif module_name == "gate_up_proj":
return config.hidden_size, config.intermediate_size return config.hidden_size, config.intermediate_size
......
...@@ -432,40 +432,6 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -432,40 +432,6 @@ class Gemma2ForCausalLM(nn.Module):
return result return result
def get_hidden_dim(self, module_name):
# return input_dim, output_dim
if module_name in ["q_proj", "qkv_proj"]:
return (
self.config.hidden_size,
self.config.head_dim * self.config.num_attention_heads,
)
elif module_name in ["o_proj"]:
return (
self.config.head_dim * self.config.num_attention_heads,
self.config.hidden_size,
)
elif module_name in ["kv_proj"]:
return (
self.config.hidden_size,
self.config.head_dim * self.config.num_key_value_heads,
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size
elif module_name == "down_proj":
return self.config.intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()
def get_module_name(self, name):
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)
def get_attention_sliding_window_size(self): def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config) return get_attention_sliding_window_size(self.config)
......
...@@ -501,20 +501,26 @@ class Gemma3nForConditionalGeneration(PreTrainedModel): ...@@ -501,20 +501,26 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
def get_hidden_dim(self, module_name): def get_hidden_dim(self, module_name):
# return input_dim, output_dim # return input_dim, output_dim
if module_name in ["q_proj", "qkv_proj"]: # TODO: the special handling of qkv will be addressed in #8940.
if module_name == "qkv_proj":
return ( return (
self.config.hidden_size, self.config.hidden_size,
self.config.head_dim * self.config.num_attention_heads, None, # qkv_proj is only used in LoRA A
) )
elif module_name in ["o_proj"]: elif module_name == "kv_proj":
return (
None, # kv_proj is only used in LoRA B
self.config.head_dim * self.config.num_key_value_heads,
)
elif module_name == "q_proj":
return ( return (
None, # q_proj is only used in LoRA B
self.config.head_dim * self.config.num_attention_heads, self.config.head_dim * self.config.num_attention_heads,
self.config.hidden_size,
) )
elif module_name in ["kv_proj"]: elif module_name in ["o_proj"]:
return ( return (
self.config.head_dim * self.config.num_attention_heads,
self.config.hidden_size, self.config.hidden_size,
self.config.head_dim * self.config.num_key_value_heads,
) )
elif module_name == "gate_up_proj": elif module_name == "gate_up_proj":
assert len(set(self.config.intermediate_size)) == 1, ( assert len(set(self.config.intermediate_size)) == 1, (
......
...@@ -363,31 +363,6 @@ class GraniteForCausalLM(nn.Module): ...@@ -363,31 +363,6 @@ class GraniteForCausalLM(nn.Module):
else: else:
return self.pooler(hidden_states, forward_batch) return self.pooler(hidden_states, forward_batch)
def get_hidden_dim(self, module_name):
# return input_dim, output_dim
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return self.config.hidden_size, self.config.hidden_size
elif module_name in ["kv_proj"]:
return self.config.hidden_size, self.config.hidden_size // (
self.config.num_attention_heads // self.config.num_key_value_heads
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size
elif module_name == "down_proj":
return self.config.intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()
def get_module_name(self, name):
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)
def get_module_name_from_weight_name(self, name): def get_module_name_from_weight_name(self, name):
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping: for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
if weight_name in name: if weight_name in name:
......
...@@ -532,31 +532,6 @@ class LlamaForCausalLM(nn.Module): ...@@ -532,31 +532,6 @@ class LlamaForCausalLM(nn.Module):
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens return self.model.embed_tokens
def get_hidden_dim(self, module_name):
# return input_dim, output_dim
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return self.config.hidden_size, self.config.hidden_size
elif module_name in ["kv_proj"]:
return self.config.hidden_size, self.config.hidden_size // (
self.config.num_attention_heads // self.config.num_key_value_heads
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size
elif module_name == "down_proj":
return self.config.intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()
def get_module_name(self, name):
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)
def get_module_name_from_weight_name(self, name): def get_module_name_from_weight_name(self, name):
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping: for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
if weight_name in name: if weight_name in name:
......
...@@ -330,30 +330,6 @@ class Qwen3ForCausalLM(nn.Module): ...@@ -330,30 +330,6 @@ class Qwen3ForCausalLM(nn.Module):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
def get_hidden_dim(self, module_name: str) -> Tuple[int]:
# return input_dim, output_dim
if module_name in ["q_proj", "qkv_proj"]:
return (
self.config.hidden_size,
self.config.head_dim * self.config.num_attention_heads,
)
elif module_name in ["o_proj"]:
return (
self.config.head_dim * self.config.num_attention_heads,
self.config.hidden_size,
)
elif module_name in ["kv_proj"]:
return (
self.config.hidden_size,
self.config.head_dim * self.config.num_key_value_heads,
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size
elif module_name == "down_proj":
return self.config.intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
......
...@@ -416,30 +416,6 @@ class TorchNativeLlamaForCausalLM(nn.Module): ...@@ -416,30 +416,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
def get_hidden_dim(self, module_name):
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return self.config.hidden_size, self.config.hidden_size
elif module_name in ["kv_proj"]:
return self.config.hidden_size, self.config.hidden_size // (
self.config.num_attention_heads // self.config.num_key_value_heads
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size
elif module_name == "down_proj":
return self.config.intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()
def get_module_name(self, name):
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)
def get_module_name_from_weight_name(self, name): def get_module_name_from_weight_name(self, name):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id, num_shard) # (param_name, shard_name, shard_id, num_shard)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment