"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6f4424bb086d3d090855862be5aff64eb8ed7101"
Unverified Commit 2ab76c2c authored by Yun Dai's avatar Yun Dai Committed by GitHub
Browse files

[Falcon] Set `use_cache=False` before creating `presents` which relies on `use_cache` (#26328)

* Set `presents=None` when `use_cache` is set to False for activation ckpt

* Update modeling_falcon.py

* fix black
parent 253f9a3f
......@@ -1094,6 +1094,12 @@ class FalconModel(FalconPreTrainedModel):
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
......@@ -1137,11 +1143,6 @@ class FalconModel(FalconPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment