Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8205b264
Unverified
Commit
8205b264
authored
Jan 11, 2024
by
jiqing-feng
Committed by
GitHub
Jan 11, 2024
Browse files
Assitant model may on a different device (#27995)
* Assitant model may on a different device * fix tensor device
parent
cbbe3074
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
2 deletions
+14
-2
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+8
-1
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+6
-1
No files found.
src/transformers/generation/candidate_generator.py
View file @
8205b264
...
@@ -96,6 +96,11 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -96,6 +96,11 @@ class AssistedCandidateGenerator(CandidateGenerator):
model_kwargs
:
Dict
,
model_kwargs
:
Dict
,
inputs_tensor
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_tensor
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
# Make sure all data at the same device as assistant model
device
=
assistant_model
.
device
input_ids
=
input_ids
.
to
(
device
)
inputs_tensor
=
inputs_tensor
.
to
(
device
)
# Prepare the assistant and the starting number of candidate tokens
# Prepare the assistant and the starting number of candidate tokens
self
.
assistant_model
=
assistant_model
self
.
assistant_model
=
assistant_model
self
.
num_assistant_tokens
=
assistant_model
.
generation_config
.
num_assistant_tokens
self
.
num_assistant_tokens
=
assistant_model
.
generation_config
.
num_assistant_tokens
...
@@ -104,7 +109,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -104,7 +109,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
assistant_kwargs
=
{}
assistant_kwargs
=
{}
for
key
,
value
in
model_kwargs
.
items
():
# deepcopy crashes if we attempt to copy encoder outputs with grads
for
key
,
value
in
model_kwargs
.
items
():
# deepcopy crashes if we attempt to copy encoder outputs with grads
if
key
not
in
(
"encoder_outputs"
,
"assistant_encoder_outputs"
):
if
key
not
in
(
"encoder_outputs"
,
"assistant_encoder_outputs"
):
assistant_kwargs
[
key
]
=
value
.
detach
()
if
isinstance
(
value
,
torch
.
Tensor
)
else
copy
.
deepcopy
(
value
)
assistant_kwargs
[
key
]
=
(
value
.
detach
().
to
(
device
)
if
isinstance
(
value
,
torch
.
Tensor
)
else
copy
.
deepcopy
(
value
)
)
if
"assistant_encoder_outputs"
in
model_kwargs
:
if
"assistant_encoder_outputs"
in
model_kwargs
:
assistant_kwargs
[
"encoder_outputs"
]
=
model_kwargs
[
"assistant_encoder_outputs"
]
assistant_kwargs
[
"encoder_outputs"
]
=
model_kwargs
[
"assistant_encoder_outputs"
]
...
...
src/transformers/generation/utils.py
View file @
8205b264
...
@@ -4585,7 +4585,12 @@ class GenerationMixin:
...
@@ -4585,7 +4585,12 @@ class GenerationMixin:
cur_len
=
input_ids
.
shape
[
-
1
]
cur_len
=
input_ids
.
shape
[
-
1
]
# 1. Fetch candidate sequences from a `CandidateGenerator`
# 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids
,
candidate_logits
=
candidate_generator
.
get_candidates
(
input_ids
)
candidate_input_ids
,
candidate_logits
=
candidate_generator
.
get_candidates
(
input_ids
.
to
(
candidate_generator
.
assistant_model
.
device
)
)
candidate_input_ids
=
candidate_input_ids
.
to
(
self
.
device
)
candidate_logits
=
candidate_logits
.
to
(
self
.
device
)
candidate_length
=
candidate_input_ids
.
shape
[
1
]
-
input_ids
.
shape
[
1
]
candidate_length
=
candidate_input_ids
.
shape
[
1
]
-
input_ids
.
shape
[
1
]
last_assistant_token_is_eos
=
(
last_assistant_token_is_eos
=
(
~
candidate_input_ids
[:,
-
1
]
~
candidate_input_ids
[:,
-
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment