Unverified Commit 0d810cfb authored by JDKWangGuan's avatar JDKWangGuan Committed by GitHub
Browse files

Fix KeyError handling for non-existing key in state_dict.pop() (#898)

Update handling for KeyError in state_dict.pop() for non-existing keys.
Changed state_dict.pop(f"h.{d}.attn.bias") to state_dict.pop(f"h.{d}.attn.bias", None) to prevent KeyError exceptions.


The following code can re-produce the issue
```
from transformers import AutoTokenizer, GPT2Model, GPT2Config
from flash_attn.models.gpt import GPTLMHeadModel, GPTModel

# >>> transformers.__version__
# '4.38.2'

model_path = 'gpt2'
output_model_path = 'gpt2_model'
config = GPT2Config.from_pretrained(model_path, output_hidden_states=True)
model = GPT2Model.from_pretrained(model_path, from_tf=False, config=config)
'''
model fine-tuning here
'''
# dump the fine-tuned model
model.save_pretrained(output_model_path)

# load the fine-tuned model
config = GPT2Config.from_pretrained(output_model_path, output_hidden_states=True)
model = GPTModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'
model = GPTLMHeadModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'

```
parent 6a2a16e9
...@@ -966,7 +966,7 @@ def remap_state_dict_hf_gpt2(state_dict, config): ...@@ -966,7 +966,7 @@ def remap_state_dict_hf_gpt2(state_dict, config):
# Attention # Attention
for d in range(config.num_hidden_layers): for d in range(config.num_hidden_layers):
state_dict.pop(f"h.{d}.attn.bias") # We don't store this bias state_dict.pop(f"h.{d}.attn.bias", None) # We don't store this bias
Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight") Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight")
state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t() state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight") Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment