Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
73de5108
Unverified
Commit
73de5108
authored
Dec 14, 2023
by
Younes Belkada
Committed by
GitHub
Dec 14, 2023
Browse files
[`core` / `modeling`] Fix training bug with PEFT + GC (#28031)
fix trainign bug
parent
2788f8d8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
35 additions
and
35 deletions
+35
-35
src/transformers/models/deprecated/open_llama/modeling_open_llama.py
...rmers/models/deprecated/open_llama/modeling_open_llama.py
+7
-7
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+7
-7
src/transformers/models/mistral/modeling_mistral.py
src/transformers/models/mistral/modeling_mistral.py
+7
-7
src/transformers/models/persimmon/modeling_persimmon.py
src/transformers/models/persimmon/modeling_persimmon.py
+7
-7
src/transformers/models/phi/modeling_phi.py
src/transformers/models/phi/modeling_phi.py
+7
-7
No files found.
src/transformers/models/deprecated/open_llama/modeling_open_llama.py
View file @
73de5108
...
@@ -578,6 +578,13 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
...
@@ -578,6 +578,13 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
seq_length_with_past
=
seq_length
seq_length_with_past
=
seq_length
past_key_values_length
=
0
past_key_values_length
=
0
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
if
past_key_values
is
not
None
:
if
past_key_values
is
not
None
:
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
...
@@ -608,13 +615,6 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
...
@@ -608,13 +615,6 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
# decoder layers
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
all_self_attns
=
()
if
output_attentions
else
None
...
...
src/transformers/models/llama/modeling_llama.py
View file @
73de5108
...
@@ -1000,6 +1000,13 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -1000,6 +1000,13 @@ class LlamaModel(LlamaPreTrainedModel):
else
:
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
past_key_values_length
=
0
past_key_values_length
=
0
if
use_cache
:
if
use_cache
:
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
...
@@ -1038,13 +1045,6 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -1038,13 +1045,6 @@ class LlamaModel(LlamaPreTrainedModel):
# embed positions
# embed positions
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
# decoder layers
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
all_self_attns
=
()
if
output_attentions
else
None
...
...
src/transformers/models/mistral/modeling_mistral.py
View file @
73de5108
...
@@ -855,6 +855,13 @@ class MistralModel(MistralPreTrainedModel):
...
@@ -855,6 +855,13 @@ class MistralModel(MistralPreTrainedModel):
else
:
else
:
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
past_key_values_length
=
0
past_key_values_length
=
0
if
use_cache
:
if
use_cache
:
...
@@ -899,13 +906,6 @@ class MistralModel(MistralPreTrainedModel):
...
@@ -899,13 +906,6 @@ class MistralModel(MistralPreTrainedModel):
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
# decoder layers
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
all_self_attns
=
()
if
output_attentions
else
None
...
...
src/transformers/models/persimmon/modeling_persimmon.py
View file @
73de5108
...
@@ -608,6 +608,13 @@ class PersimmonModel(PersimmonPreTrainedModel):
...
@@ -608,6 +608,13 @@ class PersimmonModel(PersimmonPreTrainedModel):
seq_length_with_past
=
seq_length
seq_length_with_past
=
seq_length
past_key_values_length
=
0
past_key_values_length
=
0
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
if
use_cache
:
if
use_cache
:
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
if
use_legacy_cache
:
if
use_legacy_cache
:
...
@@ -635,13 +642,6 @@ class PersimmonModel(PersimmonPreTrainedModel):
...
@@ -635,13 +642,6 @@ class PersimmonModel(PersimmonPreTrainedModel):
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
# decoder layers
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
all_self_attns
=
()
if
output_attentions
else
None
...
...
src/transformers/models/phi/modeling_phi.py
View file @
73de5108
...
@@ -860,6 +860,13 @@ class PhiModel(PhiPreTrainedModel):
...
@@ -860,6 +860,13 @@ class PhiModel(PhiPreTrainedModel):
past_key_values_length
=
0
past_key_values_length
=
0
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
if
use_cache
:
if
use_cache
:
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
if
use_legacy_cache
:
if
use_legacy_cache
:
...
@@ -890,13 +897,6 @@ class PhiModel(PhiPreTrainedModel):
...
@@ -890,13 +897,6 @@ class PhiModel(PhiPreTrainedModel):
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
# decoder layers
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
all_self_attns
=
()
if
output_attentions
else
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment