Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bdb9106f
Unverified
Commit
bdb9106f
authored
May 24, 2024
by
Pablo Montalvo
Committed by
GitHub
May 24, 2024
Browse files
Paligemma- fix devices and dtype assignments (#31008)
* fix devices and dtype assignments * [run-slow]paligemma
parent
deba7655
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
6 deletions
+9
-6
src/transformers/models/paligemma/modeling_paligemma.py
src/transformers/models/paligemma/modeling_paligemma.py
+9
-6
No files found.
src/transformers/models/paligemma/modeling_paligemma.py
View file @
bdb9106f
...
@@ -301,14 +301,15 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
...
@@ -301,14 +301,15 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
pad_mask
=
input_ids
==
self
.
pad_token_id
pad_mask
=
input_ids
==
self
.
pad_token_id
# expand masks to match embedding dimension
# expand masks to match embedding dimension
text_mask_expanded
=
text_mask
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
embed_dim
)
text_mask_expanded
=
text_mask
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
embed_dim
)
.
to
(
inputs_embeds
.
device
)
pad_mask_expanded
=
pad_mask
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
embed_dim
)
pad_mask_expanded
=
pad_mask
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
embed_dim
)
.
to
(
inputs_embeds
.
device
)
# insert padding and text token embeddings
# insert padding and text token embeddings
final_embedding
=
torch
.
where
(
text_mask_expanded
,
inputs_embeds
,
final_embedding
)
final_embedding
=
torch
.
where
(
text_mask_expanded
,
inputs_embeds
,
final_embedding
)
final_embedding
=
torch
.
where
(
pad_mask_expanded
,
torch
.
zeros_like
(
final_embedding
),
final_embedding
)
final_embedding
=
torch
.
where
(
pad_mask_expanded
,
torch
.
zeros_like
(
final_embedding
),
final_embedding
)
# insert image embeddings - the image mask is always less or equal to the sentence in length
# insert image embeddings - the image mask is always less or equal to the sentence in length
final_embedding
=
final_embedding
.
masked_scatter
(
final_embedding
=
final_embedding
.
masked_scatter
(
image_mask
.
unsqueeze
(
-
1
).
expand_as
(
final_embedding
),
scaled_image_features
image_mask
.
unsqueeze
(
-
1
).
expand_as
(
final_embedding
).
to
(
device
=
final_embedding
.
device
),
scaled_image_features
.
to
(
device
=
final_embedding
.
device
,
dtype
=
final_embedding
.
dtype
),
)
)
final_embedding
=
torch
.
where
(
pad_mask_expanded
,
torch
.
zeros_like
(
final_embedding
),
final_embedding
)
final_embedding
=
torch
.
where
(
pad_mask_expanded
,
torch
.
zeros_like
(
final_embedding
),
final_embedding
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
...
@@ -329,10 +330,12 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
...
@@ -329,10 +330,12 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
causal_mask
=
causal_mask
.
clone
()
# copy to contiguous memory for in-place edit
causal_mask
=
causal_mask
.
clone
()
# copy to contiguous memory for in-place edit
mask_length
=
attention_mask
.
shape
[
-
1
]
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[:,
:,
:,
:
mask_length
]
+
attention_mask
[:,
None
,
None
,
:]
padding_mask
=
causal_mask
[:,
:,
:,
:
mask_length
]
+
attention_mask
[:,
None
,
None
,
:].
to
(
causal_mask
.
device
)
# unmask the prefill
# unmask the prefill
causal_mask
[:,
:,
:,
:
mask_length
]
=
causal_mask
[:,
:,
:,
:
mask_length
].
masked_fill
(
causal_mask
[:,
:,
:,
:
mask_length
]
=
causal_mask
[:,
:,
:,
:
mask_length
].
masked_fill
(
token_type_ids
[:,
None
,
None
,
:]
==
0
,
0
token_type_ids
[:,
None
,
None
,
:]
.
to
(
causal_mask
.
device
)
==
0
,
0
)
)
padding_mask
=
padding_mask
==
0
padding_mask
=
padding_mask
==
0
causal_mask
[:,
:,
:,
:
mask_length
]
=
causal_mask
[:,
:,
:,
:
mask_length
].
masked_fill
(
causal_mask
[:,
:,
:,
:
mask_length
]
=
causal_mask
[:,
:,
:,
:
mask_length
].
masked_fill
(
...
@@ -484,7 +487,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
...
@@ -484,7 +487,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we use the input attention mask to shift the logits and labels, because it is 2D.
shift_attention_mask
=
input_attention_mask
[...,
1
:]
shift_attention_mask
=
input_attention_mask
[...,
1
:]
shift_logits
=
shift_logits
[
shift_attention_mask
.
to
(
logits
.
device
)
!=
0
].
contiguous
()
shift_logits
=
shift_logits
[
shift_attention_mask
.
to
(
logits
.
device
)
!=
0
].
contiguous
()
shift_labels
=
shift_labels
[
shift_attention_mask
.
to
(
logit
s
.
device
)
!=
0
].
contiguous
()
shift_labels
=
shift_labels
[
shift_attention_mask
.
to
(
shift_label
s
.
device
)
!=
0
].
contiguous
()
else
:
else
:
shift_logits
=
shift_logits
.
contiguous
()
shift_logits
=
shift_logits
.
contiguous
()
shift_labels
=
shift_labels
.
contiguous
()
shift_labels
=
shift_labels
.
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment