Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bc72b4e2
Unverified
Commit
bc72b4e2
authored
Jan 13, 2024
by
Joao Gante
Committed by
GitHub
Jan 13, 2024
Browse files
Generate: fix candidate device placement (#28493)
* fix candidate device * this line shouldn't have been in
parent
e304f976
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
4 deletions
+5
-4
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+2
-0
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+3
-4
No files found.
src/transformers/generation/candidate_generator.py
View file @
bc72b4e2
...
@@ -169,6 +169,8 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -169,6 +169,8 @@ class AssistedCandidateGenerator(CandidateGenerator):
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
vocabulary_size)` containing the logits associated to each candidate.
"""
"""
input_ids
=
input_ids
.
to
(
self
.
assistant_model
.
device
)
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values
=
self
.
assistant_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
has_past_key_values
=
self
.
assistant_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
...
...
src/transformers/generation/utils.py
View file @
bc72b4e2
...
@@ -4591,11 +4591,10 @@ class GenerationMixin:
...
@@ -4591,11 +4591,10 @@ class GenerationMixin:
cur_len
=
input_ids
.
shape
[
-
1
]
cur_len
=
input_ids
.
shape
[
-
1
]
# 1. Fetch candidate sequences from a `CandidateGenerator`
# 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids
,
candidate_logits
=
candidate_generator
.
get_candidates
(
candidate_input_ids
,
candidate_logits
=
candidate_generator
.
get_candidates
(
input_ids
)
input_ids
.
to
(
candidate_generator
.
assistant_model
.
device
)
)
candidate_input_ids
=
candidate_input_ids
.
to
(
self
.
device
)
candidate_input_ids
=
candidate_input_ids
.
to
(
self
.
device
)
candidate_logits
=
candidate_logits
.
to
(
self
.
device
)
if
candidate_logits
is
not
None
:
candidate_logits
=
candidate_logits
.
to
(
self
.
device
)
candidate_length
=
candidate_input_ids
.
shape
[
1
]
-
input_ids
.
shape
[
1
]
candidate_length
=
candidate_input_ids
.
shape
[
1
]
-
input_ids
.
shape
[
1
]
last_assistant_token_is_eos
=
(
last_assistant_token_is_eos
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment