Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8252e24a
Unverified
Commit
8252e24a
authored
Mar 29, 2023
by
Younes Belkada
Committed by
GitHub
Mar 29, 2023
Browse files
[`Generate`] Add conditional generation for multimodal models (#22424)
* add conditional generation * add comments
parent
33f4cb10
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
0 deletions
+4
-0
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+4
-0
No files found.
src/transformers/generation/utils.py
View file @
8252e24a
...
@@ -1288,6 +1288,10 @@ class GenerationMixin:
...
@@ -1288,6 +1288,10 @@ class GenerationMixin:
model_kwargs
=
model_kwargs
,
model_kwargs
=
model_kwargs
,
device
=
inputs_tensor
.
device
,
device
=
inputs_tensor
.
device
,
)
)
# conditional generation for multi-modal models.
if
"input_ids"
in
model_kwargs
and
model_input_name
==
"pixel_values"
:
input_ids
=
torch
.
cat
([
input_ids
,
model_kwargs
.
pop
(
"input_ids"
)],
dim
=-
1
)
else
:
else
:
input_ids
=
inputs_tensor
if
model_input_name
==
"input_ids"
else
model_kwargs
.
pop
(
"input_ids"
)
input_ids
=
inputs_tensor
if
model_input_name
==
"input_ids"
else
model_kwargs
.
pop
(
"input_ids"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment