Unverified Commit adb2503e authored by GeneZC's avatar GeneZC Committed by GitHub
Browse files

Fix stuff related to the causal_mask in CodeGen. (#21527)

* Fix stuff related to the causal_mask in CodeGen.

1. Line 613, `_keys_to_ignore_on_load_missing  =  [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]` => `_keys_to_ignore_on_load_missing  =  [r"h\.\d+\.attn\.causal_mask"]` to load correctly from CodeGen checkpoint without `causal_mask`.
2. Line 152, `causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length]
` => `causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].bool()
` to alleviate potential user warning saying like `UserWarning: where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead.`.

* Revert the .bool()

Revert the .bool() and leave it to the future PR.
parent 5b72b341
...@@ -610,7 +610,7 @@ class CodeGenModel(CodeGenPreTrainedModel): ...@@ -610,7 +610,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
CODEGEN_START_DOCSTRING, CODEGEN_START_DOCSTRING,
) )
class CodeGenForCausalLM(CodeGenPreTrainedModel): class CodeGenForCausalLM(CodeGenPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"] _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.causal_mask"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment