• JDKWangGuan's avatar
    Fix KeyError handling for non-existing key in state_dict.pop() (#898) · 0d810cfb
    JDKWangGuan authored
    Update handling for KeyError in state_dict.pop() for non-existing keys.
    Changed state_dict.pop(f"h.{d}.attn.bias") to state_dict.pop(f"h.{d}.attn.bias", None) to prevent KeyError exceptions.
    
    
    The following code can re-produce the issue
    ```
    from transformers import AutoTokenizer, GPT2Model, GPT2Config
    from flash_attn.models.gpt import GPTLMHeadModel, GPTModel
    
    # >>> transformers.__version__
    # '4.38.2'
    
    model_path = 'gpt2'
    output_model_path = 'gpt2_model'
    config = GPT2Config.from_pretrained(model_path, output_hidden_states=True)
    model = GPT2Model.from_pretrained(model_path, from_tf=False, config=config)
    '''
    model fine-tuning here
    '''
    # dump the fine-tuned model
    model.save_pretrained(output_model_path)
    
    # load the fine-tuned model
    config = GPT2Config.from_pretrained(output_model_path, output_hidden_states=True)
    model = GPTModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'
    model = GPTLMHeadModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'
    
    ```
    0d810cfb
gpt.py 46.6 KB