Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
60b69f7d
Unverified
Commit
60b69f7d
authored
Jun 12, 2023
by
Joao Gante
Committed by
GitHub
Jun 12, 2023
Browse files
Generate: detect special architectures when loaded from PEFT (#24198)
parent
97527898
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
3 deletions
+17
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+17
-3
No files found.
src/transformers/generation/utils.py
View file @
60b69f7d
...
...
@@ -4232,6 +4232,15 @@ class GenerationMixin:
# other auxiliary variables
max_len
=
stopping_criteria
[
0
].
max_length
assistant_kv_indexing
=
(
1
if
"bloom"
in
assistant_model
.
__class__
.
__name__
.
lower
()
or
(
assistant_model
.
config
.
architectures
is
not
None
and
"bloom"
in
assistant_model
.
config
.
architectures
[
0
].
lower
()
)
else
0
)
this_peer_finished
=
False
# used by synced_gpus only
while
True
:
...
...
@@ -4247,7 +4256,6 @@ class GenerationMixin:
# Assistant: main logic start
cur_len
=
input_ids
.
shape
[
-
1
]
assistant_kv_indexing
=
0
if
"bloom"
not
in
assistant_model
.
__class__
.
__name__
.
lower
()
else
1
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
...
...
@@ -4512,7 +4520,10 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
)
)
past_key_values
=
tuple
(
new_past
)
elif
"bloom"
in
model
.
__class__
.
__name__
.
lower
():
# bloom is special
# bloom is special
elif
"bloom"
in
model
.
__class__
.
__name__
.
lower
()
or
(
model
.
config
.
architectures
is
not
None
and
"bloom"
in
model
.
config
.
architectures
[
0
].
lower
()
):
for
idx
in
range
(
len
(
past_key_values
)):
new_past
.
append
(
(
...
...
@@ -4521,7 +4532,10 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
)
)
past_key_values
=
tuple
(
new_past
)
elif
"gptbigcode"
in
model
.
__class__
.
__name__
.
lower
():
# gptbigcode is too
# gptbigcode is too
elif
"gptbigcode"
in
model
.
__class__
.
__name__
.
lower
()
or
(
model
.
config
.
architectures
is
not
None
and
"gptbigcode"
in
model
.
config
.
architectures
[
0
].
lower
()
):
if
model
.
config
.
multi_query
:
for
idx
in
range
(
len
(
past_key_values
)):
past_key_values
[
idx
]
=
past_key_values
[
idx
][:,
:
maximum_length
,
:]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment