Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5ed9bd18
Unverified
Commit
5ed9bd18
authored
Oct 20, 2022
by
Joao Gante
Committed by
GitHub
Oct 20, 2022
Browse files
TF: sample generation compatible with XLA and dynamic batch sizes (#19773)
parent
c186e816
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
2 deletions
+3
-2
src/transformers/generation_tf_utils.py
src/transformers/generation_tf_utils.py
+3
-2
No files found.
src/transformers/generation_tf_utils.py
View file @
5ed9bd18
...
@@ -360,6 +360,7 @@ class TFGenerationMixin:
...
@@ -360,6 +360,7 @@ class TFGenerationMixin:
@
property
@
property
def
seed_generator
(
self
):
def
seed_generator
(
self
):
warnings
.
warn
(
"`seed_generator` is deprecated and will be removed in a future version."
,
UserWarning
)
if
self
.
_seed_generator
is
None
:
if
self
.
_seed_generator
is
None
:
self
.
_seed_generator
=
tf
.
random
.
Generator
.
from_non_deterministic_state
()
self
.
_seed_generator
=
tf
.
random
.
Generator
.
from_non_deterministic_state
()
return
self
.
_seed_generator
return
self
.
_seed_generator
...
@@ -1920,7 +1921,7 @@ class TFGenerationMixin:
...
@@ -1920,7 +1921,7 @@ class TFGenerationMixin:
**
model_kwargs
,
**
model_kwargs
,
)
->
Tuple
[
tf
.
Tensor
,
Dict
[
str
,
Any
]]:
)
->
Tuple
[
tf
.
Tensor
,
Dict
[
str
,
Any
]]:
expanded_return_idx
=
tf
.
reshape
(
expanded_return_idx
=
tf
.
reshape
(
tf
.
tile
(
tf
.
reshape
(
tf
.
range
(
input_ids
.
shape
[
0
]),
(
-
1
,
1
)),
(
1
,
expand_size
)),
(
-
1
,)
tf
.
tile
(
tf
.
reshape
(
tf
.
range
(
tf
.
shape
(
input_ids
)
[
0
]),
(
-
1
,
1
)),
(
1
,
expand_size
)),
(
-
1
,)
)
)
input_ids
=
tf
.
gather
(
input_ids
,
expanded_return_idx
,
axis
=
0
)
input_ids
=
tf
.
gather
(
input_ids
,
expanded_return_idx
,
axis
=
0
)
...
@@ -2624,7 +2625,7 @@ class TFGenerationMixin:
...
@@ -2624,7 +2625,7 @@ class TFGenerationMixin:
if
seed
is
not
None
:
if
seed
is
not
None
:
sample_seed
=
seed
sample_seed
=
seed
else
:
else
:
sample_seed
=
tf
.
cast
(
self
.
seed_generator
.
make_seeds
(
count
=
1
)[:,
0
]
,
dtype
=
tf
.
int32
)
sample_seed
=
tf
.
experimental
.
numpy
.
random
.
randint
(
tf
.
int32
.
min
,
tf
.
int32
.
max
,
(
2
,)
,
dtype
=
tf
.
int32
)
next_tokens
=
tf
.
squeeze
(
next_tokens
=
tf
.
squeeze
(
tf
.
random
.
stateless_categorical
(
tf
.
random
.
stateless_categorical
(
logits
=
next_tokens_scores
,
num_samples
=
1
,
seed
=
sample_seed
,
dtype
=
tf
.
int32
logits
=
next_tokens_scores
,
num_samples
=
1
,
seed
=
sample_seed
,
dtype
=
tf
.
int32
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment