Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bf3dfd11
Unverified
Commit
bf3dfd11
authored
Mar 18, 2024
by
Joao Gante
Committed by
GitHub
Mar 18, 2024
Browse files
CI / generate: batch size computation compatible with all models (#29671)
parent
00c1d87a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
20 deletions
+12
-20
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+12
-20
No files found.
src/transformers/generation/utils.py
View file @
bf3dfd11
...
@@ -1949,11 +1949,9 @@ class GenerationMixin:
...
@@ -1949,11 +1949,9 @@ class GenerationMixin:
)
)
# keep track of which sequences are already finished
# keep track of which sequences are already finished
batch_size
,
cur_len
=
(
batch_size
,
cur_len
=
input_ids
.
shape
model_kwargs
[
"attention_mask"
].
shape
if
"inputs_embeds"
in
model_kwargs
:
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
cur_len
=
model_kwargs
[
"inputs_embeds"
].
shape
[
1
]
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
...
@@ -2398,12 +2396,10 @@ class GenerationMixin:
...
@@ -2398,12 +2396,10 @@ class GenerationMixin:
)
)
# keep track of which sequences are already finished
# keep track of which sequences are already finished
batch_size
,
cur_len
=
input_ids
.
shape
if
"inputs_embeds"
in
model_kwargs
:
cur_len
=
model_kwargs
[
"inputs_embeds"
].
shape
[
1
]
this_peer_finished
=
False
this_peer_finished
=
False
batch_size
,
cur_len
=
(
model_kwargs
[
"attention_mask"
].
shape
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
...
@@ -2686,12 +2682,10 @@ class GenerationMixin:
...
@@ -2686,12 +2682,10 @@ class GenerationMixin:
)
)
# keep track of which sequences are already finished
# keep track of which sequences are already finished
batch_size
,
cur_len
=
input_ids
.
shape
if
"inputs_embeds"
in
model_kwargs
:
cur_len
=
model_kwargs
[
"inputs_embeds"
].
shape
[
1
]
this_peer_finished
=
False
this_peer_finished
=
False
batch_size
,
cur_len
=
(
model_kwargs
[
"attention_mask"
].
shape
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
...
@@ -4461,11 +4455,9 @@ class GenerationMixin:
...
@@ -4461,11 +4455,9 @@ class GenerationMixin:
)
)
# keep track of which sequences are already finished
# keep track of which sequences are already finished
batch_size
,
cur_len
=
batch_size
,
cur_len
=
(
batch_size
,
cur_len
=
input_ids
.
shape
model_kwargs
[
"attention_mask"
].
shape
if
"inputs_embeds"
in
model_kwargs
:
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
cur_len
=
model_kwargs
[
"inputs_embeds"
].
shape
[
1
]
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment