-
JDKWangGuan authored
Update handling for KeyError in state_dict.pop() for non-existing keys. Changed state_dict.pop(f"h.{d}.attn.bias") to state_dict.pop(f"h.{d}.attn.bias", None) to prevent KeyError exceptions. The following code can re-produce the issue ``` from transformers import AutoTokenizer, GPT2Model, GPT2Config from flash_attn.models.gpt import GPTLMHeadModel, GPTModel # >>> transformers.__version__ # '4.38.2' model_path = 'gpt2' output_model_path = 'gpt2_model' config = GPT2Config.from_pretrained(model_path, output_hidden_states=True) model = GPT2Model.from_pretrained(model_path, from_tf=False, config=config) ''' model fine-tuning here ''' # dump the fine-tuned model model.save_pretrained(output_model_path) # load the fine-tuned model config = GPT2Config.from_pretrained(output_model_path, output_hidden_states=True) model = GPTModel.from_pretrained(output_model_path, config=config, strict=True) # failed due to KeyError: 'h.0.attn.bias' model = GPTLMHeadModel.from_pretrained(output_model_path, config=config, strict=True) # failed due to KeyError: 'h.0.attn.bias' ```0d810cfb