Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
360719a6
Unverified
Commit
360719a6
authored
Jul 06, 2022
by
Joao Gante
Committed by
GitHub
Jul 06, 2022
Browse files
TF: GPT-J compatible with XLA generation (#17986)
parent
bf37e5c7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
85 additions
and
142 deletions
+85
-142
src/transformers/models/gptj/modeling_tf_gptj.py
src/transformers/models/gptj/modeling_tf_gptj.py
+41
-49
tests/models/gptj/test_modeling_tf_gptj.py
tests/models/gptj/test_modeling_tf_gptj.py
+44
-93
No files found.
src/transformers/models/gptj/modeling_tf_gptj.py
View file @
360719a6
...
@@ -60,14 +60,12 @@ GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
...
@@ -60,14 +60,12 @@ GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
]
def
fixed_pos_embedding
(
x
:
tf
.
Tensor
,
seq_dim
:
int
=
1
,
seq_len
:
Optional
[
int
]
=
None
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
def
create_sinusoidal_positions
(
num_pos
:
int
,
dim
:
int
)
->
tf
.
Tensor
:
dim
=
shape_list
(
x
)[
-
1
]
if
seq_len
is
None
:
seq_len
=
shape_list
(
x
)[
seq_dim
]
inv_freq
=
tf
.
cast
(
1.0
/
(
10000
**
(
tf
.
range
(
0
,
dim
,
2
)
/
dim
)),
tf
.
float32
)
inv_freq
=
tf
.
cast
(
1.0
/
(
10000
**
(
tf
.
range
(
0
,
dim
,
2
)
/
dim
)),
tf
.
float32
)
seq_len_range
=
tf
.
cast
(
tf
.
range
(
seq_len
),
tf
.
float32
)
sinusoid_inp
=
tf
.
cast
(
tf
.
einsum
(
"i , j -> i j"
,
tf
.
range
(
num_pos
,
dtype
=
tf
.
float32
),
inv_freq
),
tf
.
float32
)
sinusoid_inp
=
tf
.
cast
(
tf
.
einsum
(
"i , j -> i j"
,
seq_len_range
,
inv_freq
),
tf
.
float32
)
sin
,
cos
=
tf
.
sin
(
sinusoid_inp
),
tf
.
cos
(
sinusoid_inp
)
return
tf
.
cast
(
tf
.
sin
(
sinusoid_inp
),
dtype
=
x
.
dtype
),
tf
.
cast
(
tf
.
cos
(
sinusoid_inp
),
dtype
=
x
.
dtype
)
out
=
tf
.
concat
((
sin
,
cos
),
axis
=
1
)
return
out
def
rotate_every_two
(
x
:
tf
.
Tensor
)
->
tf
.
Tensor
:
def
rotate_every_two
(
x
:
tf
.
Tensor
)
->
tf
.
Tensor
:
...
@@ -77,11 +75,11 @@ def rotate_every_two(x: tf.Tensor) -> tf.Tensor:
...
@@ -77,11 +75,11 @@ def rotate_every_two(x: tf.Tensor) -> tf.Tensor:
return
rotate_half_tensor
return
rotate_half_tensor
def
apply_rotary_pos_emb
(
x
:
tf
.
Tensor
,
sincos
:
tf
.
Tensor
,
offset
:
int
=
0
)
->
tf
.
Tensor
:
def
apply_rotary_pos_emb
(
tensor
:
tf
.
Tensor
,
sincos
:
tf
.
Tensor
)
->
tf
.
Tensor
:
sin_pos
,
cos_pos
=
sincos
sin_pos
,
cos_pos
=
sincos
sin_pos
=
tf
.
repeat
(
sin_pos
[
None
,
offset
:
shape_list
(
x
)[
1
]
+
offset
,
None
,
:],
2
,
3
)
sin_pos
=
tf
.
repeat
(
sin_pos
[
:,
:
,
None
,
:],
2
,
3
)
cos_pos
=
tf
.
repeat
(
cos_pos
[
None
,
offset
:
shape_list
(
x
)[
1
]
+
offset
,
None
,
:],
2
,
3
)
cos_pos
=
tf
.
repeat
(
cos_pos
[
:,
:
,
None
,
:],
2
,
3
)
return
(
x
*
cos_pos
)
+
(
rotate_every_two
(
x
)
*
sin_pos
)
return
(
tensor
*
cos_pos
)
+
(
rotate_every_two
(
tensor
)
*
sin_pos
)
class
TFGPTJAttention
(
tf
.
keras
.
layers
.
Layer
):
class
TFGPTJAttention
(
tf
.
keras
.
layers
.
Layer
):
...
@@ -132,6 +130,8 @@ class TFGPTJAttention(tf.keras.layers.Layer):
...
@@ -132,6 +130,8 @@ class TFGPTJAttention(tf.keras.layers.Layer):
tf
.
cast
(
tf
.
experimental
.
numpy
.
tril
(
tf
.
ones
((
self
.
max_positions
,
self
.
max_positions
))),
tf
.
int8
),
tf
.
cast
(
tf
.
experimental
.
numpy
.
tril
(
tf
.
ones
((
self
.
max_positions
,
self
.
max_positions
))),
tf
.
int8
),
(
1
,
1
,
self
.
max_positions
,
self
.
max_positions
),
(
1
,
1
,
self
.
max_positions
,
self
.
max_positions
),
)
)
pos_embd_dim
=
self
.
rotary_dim
or
self
.
embed_dim
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
max_positions
,
pos_embd_dim
)
def
get_causal_mask
(
self
,
key_length
,
query_length
)
->
tf
.
Tensor
:
def
get_causal_mask
(
self
,
key_length
,
query_length
)
->
tf
.
Tensor
:
return
tf
.
cast
(
self
.
lower_triangle_mask
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
],
tf
.
bool
)
return
tf
.
cast
(
self
.
lower_triangle_mask
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
],
tf
.
bool
)
...
@@ -207,8 +207,9 @@ class TFGPTJAttention(tf.keras.layers.Layer):
...
@@ -207,8 +207,9 @@ class TFGPTJAttention(tf.keras.layers.Layer):
def
call
(
def
call
(
self
,
self
,
hidden_states
:
tf
.
Tensor
,
hidden_states
:
tf
.
Tensor
,
attention_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
layer_past
:
Optional
[
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]]
=
None
,
layer_past
:
Optional
[
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]]
=
None
,
attention_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
position_ids
:
Optional
[
tf
.
Tensor
]
=
None
,
head_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
head_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
use_cache
:
bool
=
False
,
use_cache
:
bool
=
False
,
output_attentions
:
bool
=
False
,
output_attentions
:
bool
=
False
,
...
@@ -221,13 +222,8 @@ class TFGPTJAttention(tf.keras.layers.Layer):
...
@@ -221,13 +222,8 @@ class TFGPTJAttention(tf.keras.layers.Layer):
key
=
self
.
_split_heads
(
key
,
True
)
key
=
self
.
_split_heads
(
key
,
True
)
value
=
self
.
_split_heads
(
value
,
False
)
value
=
self
.
_split_heads
(
value
,
False
)
seq_len
=
shape_list
(
key
)[
1
]
sincos
=
tf
.
gather
(
self
.
embed_positions
,
position_ids
,
axis
=
0
)
offset
=
0
sincos
=
tf
.
split
(
sincos
,
2
,
axis
=-
1
)
if
layer_past
is
not
None
:
offset
=
shape_list
(
layer_past
[
0
])[
-
2
]
seq_len
+=
offset
if
self
.
rotary_dim
is
not
None
:
if
self
.
rotary_dim
is
not
None
:
k_rot
=
key
[:,
:,
:,
:
self
.
rotary_dim
]
k_rot
=
key
[:,
:,
:,
:
self
.
rotary_dim
]
k_pass
=
key
[:,
:,
:,
self
.
rotary_dim
:]
k_pass
=
key
[:,
:,
:,
self
.
rotary_dim
:]
...
@@ -235,16 +231,14 @@ class TFGPTJAttention(tf.keras.layers.Layer):
...
@@ -235,16 +231,14 @@ class TFGPTJAttention(tf.keras.layers.Layer):
q_rot
=
query
[:,
:,
:,
:
self
.
rotary_dim
]
q_rot
=
query
[:,
:,
:,
:
self
.
rotary_dim
]
q_pass
=
query
[:,
:,
:,
self
.
rotary_dim
:]
q_pass
=
query
[:,
:,
:,
self
.
rotary_dim
:]
sincos
=
fixed_pos_embedding
(
k_rot
,
1
,
seq_len
=
seq_len
)
k_rot
=
apply_rotary_pos_emb
(
k_rot
,
sincos
)
k_rot
=
apply_rotary_pos_emb
(
k_rot
,
sincos
,
offset
=
offset
)
q_rot
=
apply_rotary_pos_emb
(
q_rot
,
sincos
)
q_rot
=
apply_rotary_pos_emb
(
q_rot
,
sincos
,
offset
=
offset
)
key
=
tf
.
concat
((
k_rot
,
k_pass
),
axis
=-
1
)
key
=
tf
.
concat
((
k_rot
,
k_pass
),
axis
=-
1
)
query
=
tf
.
concat
((
q_rot
,
q_pass
),
axis
=-
1
)
query
=
tf
.
concat
((
q_rot
,
q_pass
),
axis
=-
1
)
else
:
else
:
sincos
=
fixed_pos_embedding
(
key
,
1
,
seq_len
=
seq_len
)
key
=
apply_rotary_pos_emb
(
key
,
sincos
)
key
=
apply_rotary_pos_emb
(
key
,
sincos
,
offset
=
offset
)
query
=
apply_rotary_pos_emb
(
query
,
sincos
)
query
=
apply_rotary_pos_emb
(
query
,
sincos
,
offset
=
offset
)
key
=
tf
.
transpose
(
key
,
(
0
,
2
,
1
,
3
))
key
=
tf
.
transpose
(
key
,
(
0
,
2
,
1
,
3
))
query
=
tf
.
transpose
(
query
,
(
0
,
2
,
1
,
3
))
query
=
tf
.
transpose
(
query
,
(
0
,
2
,
1
,
3
))
...
@@ -310,6 +304,7 @@ class TFGPTJBlock(tf.keras.layers.Layer):
...
@@ -310,6 +304,7 @@ class TFGPTJBlock(tf.keras.layers.Layer):
hidden_states
:
tf
.
Tensor
,
hidden_states
:
tf
.
Tensor
,
layer_past
:
Optional
[
tf
.
Tensor
]
=
None
,
layer_past
:
Optional
[
tf
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
position_ids
:
Optional
[
tf
.
Tensor
]
=
None
,
head_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
head_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
use_cache
:
bool
=
False
,
use_cache
:
bool
=
False
,
output_attentions
:
bool
=
False
,
output_attentions
:
bool
=
False
,
...
@@ -317,9 +312,10 @@ class TFGPTJBlock(tf.keras.layers.Layer):
...
@@ -317,9 +312,10 @@ class TFGPTJBlock(tf.keras.layers.Layer):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_outputs
=
self
.
attn
(
attn_outputs
=
self
.
attn
(
hidden_states
,
hidden_states
=
hidden_states
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
use_cache
=
use_cache
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
...
@@ -466,12 +462,13 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
...
@@ -466,12 +462,13 @@ class TFGPTJMainLayer(tf.keras.layers.Layer):
all_hidden_states
=
all_hidden_states
+
(
tf
.
reshape
(
hidden_states
,
output_shape
),)
all_hidden_states
=
all_hidden_states
+
(
tf
.
reshape
(
hidden_states
,
output_shape
),)
outputs
=
block
(
outputs
=
block
(
hidden_states
,
hidden_states
=
hidden_states
,
layer_past
,
layer_past
=
layer_past
,
attention_mask
,
attention_mask
=
attention_mask
,
head_mask
[
i
],
position_ids
=
position_ids
,
use_cache
,
head_mask
=
head_mask
[
i
],
output_attentions
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
training
=
training
,
training
=
training
,
)
)
...
@@ -722,8 +719,6 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -722,8 +719,6 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
self
.
lm_head
=
tf
.
keras
.
layers
.
Dense
(
self
.
lm_head
=
tf
.
keras
.
layers
.
Dense
(
config
.
vocab_size
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"lm_head"
config
.
vocab_size
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"lm_head"
)
)
# TODO (Joao): investigate why GPTJ has numerical issues in XLA generate
self
.
supports_xla_generation
=
False
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
lm_head
return
self
.
lm_head
...
@@ -731,25 +726,21 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -731,25 +726,21 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
def
set_output_embeddings
(
self
,
new_embeddings
):
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
self
.
lm_head
=
new_embeddings
def
prepare_inputs_for_generation
(
self
,
inputs
,
past
=
None
,
use_cache
=
None
,
use_xla
=
False
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
inputs
,
past
=
None
,
use_cache
=
None
,
**
kwargs
):
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
token_type_ids
=
kwargs
.
get
(
"token_type_ids"
,
None
)
# tests will need to be fixed after the change
# only last token for inputs_ids if past is defined in kwargs
# only last token for inputs_ids if past is defined in kwargs
if
past
:
if
past
:
inputs
=
tf
.
expand_dims
(
inputs
[:,
-
1
],
-
1
)
inputs
=
tf
.
expand_dims
(
inputs
[:,
-
1
],
-
1
)
if
token_type_ids
is
not
None
:
token_type_ids
=
tf
.
expand_dims
(
token_type_ids
[:,
-
1
],
-
1
)
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
attention_mask
=
kwargs
.
get
(
"attention_mask"
,
None
)
# TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
if
attention_mask
is
not
None
and
position_ids
is
None
:
# for a future PR to not change too many things for now.
position_ids
=
tf
.
math
.
cumsum
(
attention_mask
,
axis
=-
1
,
exclusive
=
True
)
# All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
if
past
:
position_ids
=
None
position_ids
=
tf
.
expand_dims
(
position_ids
[:,
-
1
],
-
1
)
attention_mask
=
None
if
use_xla
:
attention_mask
=
kwargs
.
get
(
"attention_mask"
,
None
)
if
past
is
not
None
and
attention_mask
is
not
None
:
position_ids
=
tf
.
reduce_sum
(
attention_mask
,
axis
=
1
,
keepdims
=
True
)
-
1
elif
attention_mask
is
not
None
:
position_ids
=
tf
.
math
.
cumsum
(
attention_mask
,
axis
=
1
,
exclusive
=
True
)
return
{
return
{
"input_ids"
:
inputs
,
"input_ids"
:
inputs
,
...
@@ -757,6 +748,7 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -757,6 +748,7 @@ class TFGPTJForCausalLM(TFGPTJPreTrainedModel, TFCausalLanguageModelingLoss):
"position_ids"
:
position_ids
,
"position_ids"
:
position_ids
,
"past"
:
past
,
"past"
:
past
,
"use_cache"
:
use_cache
,
"use_cache"
:
use_cache
,
"token_type_ids"
:
token_type_ids
,
}
}
@
unpack_inputs
@
unpack_inputs
...
...
tests/models/gptj/test_modeling_tf_gptj.py
View file @
360719a6
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
datetime
import
unittest
import
unittest
from
transformers
import
AutoTokenizer
,
GPTJConfig
,
is_tf_available
from
transformers
import
AutoTokenizer
,
GPTJConfig
,
is_tf_available
...
@@ -48,6 +47,7 @@ class TFGPTJModelTester:
...
@@ -48,6 +47,7 @@ class TFGPTJModelTester:
self
.
use_mc_token_ids
=
True
self
.
use_mc_token_ids
=
True
self
.
vocab_size
=
99
self
.
vocab_size
=
99
self
.
hidden_size
=
32
self
.
hidden_size
=
32
self
.
rotary_dim
=
4
self
.
num_hidden_layers
=
5
self
.
num_hidden_layers
=
5
self
.
num_attention_heads
=
4
self
.
num_attention_heads
=
4
self
.
intermediate_size
=
37
self
.
intermediate_size
=
37
...
@@ -103,6 +103,7 @@ class TFGPTJModelTester:
...
@@ -103,6 +103,7 @@ class TFGPTJModelTester:
bos_token_id
=
self
.
bos_token_id
,
bos_token_id
=
self
.
bos_token_id
,
eos_token_id
=
self
.
eos_token_id
,
eos_token_id
=
self
.
eos_token_id
,
pad_token_id
=
self
.
pad_token_id
,
pad_token_id
=
self
.
pad_token_id
,
rotary_dim
=
self
.
rotary_dim
,
return_dict
=
True
,
return_dict
=
True
,
)
)
...
@@ -359,10 +360,10 @@ class TFGPTJModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
...
@@ -359,10 +360,10 @@ class TFGPTJModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
@
require_tf
@
require_tf
@
tooslow
# Marked as @tooslow due to GPU OOM -- but still useful to run locally. Requires ~39GB of RAM.
class
TFGPTJModelLanguageGenerationTest
(
unittest
.
TestCase
):
class
TFGPTJModelLanguageGenerationTest
(
unittest
.
TestCase
):
@
tooslow
def
test_lm_generate_gptj
(
self
):
def
test_lm_generate_gptj
(
self
):
# Marked as @tooslow due to GPU OOM
model
=
TFGPTJForCausalLM
.
from_pretrained
(
"EleutherAI/gpt-j-6B"
,
from_pt
=
True
)
model
=
TFGPTJForCausalLM
.
from_pretrained
(
"EleutherAI/gpt-j-6B"
,
from_pt
=
True
)
input_ids
=
tf
.
convert_to_tensor
([[
464
,
3290
]],
dtype
=
tf
.
int32
)
# The dog
input_ids
=
tf
.
convert_to_tensor
([[
464
,
3290
]],
dtype
=
tf
.
int32
)
# The dog
# fmt: off
# fmt: off
...
@@ -372,74 +373,20 @@ class TFGPTJModelLanguageGenerationTest(unittest.TestCase):
...
@@ -372,74 +373,20 @@ class TFGPTJModelLanguageGenerationTest(unittest.TestCase):
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
False
)
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
False
)
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
self
.
assertListEqual
(
output_ids
[
0
].
numpy
().
tolist
(),
expected_output_ids
)
@
tooslow
def
test_gptj_sample
(
self
):
def
test_gptj_sample
(
self
):
# Marked as @tooslow due to GPU OOM (issue #13676)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"EleutherAI/gpt-j-6B"
,
revision
=
"float16"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"EleutherAI/gpt-j-6B"
,
revision
=
"float16"
)
model
=
TFGPTJForCausalLM
.
from_pretrained
(
"EleutherAI/gpt-j-6B"
,
revision
=
"float16"
,
from_pt
=
True
)
model
=
TFGPTJForCausalLM
.
from_pretrained
(
"EleutherAI/gpt-j-6B"
,
revision
=
"float16"
,
from_pt
=
True
)
t
f
.
random
.
set_seed
(
0
)
t
okenized
=
tokenizer
(
"Today is a nice day and"
,
return_tensors
=
"tf"
)
tokenized
=
tokenizer
(
"Today is a nice day and"
,
return_tensors
=
"tf"
,
return_token_type_ids
=
True
)
# forces the generation to happen on CPU, to avoid GPU-related quirks
input_ids
,
token_type_ids
=
tokenized
.
input_ids
,
tokenized
.
token_type_ids
with
tf
.
device
(
":/CPU:0"
):
output_ids
=
model
.
generate
(
input_ids
,
do_sample
=
True
)
output_ids
=
model
.
generate
(
**
tokenized
,
do_sample
=
True
,
seed
=
[
42
,
0
]
)
output_str
=
tokenizer
.
decode
(
output_ids
[
0
],
skip_special_tokens
=
True
)
output_str
=
tokenizer
.
decode
(
output_ids
[
0
],
skip_special_tokens
=
True
)
output_seq
=
model
.
generate
(
input_ids
=
input_ids
,
do_sample
=
True
,
num_return_sequences
=
5
)
EXPECTED_OUTPUT_STR
=
"Today is a nice day and I’m going to go for a walk. I’"
output_seq_tt
=
model
.
generate
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
,
do_sample
=
True
,
num_return_sequences
=
5
)
output_seq_strs
=
tokenizer
.
batch_decode
(
output_seq
,
skip_special_tokens
=
True
)
output_seq_tt_strs
=
tokenizer
.
batch_decode
(
output_seq_tt
,
skip_special_tokens
=
True
)
EXPECTED_OUTPUT_STR
=
"Today is a nice day and I am taking an hour to sit in the hammock and just enjoy"
self
.
assertEqual
(
output_str
,
EXPECTED_OUTPUT_STR
)
self
.
assertEqual
(
output_str
,
EXPECTED_OUTPUT_STR
)
self
.
assertTrue
(
all
([
output_seq_strs
[
idx
]
!=
output_seq_tt_strs
[
idx
]
for
idx
in
range
(
len
(
output_seq_tt_strs
))])
)
# token_type_ids should change output
@
slow
def
_get_beam_search_test_objects
(
self
):
@
unittest
.
skip
(
reason
=
"TF generate currently has no time-based stopping criteria"
)
def
test_gptj_sample_max_time
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"anton-l/gpt-j-tiny-random"
)
model
=
TFGPTJForCausalLM
.
from_pretrained
(
"anton-l/gpt-j-tiny-random"
,
from_pt
=
True
)
input_ids
=
tokenizer
(
"Today is a nice day and"
,
return_tensors
=
"tf"
,
return_token_type_ids
=
True
).
input_ids
MAX_TIME
=
0.5
start
=
datetime
.
datetime
.
now
()
model
.
generate
(
input_ids
,
do_sample
=
True
,
max_time
=
MAX_TIME
,
max_length
=
256
)
duration
=
datetime
.
datetime
.
now
()
-
start
self
.
assertGreater
(
duration
,
datetime
.
timedelta
(
seconds
=
MAX_TIME
))
self
.
assertLess
(
duration
,
datetime
.
timedelta
(
seconds
=
1.5
*
MAX_TIME
))
start
=
datetime
.
datetime
.
now
()
model
.
generate
(
input_ids
,
do_sample
=
False
,
max_time
=
MAX_TIME
,
max_length
=
256
)
duration
=
datetime
.
datetime
.
now
()
-
start
self
.
assertGreater
(
duration
,
datetime
.
timedelta
(
seconds
=
MAX_TIME
))
self
.
assertLess
(
duration
,
datetime
.
timedelta
(
seconds
=
1.5
*
MAX_TIME
))
start
=
datetime
.
datetime
.
now
()
model
.
generate
(
input_ids
,
do_sample
=
False
,
num_beams
=
2
,
max_time
=
MAX_TIME
,
max_length
=
256
)
duration
=
datetime
.
datetime
.
now
()
-
start
self
.
assertGreater
(
duration
,
datetime
.
timedelta
(
seconds
=
MAX_TIME
))
self
.
assertLess
(
duration
,
datetime
.
timedelta
(
seconds
=
1.5
*
MAX_TIME
))
start
=
datetime
.
datetime
.
now
()
model
.
generate
(
input_ids
,
do_sample
=
True
,
num_beams
=
2
,
max_time
=
MAX_TIME
,
max_length
=
256
)
duration
=
datetime
.
datetime
.
now
()
-
start
self
.
assertGreater
(
duration
,
datetime
.
timedelta
(
seconds
=
MAX_TIME
))
self
.
assertLess
(
duration
,
datetime
.
timedelta
(
seconds
=
1.5
*
MAX_TIME
))
start
=
datetime
.
datetime
.
now
()
model
.
generate
(
input_ids
,
do_sample
=
False
,
max_time
=
None
,
max_length
=
256
)
duration
=
datetime
.
datetime
.
now
()
-
start
self
.
assertGreater
(
duration
,
datetime
.
timedelta
(
seconds
=
1.5
*
MAX_TIME
))
@
tooslow
def
test_batch_generation
(
self
):
# Marked as @tooslow due to GPU OOM
model
=
TFGPTJForCausalLM
.
from_pretrained
(
"EleutherAI/gpt-j-6B"
,
revision
=
"float16"
,
from_pt
=
True
)
model
=
TFGPTJForCausalLM
.
from_pretrained
(
"EleutherAI/gpt-j-6B"
,
revision
=
"float16"
,
from_pt
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"EleutherAI/gpt-j-6B"
,
revision
=
"float16"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"EleutherAI/gpt-j-6B"
,
revision
=
"float16"
)
...
@@ -454,42 +401,46 @@ class TFGPTJModelLanguageGenerationTest(unittest.TestCase):
...
@@ -454,42 +401,46 @@ class TFGPTJModelLanguageGenerationTest(unittest.TestCase):
"Hello, my dog is a little"
,
"Hello, my dog is a little"
,
"Today, I"
,
"Today, I"
,
]
]
expected_output_sentences
=
[
"Hello, my dog is a little over a year old and has been diagnosed with hip dysplasia"
,
"Today, I’m going to be talking about a topic that’"
,
]
return
model
,
tokenizer
,
sentences
,
expected_output_sentences
inputs
=
tokenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
)
def
test_batch_beam_search
(
self
):
input_ids
=
inputs
[
"input_ids"
]
# Confirms that we get the expected results with left-padded beam search
token_type_ids
=
tf
.
concat
(
model
,
tokenizer
,
sentences
,
expected_output_sentences
=
self
.
_get_beam_search_test_objects
()
[
tf
.
zeros
((
input_ids
.
shape
[
0
],
input_ids
.
shape
[
1
]
-
1
),
dtype
=
tf
.
int64
),
500
*
tf
.
ones
((
input_ids
.
shape
[
0
],
1
),
dtype
=
tf
.
int64
),
],
axis
=-
1
,
)
outputs
=
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
inputs
[
"attention_mask"
])
inputs
=
tokenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
)
outputs_tt
=
model
.
generate
(
outputs
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
num_beams
=
2
)
input_ids
=
input_ids
,
batch_out_sentence
=
tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
attention_mask
=
inputs
[
"attention_mask"
],
self
.
assertListEqual
(
expected_output_sentences
,
batch_out_sentence
)
token_type_ids
=
token_type_ids
,
)
inputs_non_padded
=
tokenizer
(
sentences
[
0
],
return_tensors
=
"tf"
).
input_ids
def
test_batch_left_padding
(
self
):
output_non_padded
=
model
.
generate
(
input_ids
=
inputs_non_padded
)
# Confirms that left-padding is working properly
model
,
tokenizer
,
sentences
,
expected_output_sentences
=
self
.
_get_beam_search_test_objects
()
inputs
=
tokenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
)
inputs_non_padded
=
tokenizer
(
sentences
[
0
],
return_tensors
=
"tf"
)
output_non_padded
=
model
.
generate
(
**
inputs_non_padded
,
do_sample
=
False
,
num_beams
=
2
)
num_paddings
=
(
num_paddings
=
(
shape_list
(
inputs_non_padded
)[
-
1
]
-
tf
.
reduce_sum
(
tf
.
cast
(
inputs
[
"attention_mask"
][
-
1
],
tf
.
int64
)).
numpy
()
shape_list
(
inputs_non_padded
[
"input_ids"
])[
-
1
]
-
tf
.
reduce_sum
(
tf
.
cast
(
inputs
[
"attention_mask"
][
-
1
],
tf
.
int64
)).
numpy
()
)
inputs_padded
=
tokenizer
(
sentences
[
1
],
return_tensors
=
"tf"
)
output_padded
=
model
.
generate
(
**
inputs_padded
,
do_sample
=
False
,
num_beams
=
2
,
max_length
=
model
.
config
.
max_length
-
num_paddings
)
)
inputs_padded
=
tokenizer
(
sentences
[
1
],
return_tensors
=
"tf"
).
input_ids
output_padded
=
model
.
generate
(
input_ids
=
inputs_padded
,
max_length
=
model
.
config
.
max_length
-
num_paddings
)
batch_out_sentence
=
tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
batch_out_sentence_tt
=
tokenizer
.
batch_decode
(
outputs_tt
,
skip_special_tokens
=
True
)
non_padded_sentence
=
tokenizer
.
decode
(
output_non_padded
[
0
],
skip_special_tokens
=
True
)
non_padded_sentence
=
tokenizer
.
decode
(
output_non_padded
[
0
],
skip_special_tokens
=
True
)
padded_sentence
=
tokenizer
.
decode
(
output_padded
[
0
],
skip_special_tokens
=
True
)
padded_sentence
=
tokenizer
.
decode
(
output_padded
[
0
],
skip_special_tokens
=
True
)
self
.
assertListEqual
(
expected_output_sentences
,
[
non_padded_sentence
,
padded_sentence
])
expected_output_sentence
=
[
def
test_xla_beam_search
(
self
):
"Hello, my dog is a little over a year old and has been diagnosed with a heart murmur"
,
# Confirms that XLA is working properly
"Today, I’m going to share with you a few of my favorite"
,
model
,
tokenizer
,
sentences
,
expected_output_sentences
=
self
.
_get_beam_search_test_objects
()
]
self
.
assertListEqual
(
expected_output_sentence
,
batch_out_sentence
)
inputs
=
tokenizer
(
sentences
,
return_tensors
=
"tf"
,
padding
=
True
)
self
.
assertTrue
(
batch_out_sentence_tt
!=
batch_out_sentence
)
# token_type_ids should change output
xla_generate
=
tf
.
function
(
model
.
generate
,
jit_compile
=
True
)
self
.
assertListEqual
(
expected_output_sentence
,
[
non_padded_sentence
,
padded_sentence
])
outputs_xla
=
xla_generate
(
**
inputs
,
do_sample
=
False
,
num_beams
=
2
)
xla_sentence
=
tokenizer
.
batch_decode
(
outputs_xla
,
skip_special_tokens
=
True
)
self
.
assertListEqual
(
expected_output_sentences
,
xla_sentence
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment