Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7628b3a0
Unverified
Commit
7628b3a0
authored
Feb 28, 2024
by
Joao Gante
Committed by
GitHub
Feb 28, 2024
Browse files
Idefics: generate fix (#29320)
parent
2ce56d35
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
33 deletions
+21
-33
src/transformers/models/idefics/modeling_idefics.py
src/transformers/models/idefics/modeling_idefics.py
+21
-33
No files found.
src/transformers/models/idefics/modeling_idefics.py
View file @
7628b3a0
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
# limitations under the License.
# limitations under the License.
""" PyTorch Idefics model."""
""" PyTorch Idefics model."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -187,35 +187,6 @@ def expand_inputs_for_generation(
...
@@ -187,35 +187,6 @@ def expand_inputs_for_generation(
return
input_ids
,
model_kwargs
return
input_ids
,
model_kwargs
def
update_model_kwargs_for_generation
(
outputs
,
model_kwargs
):
# must have this key set to at least None
if
"past_key_values"
in
outputs
:
model_kwargs
[
"past_key_values"
]
=
outputs
.
past_key_values
else
:
model_kwargs
[
"past_key_values"
]
=
None
# update token_type_ids with last value
if
"token_type_ids"
in
model_kwargs
:
token_type_ids
=
model_kwargs
[
"token_type_ids"
]
model_kwargs
[
"token_type_ids"
]
=
torch
.
cat
([
token_type_ids
,
token_type_ids
[:,
-
1
].
unsqueeze
(
-
1
)],
dim
=-
1
)
# update attention masks
if
"attention_mask"
in
model_kwargs
:
attention_mask
=
model_kwargs
[
"attention_mask"
]
model_kwargs
[
"attention_mask"
]
=
torch
.
cat
(
[
attention_mask
,
attention_mask
.
new_ones
((
attention_mask
.
shape
[
0
],
1
))],
dim
=-
1
)
if
"image_attention_mask"
in
model_kwargs
:
image_attention_mask
=
model_kwargs
[
"image_attention_mask"
]
last_mask
=
image_attention_mask
[:,
-
1
,
:].
unsqueeze
(
1
)
model_kwargs
[
"image_attention_mask"
]
=
last_mask
# Get the precomputed image_hidden_states
model_kwargs
[
"image_hidden_states"
]
=
outputs
.
image_hidden_states
return
model_kwargs
def
prepare_inputs_for_generation
(
input_ids
,
past_key_values
=
None
,
**
kwargs
):
def
prepare_inputs_for_generation
(
input_ids
,
past_key_values
=
None
,
**
kwargs
):
token_type_ids
=
kwargs
.
get
(
"token_type_ids"
,
None
)
token_type_ids
=
kwargs
.
get
(
"token_type_ids"
,
None
)
# only last token for inputs_ids if past is defined in kwargs
# only last token for inputs_ids if past is defined in kwargs
...
@@ -1580,9 +1551,26 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
...
@@ -1580,9 +1551,26 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
):
):
return
expand_inputs_for_generation
(
*
args
,
**
model_kwargs
)
return
expand_inputs_for_generation
(
*
args
,
**
model_kwargs
)
@
staticmethod
def
_update_model_kwargs_for_generation
(
def
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
):
self
,
return
update_model_kwargs_for_generation
(
outputs
,
model_kwargs
)
outputs
:
ModelOutput
,
model_kwargs
:
Dict
[
str
,
Any
],
is_encoder_decoder
:
bool
=
False
,
standardize_cache_format
:
bool
=
False
,
model_inputs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
Dict
[
str
,
Any
]:
model_kwargs
=
super
().
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
,
standardize_cache_format
,
model_inputs
)
if
"image_attention_mask"
in
model_kwargs
:
image_attention_mask
=
model_kwargs
[
"image_attention_mask"
]
last_mask
=
image_attention_mask
[:,
-
1
,
:].
unsqueeze
(
1
)
model_kwargs
[
"image_attention_mask"
]
=
last_mask
# Get the precomputed image_hidden_states
model_kwargs
[
"image_hidden_states"
]
=
outputs
.
image_hidden_states
return
model_kwargs
@
staticmethod
@
staticmethod
def
_reorder_cache
(
past
,
beam_idx
):
def
_reorder_cache
(
past
,
beam_idx
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment