Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
360719a6
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "28f7ca1f807f0857c24f18c0b28b6b8ebee18c0a"
Unverified
Commit
360719a6
authored
Jul 06, 2022
by
Joao Gante
Committed by
GitHub
Jul 06, 2022
Browse files
TF: GPT-J compatible with XLA generation (#17986)
parent
bf37e5c7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
85 additions
and
142 deletions
+85
-142
src/transformers/models/gptj/modeling_tf_gptj.py
src/transformers/models/gptj/modeling_tf_gptj.py
+41
-49
tests/models/gptj/test_modeling_tf_gptj.py
tests/models/gptj/test_modeling_tf_gptj.py
+44
-93
No files found.
src/transformers/models/gptj/modeling_tf_gptj.py
View file @
360719a6
...
...
@@ -60,14 +60,12 @@ GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
def
fixed_pos_embedding
(
x
:
tf
.
Tensor
,
seq_dim
:
int
=
1
,
seq_len
:
Optional
[
int
]
=
None
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
dim
=
shape_list
(
x
)[
-
1
]
if
seq_len
is
None
:
seq_len
=
shape_list
(
x
)[
seq_dim
]
def
create_sinusoidal_positions
(
num_pos
:
int
,
dim
:
int
)
->
tf
.
Tensor
:
inv_freq
=
tf
.
cast
(
1.0
/
(
10000
**
(
tf
.
range
(
0
,
dim
,
2
)
/
dim
)),
tf
.
float32
)
seq_len_range
=
tf
.
cast
(
tf
.
range
(
seq_len
),
tf
.
float32
)
sinusoid_inp
=
tf
.
cast
(
tf
.
einsum
(
"i , j -> i j"
,
seq_len_range
,
inv_freq
),
tf
.
float32
)
return
tf
.
cast
(
tf
.
sin
(
sinusoid_inp
),
dtype
=
x
.
dtype
),
tf
.
cast
(
tf
.
cos
(
sinusoid_inp
),
dtype
=
x
.
dtype
)
sinusoid_inp
=
tf
.
cast
(
tf
.
einsum
(
"i , j -> i j"
,
tf
.
range
(
num_pos
,
dtype
=
tf
.
float32
),
inv_freq
),
tf
.
float32
)
sin
,
cos
=
tf
.
sin
(
sinusoid_inp
),
tf
.
cos
(
sinusoid_inp
)
out
=
tf
.
concat
((
sin
,
cos
),
axis
=
1
)
return
out
def
rotate_every_two
(
x
:
tf
.
Tensor
)
->
tf
.
Tensor
:
...
...
@@ -77,11 +75,11 @@ def rotate_every_two(x: tf.Tensor) -> tf.Tensor:
return
rotate_half_tensor
def
apply_rotary_pos_emb
(
x
:
tf
.
Tensor
,
sincos
:
tf
.
Tensor
,
offset
:
int
=
0
)
->
tf
.
Tensor
:
def
apply_rotary_pos_emb
(
tensor
:
tf
.
Tensor
,
sincos
:
tf
.
Tensor
)
->
tf
.
Tensor
:
sin_pos
,
cos_pos
=
sincos
sin_pos
=
tf
.
repeat
(
sin_pos
[
None
,
offset
:
shape_list
(
x
)[
1
]
+
offset
,
None
,
:],
2
,
3
)
cos_pos
=
tf
.
repeat
(
cos_pos
[
None
,
offset
:
shape_list
(
x
)[
1
]
+
offset
,
None
,
:],
2
,
3
)
return
(
x
*
cos_pos
)
+
(
rotate_every_two
(
x
)
*
sin_pos
)
sin_pos
=
tf
.
repeat
(
sin_pos
[
:,
:
,
None
,
:],
2
,
3
)
cos_pos
=
tf
.
repeat
(
cos_pos
[
:,
:
,
None
,
:],
2
,
3
)
return
(
tensor
*
cos_pos
)
+
(
rotate_every_two
(
tensor
)
*
sin_pos
)
class
TFGPTJAttention
(
tf
.
keras
.
layers
.
Layer
):
...
...
@@ -132,6 +130,8 @@ class TFGPTJAttention(tf.keras.layers.Layer):
tf
.
cast
(
tf
.
experimental
.
numpy
.
tril
(
tf
.
ones
((
self
.
max_positions
,
self
.
max_positions
))),
tf
.
int8
),
(
1
,
1
,
self
.
max_positions
,
self
.
max_positions
),
)
pos_embd_dim
=
self
.
rotary_dim
or
self
.
embed_dim
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
max_positions
,
pos_embd_dim
)
def
get_causal_mask
(
self
,
key_length
,
query_length
)
->
tf
.
Tensor
:
return
tf
.
cast
(
self
.
lower_triangle_mask
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
],
tf
.
bool
)
...
...
@@ -207,8 +207,9 @@ class TFGPTJAttention(tf.keras.layers.Layer):
def
call
(
self
,
hidden_states
:
tf
.
Tensor
,
attention_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
layer_past
:
Optional
[
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]]
=
None
,
attention_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
position_ids
:
Optional
[
tf
.
Tensor
]
=
None
,
head_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
use_cache
:
bool
=
False
,
output_attentions
:
bool
=
False
,
...
...
@@ -221,13 +222,8 @@ class TFGPTJAttention(tf.keras.layers.Layer):
key
=
self
.
_split_heads
(
key
,
True
)
value
=
self
.
_split_heads
(
value
,
False
)
seq_len
=
shape_list
(
key
)[
1
]
offset
=
0
if
layer_past
is
not
None
:
offset
=
shape_list
(
layer_past
[
0
])[
-
2
]
seq_len
+=
offset
sincos
=
tf
.
gather
(
self
.
embed_positions
,
position_ids
,
axis
=
0
)
sincos
=
tf
.
split
(
sincos
,
2
,
axis
=-
1
)
if
self
.
rotary_dim
is
not
None
:
k_rot
=
key
[:,
:,
:,
:
self
.
rotary_dim
]
k_pass
=
key
[:,
:,
:,
self
.
rotary_dim
:]
...
...
@@ -235,16 +231,14 @@ class TFGPTJAttention(tf.keras.layers.Layer):
q_rot
=
query
[:,
:,
:,
:
self
.
rotary_dim
]
q_pass
=
query
[:,
:,
:,
self
.
rotary_dim
:]
sincos
=
fixed_pos_embedding
(
k_rot
,
1
,
seq_len
=
seq_len
)
k_rot
=
apply_rotary_pos_emb
(
k_rot
,
sincos
,
offset
=
offset
)
q_rot
=
apply_rotary_pos_emb
(
q_rot
,
sincos
,
offset
=
offset
)
k_rot
=
apply_rotary_pos_emb
(
k_rot
,
sincos
)
q_rot
=
apply_rotary_pos_emb
(
q_rot
,
sincos
)
key
=
tf
.
concat
((
k_rot
,
k_pass
),
axis
=-
1
)
query
=
tf
.
concat
((
q_rot
,
q_pass
),
axis
=-
1
)
else
:
sincos
=
fixed_pos_embedding
(
key
,
1
,
seq_len
=
seq_len
)
key
=
apply_rotary_pos_emb
(
key
,
sincos
,
offset
=
offset
)
query
=
apply_rotary_pos_emb
(
query
,
sincos
,
offset
=
offset
)
key
=
apply_rotary_pos_emb
(
key
,
sincos
)
query
=
apply_rotary_pos_emb
(
query
,
sincos
)
key
=
tf
.
transpose
(
key
,
(
0
,
2
,
1
,
3
))
query
=
tf
.
transpose
(
query
,
(
0
,
2
,
1
,
3
))
...
...
@@ -310,6 +304,7 @@ class TFGPTJBlock(tf.keras.layers.Layer):
hidden_states
:
tf
.
Tensor
,
layer_past
:
Optional
[
tf
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
position_ids
:
Optional
[
tf
.
Tensor
]
=
None
,
head_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
use_cache
:
bool
=
False
,
output_attentions
:
bool
=
False
,
...
...
@@ -317,9 +312,10 @@ class TFGPTJBlock(tf.keras.layers.Layer):
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_outputs
=
self
.
attn
(
hidden_states
,
hidden_states
=
hidden_states
,
layer_past
=
layer_past
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
head_mask
=
head_mask
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
...
...
@@ -466,12 +462,13 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
all_hidden_states
=
all_hidden_states
+
(
tf
.
reshape
(
hidden_states
,
output_shape
),)
outputs
=
block
(
hidden_states
,
layer_past
,
attention_mask
,
head_mask
[
i
],
use_cache
,
output_attentions
,
hidden_states
=
hidden_states
,
layer_past
=
layer_past
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
head_mask
=
head_mask
[
i
],
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
training
=
training
,
)
...
...
@@ -722,8 +719,6 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
self
.
lm_head
=
tf
.
keras
.
layers
.
Dense
(
config
.
vocab_size
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"lm_head"
)
# TODO (Joao): investigate why GPTJ has numerical issues in XLA generate
self
.
supports_xla_generation
=
False
def
get_output_embeddings
(
self
):
return
self
.
lm_head
...
...
@@ -731,25 +726,21 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
def
prepare_inputs_for_generation
(
self
,
inputs
,
past
=
None
,
use_cache
=
None
,
use_xla
=
False
,
**
kwargs
):
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
# tests will need to be fixed after the change
def
prepare_inputs_for_generation
(
self
,
inputs
,
past
=
None
,
use_cache
=
None
,
**
kwargs
):
token_type_ids
=
kwargs
.
get
(
"token_type_ids"
,
None
)
# only last token for inputs_ids if past is defined in kwargs
if
past
:
inputs
=
tf
.
expand_dims
(
inputs
[:,
-
1
],
-
1
)
if
token_type_ids
is
not
None
:
token_type_ids
=
tf
.
expand_dims
(
token_type_ids
[:,
-
1
],
-
1
)
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
attention_mask
=
kwargs
.
get
(
"attention_mask"
,
None
)
# TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
# for a future PR to not change too many things for now.
# All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
position_ids
=
None
attention_mask
=
None
if
use_xla
:
attention_mask
=
kwargs
.
get
(
"attention_mask"
,
None
)
if
past
is
not
None
and
attention_mask
is
not
None
:
position_ids
=
tf
.
reduce_sum
(
attention_mask
,
axis
=
1
,
keepdims
=
True
)
-
1
elif
attention_mask
is
not
None
:
position_ids
=
tf
.
math
.
cumsum
(
attention_mask
,
axis
=
1
,
exclusive
=
True
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
position_ids
=
tf
.
math
.
cumsum
(
attention_mask
,
axis
=-
1
,
exclusive
=
True
)
if
past
:
position_ids
=
tf
.
expand_dims
(
position_ids
[:,
-
1
],
-
1
)
return
{
"input_ids"
:
inputs
,
...
...
@@ -757,6 +748,7 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
"position_ids"
:
position_ids
,
"past"
:
past
,
"use_cache"
:
use_cache
,
"token_type_ids"
:
token_type_ids
,
}
@
unpack_inputs
...
...
tests/models/gptj/test_modeling_tf_gptj.py
View file @
360719a6
...
...
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
datetime
import
unittest
from
transformers
import
AutoTokenizer
,
GPTJConfig
,
is_tf_available
...
...
@@ -48,6 +47,7 @@ class TFGPTJModelTester:
self
.
use_mc_token_ids
=
True
self
.
vocab_size
=
99
self
.
hidden_size
=
32
self
.
rotary_dim
=
4
self
.
num_hidden_layers
=
5
self
.
num_attention_heads
=
4
self
.
intermediate_size
=
37
...
...
@@ -103,6 +103,7 @@ class TFGPTJModelTester:
bos_token_id
=
self
.
bos_token_id
,
eos_token_id
=
self
.
eos_token_id
,
pad_token_id
=
self
.
pad_token_id
,
rotary_dim
=
self
.
rotary_dim
,
return_dict
=
True
,
)
...
...
@@ -359,10 +360,10 @@ class TFGPTJModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
@
require_tf
@
tooslow
# Marked as @tooslow due to GPU OOM -- but still useful to run locally. Requires ~39GB of RAM.
class
TFGPTJModelLanguageGenerationTest
(
unittest
.
TestCase
):
@
tooslow
def
test_lm_generate_gptj
(
self
):
# Marked as @tooslow due to GPU OOM
model
=
TFGPTJForCausalLM
.
from_pretrained
(
"EleutherAI/gpt-j-6B"
,
from_pt
=
True
)
input_ids
=
tf
.
convert_to_tensor
([[
464
,
3290
]],
dtype
=
tf
.
int32
)
# The dog
# fmt: off
...
...
@@ -372,74 +373,20 @@ class TFGPTJModelLanguageGenerationTest(unittest.TestCase):
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
False
)
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
@
tooslow
def
test_gptj_sample
(
self
):
# Marked as @tooslow due to GPU OOM (issue #13676)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"EleutherAI/gpt-j-6B"
,
revision
=
"float16"
)
model
=
TFGPTJForCausalLM
.
from_pretrained
(
"EleutherAI/gpt-j-6B"
,
revision
=
"float16"
,
from_pt
=
True
)
t
f
.
random
.
set_seed
(
0
)
tokenized
=
tokenizer
(
"Today is a nice day and"
,
return_tensors
=
"tf"
,
return_token_type_ids
=
True
)
input_ids
,
token_type_ids
=
tokenized
.
input_ids
,
tokenized
.
token_type_ids
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
True
)
t
okenized
=
tokenizer
(
"Today is a nice day and"
,
return_tensors
=
"tf"
)
# forces the generation to happen on CPU, to avoid GPU-related quirks
with
tf
.
device
(
":/CPU:0"
):
output_ids
=
model
.
generate
(
**
tokenized
,
do_sample
=
True
,
seed
=
[
42
,
0
]
)
output_str
=
tokenizer
.
decode
(
output_ids
[
0
],
skip_special_tokens
=
True
)
output_seq
=
model
.
generate
(
input_ids
=
input_ids
,
do_sample
=
True
,
num_return_sequences
=
5
)
output_seq_tt
=
model
.
generate
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
,
do_sample
=
True
,
num_return_sequences
=
5
)
output_seq_strs
=
tokenizer
.
batch_decode
(
output_seq
,
skip_special_tokens
=
True
)
output_seq_tt_strs
=
tokenizer
.
batch_decode
(
output_seq_tt
,
skip_special_tokens
=
True
)
EXPECTED_OUTPUT_STR
=
"Today is a nice day and I am taking an hour to sit in the hammock and just enjoy"
EXPECTED_OUTPUT_STR
=
"Today is a nice day and I’m going to go for a walk. I’"
self
.
assertEqual
(
output_str
,
EXPECTED_OUTPUT_STR
)
self
.
assertTrue
(
all
([
output_seq_strs
[
idx
]
!=
output_seq_tt_strs
[
idx
]
for
idx
in
range
(
len
(
output_seq_tt_strs
))])
)
# token_type_ids should change output
@
slow
@
unittest
.
skip
(
reason
=
"TF generate currently has no time-based stopping criteria"
)
def
test_gptj_sample_max_time
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"anton-l/gpt-j-tiny-random"
)
model
=
TFGPTJForCausalLM
.
from_pretrained
(
"anton-l/gpt-j-tiny-random"
,
from_pt
=
True
)
input_ids
=
tokenizer
(
"Today is a nice day and"
,
return_tensors
=
"tf"
,
return_token_type_ids
=
True
).
input_ids
MAX_TIME
=
0.5
start
=
datetime
.
datetime
.
now
()
model
.
generate
(
input_ids
,
do_sample
=
True
,
max_time
=
MAX_TIME
,
max_length
=
256
)
duration
=
datetime
.
datetime
.
now
()
-
start
self
.
assertGreater
(
duration
,
datetime
.
timedelta
(
seconds
=
MAX_TIME
))
self
.
assertLess
(
duration
,
datetime
.
timedelta
(
seconds
=
1.5
*
MAX_TIME
))
start
=
datetime
.
datetime
.
now
()
model
.
generate
(
input_ids
,
do_sample
=
False
,
max_time
=
MAX_TIME
,
max_length
=
256
)
duration
=
datetime
.
datetime
.
now
()
-
start
self
.
assertGreater
(
duration
,
datetime
.
timedelta
(
seconds
=
MAX_TIME
))
self
.
assertLess
(
duration
,
datetime
.
timedelta
(
seconds
=
1.5
*
MAX_TIME
))
start
=
datetime
.
datetime
.
now
()
model
.
generate
(
input_ids
,
do_sample
=
False
,
num_beams
=
2
,
max_time
=
MAX_TIME
,
max_length
=
256
)
duration
=
datetime
.
datetime
.
now
()
-
start
self
.
assertGreater
(
duration
,
datetime
.
timedelta
(
seconds
=
MAX_TIME
))
self
.
assertLess
(
duration
,
datetime
.
timedelta
(
seconds
=
1.5
*
MAX_TIME
))
start
=
datetime
.
datetime
.
now
()
model
.
generate
(
input_ids
,
do_sample
=
True
,
num_beams
=
2
,
max_time
=
MAX_TIME
,
max_length
=
256
)
duration
=
datetime
.
datetime
.
now
()
-
start
self
.
assertGreater
(
duration
,
datetime
.
timedelta
(
seconds
=
MAX_TIME
))
self
.
assertLess
(
duration
,
datetime
.
timedelta
(
seconds
=
1.5
*
MAX_TIME
))
start
=
datetime
.
datetime
.
now
()
model
.
generate
(
input_ids
,
do_sample
=
False
,
max_time
=
None
,
max_length
=
256
)
duration
=
datetime
.
datetime
.
now
()
-
start
self
.
assertGreater
(
duration
,
datetime
.
timedelta
(
seconds
=
1.5
*
MAX_TIME
))
@
tooslow
def
test_batch_generation
(
self
):
# Marked as @tooslow due to GPU OOM
def
_get_beam_search_test_objects
(
self
):
model
=
TFGPTJForCausalLM
.
from_pretrained
(
"EleutherAI/gpt-j-6B"
,
revision
=
"float16"
,
from_pt
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"EleutherAI/gpt-j-6B"
,
revision
=
"float16"
)
...
...
@@ -454,42 +401,46 @@ class TFGPTJModelLanguageGenerationTest(unittest.TestCase):
"Hello, my dog is a little"
,
"Today, I"
,
]
expected_output_sentences
=
[
"Hello, my dog is a little over a year old and has been diagnosed with hip dysplasia"
,
"Today, I’m going to be talking about a topic that’"
,
]
return
model
,
tokenizer
,
sentences
,
expected_output_sentences
inputs
=
tokenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
)
input_ids
=
inputs
[
"input_ids"
]
token_type_ids
=
tf
.
concat
(
[
tf
.
zeros
((
input_ids
.
shape
[
0
],
input_ids
.
shape
[
1
]
-
1
),
dtype
=
tf
.
int64
),
500
*
tf
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
tf
.
int64
),
],
axis
=-
1
,
)
def
test_batch_beam_search
(
self
):
# Confirms that we get the expected results with left-padded beam search
model
,
tokenizer
,
sentences
,
expected_output_sentences
=
self
.
_get_beam_search_test_objects
()
outputs
=
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
inputs
[
"attention_mask"
])
outputs_tt
=
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
inputs
[
"attention_mask"
],
token_type_ids
=
token_type_ids
,
)
inputs
=
tokenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
)
outputs
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
num_beams
=
2
)
batch_out_sentence
=
tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
expected_output_sentences
,
batch_out_sentence
)
inputs_non_padded
=
tokenizer
(
sentences
[
0
],
return_tensors
=
"tf"
).
input_ids
output_non_padded
=
model
.
generate
(
input_ids
=
inputs_non_padded
)
def
test_batch_left_padding
(
self
):
# Confirms that left-padding is working properly
model
,
tokenizer
,
sentences
,
expected_output_sentences
=
self
.
_get_beam_search_test_objects
()
inputs
=
tokenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
)
inputs_non_padded
=
tokenizer
(
sentences
[
0
],
return_tensors
=
"tf"
)
output_non_padded
=
model
.
generate
(
**
inputs_non_padded
,
do_sample
=
False
,
num_beams
=
2
)
num_paddings
=
(
shape_list
(
inputs_non_padded
)[
-
1
]
-
tf
.
reduce_sum
(
tf
.
cast
(
inputs
[
"attention_mask"
][
-
1
],
tf
.
int64
)).
numpy
()
shape_list
(
inputs_non_padded
[
"input_ids"
])[
-
1
]
-
tf
.
reduce_sum
(
tf
.
cast
(
inputs
[
"attention_mask"
][
-
1
],
tf
.
int64
)).
numpy
()
)
inputs_padded
=
tokenizer
(
sentences
[
1
],
return_tensors
=
"tf"
)
output_padded
=
model
.
generate
(
**
inputs_padded
,
do_sample
=
False
,
num_beams
=
2
,
max_length
=
model
.
config
.
max_length
-
num_paddings
)
inputs_padded
=
tokenizer
(
sentences
[
1
],
return_tensors
=
"tf"
).
input_ids
output_padded
=
model
.
generate
(
input_ids
=
inputs_padded
,
max_length
=
model
.
config
.
max_length
-
num_paddings
)
batch_out_sentence
=
tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
batch_out_sentence_tt
=
tokenizer
.
batch_decode
(
outputs_tt
,
skip_special_tokens
=
True
)
non_padded_sentence
=
tokenizer
.
decode
(
output_non_padded
[
0
],
skip_special_tokens
=
True
)
padded_sentence
=
tokenizer
.
decode
(
output_padded
[
0
],
skip_special_tokens
=
True
)
self
.
assertListEqual
(
expected_output_sentences
,
[
non_padded_sentence
,
padded_sentence
])
expected_output_sentence
=
[
"Hello, my dog is a little over a year old and has been diagnosed with a heart murmur"
,
"Today, I’m going to share with you a few of my favorite"
,
]
self
.
assertListEqual
(
expected_output_sentence
,
batch_out_sentence
)
self
.
assertTrue
(
batch_out_sentence_tt
!=
batch_out_sentence
)
# token_type_ids should change output
self
.
assertListEqual
(
expected_output_sentence
,
[
non_padded_sentence
,
padded_sentence
])
def
test_xla_beam_search
(
self
):
# Confirms that XLA is working properly
model
,
tokenizer
,
sentences
,
expected_output_sentences
=
self
.
_get_beam_search_test_objects
()
inputs
=
tokenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
)
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
outputs_xla
=
xla_generate
(
**
inputs
,
do_sample
=
False
,
num_beams
=
2
)
xla_sentence
=
tokenizer
.
batch_decode
(
outputs_xla
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
expected_output_sentences
,
xla_sentence
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment