Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
73de5108
Unverified
Commit
73de5108
authored
Dec 14, 2023
by
Younes Belkada
Committed by
GitHub
Dec 14, 2023
Browse files
[`core` / `modeling`] Fix training bug with PEFT + GC (#28031)
fix trainign bug
parent
2788f8d8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
35 additions
and
35 deletions
+35
-35
src/transformers/models/deprecated/open_llama/modeling_open_llama.py
...rmers/models/deprecated/open_llama/modeling_open_llama.py
+7
-7
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+7
-7
src/transformers/models/mistral/modeling_mistral.py
src/transformers/models/mistral/modeling_mistral.py
+7
-7
src/transformers/models/persimmon/modeling_persimmon.py
src/transformers/models/persimmon/modeling_persimmon.py
+7
-7
src/transformers/models/phi/modeling_phi.py
src/transformers/models/phi/modeling_phi.py
+7
-7
No files found.
src/transformers/models/deprecated/open_llama/modeling_open_llama.py
View file @
73de5108
...
@@ -578,6 +578,13 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
...
@@ -578,6 +578,13 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
seq_length_with_past
=
seq_length
seq_length_with_past
=
seq_length
past_key_values_length
=
0
past_key_values_length
=
0
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
if
past_key_values
is
not
None
:
if
past_key_values
is
not
None
:
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
...
@@ -608,13 +615,6 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
...
@@ -608,13 +615,6 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
# decoder layers
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
all_self_attns
=
()
if
output_attentions
else
None
...
...
src/transformers/models/llama/modeling_llama.py
View file @
73de5108
...
@@ -1000,6 +1000,13 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -1000,6 +1000,13 @@ class LlamaModel(LlamaPreTrainedModel):
else
:
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
past_key_values_length
=
0
past_key_values_length
=
0
if
use_cache
:
if
use_cache
:
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
...
@@ -1038,13 +1045,6 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -1038,13 +1045,6 @@ class LlamaModel(LlamaPreTrainedModel):
# embed positions
# embed positions
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
# decoder layers
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
all_self_attns
=
()
if
output_attentions
else
None
...
...
src/transformers/models/mistral/modeling_mistral.py
View file @
73de5108
...
@@ -855,6 +855,13 @@ class MistralModel(MistralPreTrainedModel):
...
@@ -855,6 +855,13 @@ class MistralModel(MistralPreTrainedModel):
else
:
else
:
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
past_key_values_length
=
0
past_key_values_length
=
0
if
use_cache
:
if
use_cache
:
...
@@ -899,13 +906,6 @@ class MistralModel(MistralPreTrainedModel):
...
@@ -899,13 +906,6 @@ class MistralModel(MistralPreTrainedModel):
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
# decoder layers
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
all_self_attns
=
()
if
output_attentions
else
None
...
...
src/transformers/models/persimmon/modeling_persimmon.py
View file @
73de5108
...
@@ -608,6 +608,13 @@ class PersimmonModel(PersimmonPreTrainedModel):
...
@@ -608,6 +608,13 @@ class PersimmonModel(PersimmonPreTrainedModel):
seq_length_with_past
=
seq_length
seq_length_with_past
=
seq_length
past_key_values_length
=
0
past_key_values_length
=
0
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
if
use_cache
:
if
use_cache
:
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
if
use_legacy_cache
:
if
use_legacy_cache
:
...
@@ -635,13 +642,6 @@ class PersimmonModel(PersimmonPreTrainedModel):
...
@@ -635,13 +642,6 @@ class PersimmonModel(PersimmonPreTrainedModel):
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
# decoder layers
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
all_self_attns
=
()
if
output_attentions
else
None
...
...
src/transformers/models/phi/modeling_phi.py
View file @
73de5108
...
@@ -860,6 +860,13 @@ class PhiModel(PhiPreTrainedModel):
...
@@ -860,6 +860,13 @@ class PhiModel(PhiPreTrainedModel):
past_key_values_length
=
0
past_key_values_length
=
0
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
if
use_cache
:
if
use_cache
:
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
if
use_legacy_cache
:
if
use_legacy_cache
:
...
@@ -890,13 +897,6 @@ class PhiModel(PhiPreTrainedModel):
...
@@ -890,13 +897,6 @@ class PhiModel(PhiPreTrainedModel):
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
# decoder layers
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
all_self_attns
=
()
if
output_attentions
else
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment