Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
878f6caf
Commit
878f6caf
authored
Aug 10, 2020
by
A. Unique TensorFlower
Committed by
saberkun
Aug 10, 2020
Browse files
Internal change
PiperOrigin-RevId: 325956876
parent
0bd679b0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
629 additions
and
0 deletions
+629
-0
official/nlp/data/create_pretraining_data_test.py
official/nlp/data/create_pretraining_data_test.py
+141
-0
official/nlp/data/data_loader_factory_test.py
official/nlp/data/data_loader_factory_test.py
+45
-0
official/nlp/data/pretrain_dataloader_test.py
official/nlp/data/pretrain_dataloader_test.py
+107
-0
official/nlp/data/question_answering_dataloader_test.py
official/nlp/data/question_answering_dataloader_test.py
+74
-0
official/nlp/data/sentence_prediction_dataloader_test.py
official/nlp/data/sentence_prediction_dataloader_test.py
+83
-0
official/nlp/data/tagging_data_lib_test.py
official/nlp/data/tagging_data_lib_test.py
+109
-0
official/nlp/data/tagging_data_loader_test.py
official/nlp/data/tagging_data_loader_test.py
+70
-0
No files found.
official/nlp/data/create_pretraining_data_test.py
0 → 100644
View file @
878f6caf
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.data.create_pretraining_data."""
import
random
import
tensorflow
as
tf
from
official.nlp.data
import
create_pretraining_data
as
cpd
_VOCAB_WORDS
=
[
"vocab_1"
,
"vocab_2"
]
class
CreatePretrainingDataTest
(
tf
.
test
.
TestCase
):
def
assertTokens
(
self
,
input_tokens
,
output_tokens
,
masked_positions
,
masked_labels
):
# Ensure the masked positions are unique.
self
.
assertCountEqual
(
masked_positions
,
set
(
masked_positions
))
# Ensure we can reconstruct the input from the output.
reconstructed_tokens
=
output_tokens
for
pos
,
label
in
zip
(
masked_positions
,
masked_labels
):
reconstructed_tokens
[
pos
]
=
label
self
.
assertEqual
(
input_tokens
,
reconstructed_tokens
)
# Ensure each label is valid.
for
pos
,
label
in
zip
(
masked_positions
,
masked_labels
):
output_token
=
output_tokens
[
pos
]
if
(
output_token
==
"[MASK]"
or
output_token
in
_VOCAB_WORDS
or
output_token
==
input_tokens
[
pos
]):
continue
self
.
fail
(
"invalid mask value: {}"
.
format
(
output_token
))
def
test_wordpieces_to_grams
(
self
):
tests
=
[
([
"That"
,
"cone"
],
[(
0
,
1
),
(
1
,
2
)]),
([
"That"
,
"cone"
,
"##s"
],
[(
0
,
1
),
(
1
,
3
)]),
([
"Swit"
,
"##zer"
,
"##land"
],
[(
0
,
3
)]),
([
"[CLS]"
,
"Up"
,
"##dog"
],
[(
1
,
3
)]),
([
"[CLS]"
,
"Up"
,
"##dog"
,
"[SEP]"
,
"Down"
],
[(
1
,
3
),
(
4
,
5
)]),
]
for
inp
,
expected
in
tests
:
output
=
cpd
.
_wordpieces_to_grams
(
inp
)
self
.
assertEqual
(
expected
,
output
)
def
test_window
(
self
):
input_list
=
[
1
,
2
,
3
,
4
]
window_outputs
=
[
(
1
,
[[
1
],
[
2
],
[
3
],
[
4
]]),
(
2
,
[[
1
,
2
],
[
2
,
3
],
[
3
,
4
]]),
(
3
,
[[
1
,
2
,
3
],
[
2
,
3
,
4
]]),
(
4
,
[[
1
,
2
,
3
,
4
]]),
(
5
,
[]),
]
for
window
,
expected
in
window_outputs
:
output
=
cpd
.
_window
(
input_list
,
window
)
self
.
assertEqual
(
expected
,
list
(
output
))
def
test_create_masked_lm_predictions
(
self
):
tokens
=
[
"[CLS]"
,
"a"
,
"##a"
,
"b"
,
"##b"
,
"c"
,
"##c"
,
"[SEP]"
]
rng
=
random
.
Random
(
123
)
for
_
in
range
(
0
,
5
):
output_tokens
,
masked_positions
,
masked_labels
=
(
cpd
.
create_masked_lm_predictions
(
tokens
=
tokens
,
masked_lm_prob
=
1.0
,
max_predictions_per_seq
=
3
,
vocab_words
=
_VOCAB_WORDS
,
rng
=
rng
,
do_whole_word_mask
=
False
,
max_ngram_size
=
None
))
self
.
assertEqual
(
len
(
masked_positions
),
3
)
self
.
assertEqual
(
len
(
masked_labels
),
3
)
self
.
assertTokens
(
tokens
,
output_tokens
,
masked_positions
,
masked_labels
)
def
test_create_masked_lm_predictions_whole_word
(
self
):
tokens
=
[
"[CLS]"
,
"a"
,
"##a"
,
"b"
,
"##b"
,
"c"
,
"##c"
,
"[SEP]"
]
rng
=
random
.
Random
(
345
)
for
_
in
range
(
0
,
5
):
output_tokens
,
masked_positions
,
masked_labels
=
(
cpd
.
create_masked_lm_predictions
(
tokens
=
tokens
,
masked_lm_prob
=
1.0
,
max_predictions_per_seq
=
3
,
vocab_words
=
_VOCAB_WORDS
,
rng
=
rng
,
do_whole_word_mask
=
True
,
max_ngram_size
=
None
))
# since we can't get exactly three tokens without breaking a word we
# only take two.
self
.
assertEqual
(
len
(
masked_positions
),
2
)
self
.
assertEqual
(
len
(
masked_labels
),
2
)
self
.
assertTokens
(
tokens
,
output_tokens
,
masked_positions
,
masked_labels
)
# ensure that we took an entire word.
self
.
assertIn
(
masked_labels
,
[[
"a"
,
"##a"
],
[
"b"
,
"##b"
],
[
"c"
,
"##c"
]])
def
test_create_masked_lm_predictions_ngram
(
self
):
tokens
=
[
"[CLS]"
]
+
[
"tok{}"
.
format
(
i
)
for
i
in
range
(
0
,
512
)]
+
[
"[SEP]"
]
rng
=
random
.
Random
(
345
)
for
_
in
range
(
0
,
5
):
output_tokens
,
masked_positions
,
masked_labels
=
(
cpd
.
create_masked_lm_predictions
(
tokens
=
tokens
,
masked_lm_prob
=
1.0
,
max_predictions_per_seq
=
76
,
vocab_words
=
_VOCAB_WORDS
,
rng
=
rng
,
do_whole_word_mask
=
True
,
max_ngram_size
=
3
))
self
.
assertEqual
(
len
(
masked_positions
),
76
)
self
.
assertEqual
(
len
(
masked_labels
),
76
)
self
.
assertTokens
(
tokens
,
output_tokens
,
masked_positions
,
masked_labels
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/data/data_loader_factory_test.py
0 → 100644
View file @
878f6caf
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.data.data_loader_factory."""
import
dataclasses
import
tensorflow
as
tf
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
@
dataclasses
.
dataclass
class
MyDataConfig
(
cfg
.
DataConfig
):
is_training
:
bool
=
True
@
data_loader_factory
.
register_data_loader_cls
(
MyDataConfig
)
class
MyDataLoader
:
def
__init__
(
self
,
params
):
self
.
params
=
params
class
DataLoaderFactoryTest
(
tf
.
test
.
TestCase
):
def
test_register_and_load
(
self
):
train_config
=
MyDataConfig
()
train_loader
=
data_loader_factory
.
get_data_loader
(
train_config
)
self
.
assertTrue
(
train_loader
.
params
.
is_training
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/data/pretrain_dataloader_test.py
0 → 100644
View file @
878f6caf
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.data.pretrain_dataloader."""
import
itertools
import
os
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.data
import
pretrain_dataloader
def
_create_fake_dataset
(
output_path
,
seq_length
,
max_predictions_per_seq
,
use_position_id
,
use_next_sentence_label
):
"""Creates a fake dataset."""
writer
=
tf
.
io
.
TFRecordWriter
(
output_path
)
def
create_int_feature
(
values
):
f
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
f
def
create_float_feature
(
values
):
f
=
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
list
(
values
)))
return
f
for
_
in
range
(
100
):
features
=
{}
input_ids
=
np
.
random
.
randint
(
100
,
size
=
(
seq_length
))
features
[
"input_ids"
]
=
create_int_feature
(
input_ids
)
features
[
"input_mask"
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
features
[
"segment_ids"
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
features
[
"masked_lm_positions"
]
=
create_int_feature
(
np
.
random
.
randint
(
100
,
size
=
(
max_predictions_per_seq
)))
features
[
"masked_lm_ids"
]
=
create_int_feature
(
np
.
random
.
randint
(
100
,
size
=
(
max_predictions_per_seq
)))
features
[
"masked_lm_weights"
]
=
create_float_feature
(
[
1.0
]
*
max_predictions_per_seq
)
if
use_next_sentence_label
:
features
[
"next_sentence_labels"
]
=
create_int_feature
([
1
])
if
use_position_id
:
features
[
"position_ids"
]
=
create_int_feature
(
range
(
0
,
seq_length
))
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
tf_example
.
SerializeToString
())
writer
.
close
()
class
BertPretrainDataTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
itertools
.
product
(
(
False
,
True
),
(
False
,
True
),
))
def
test_load_data
(
self
,
use_next_sentence_label
,
use_position_id
):
train_data_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"train.tf_record"
)
seq_length
=
128
max_predictions_per_seq
=
20
_create_fake_dataset
(
train_data_path
,
seq_length
,
max_predictions_per_seq
,
use_next_sentence_label
=
use_next_sentence_label
,
use_position_id
=
use_position_id
)
data_config
=
pretrain_dataloader
.
BertPretrainDataConfig
(
input_path
=
train_data_path
,
max_predictions_per_seq
=
max_predictions_per_seq
,
seq_length
=
seq_length
,
global_batch_size
=
10
,
is_training
=
True
,
use_next_sentence_label
=
use_next_sentence_label
,
use_position_id
=
use_position_id
)
dataset
=
pretrain_dataloader
.
BertPretrainDataLoader
(
data_config
).
load
()
features
=
next
(
iter
(
dataset
))
self
.
assertLen
(
features
,
6
+
int
(
use_next_sentence_label
)
+
int
(
use_position_id
))
self
.
assertIn
(
"input_word_ids"
,
features
)
self
.
assertIn
(
"input_mask"
,
features
)
self
.
assertIn
(
"input_type_ids"
,
features
)
self
.
assertIn
(
"masked_lm_positions"
,
features
)
self
.
assertIn
(
"masked_lm_ids"
,
features
)
self
.
assertIn
(
"masked_lm_weights"
,
features
)
self
.
assertEqual
(
"next_sentence_labels"
in
features
,
use_next_sentence_label
)
self
.
assertEqual
(
"position_ids"
in
features
,
use_position_id
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/data/question_answering_dataloader_test.py
0 → 100644
View file @
878f6caf
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.data.question_answering_dataloader."""
import
os
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.data
import
question_answering_dataloader
def
_create_fake_dataset
(
output_path
,
seq_length
):
"""Creates a fake dataset."""
writer
=
tf
.
io
.
TFRecordWriter
(
output_path
)
def
create_int_feature
(
values
):
f
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
f
for
_
in
range
(
100
):
features
=
{}
input_ids
=
np
.
random
.
randint
(
100
,
size
=
(
seq_length
))
features
[
'input_ids'
]
=
create_int_feature
(
input_ids
)
features
[
'input_mask'
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
features
[
'segment_ids'
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
features
[
'start_positions'
]
=
create_int_feature
(
np
.
array
([
0
]))
features
[
'end_positions'
]
=
create_int_feature
(
np
.
array
([
10
]))
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
tf_example
.
SerializeToString
())
writer
.
close
()
class
QuestionAnsweringDataTest
(
tf
.
test
.
TestCase
):
def
test_load_dataset
(
self
):
seq_length
=
128
batch_size
=
10
input_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'train.tf_record'
)
_create_fake_dataset
(
input_path
,
seq_length
)
data_config
=
question_answering_dataloader
.
QADataConfig
(
is_training
=
True
,
input_path
=
input_path
,
seq_length
=
seq_length
,
global_batch_size
=
batch_size
)
dataset
=
question_answering_dataloader
.
QuestionAnsweringDataLoader
(
data_config
).
load
()
features
,
labels
=
next
(
iter
(
dataset
))
self
.
assertCountEqual
([
'input_word_ids'
,
'input_mask'
,
'input_type_ids'
],
features
.
keys
())
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertCountEqual
([
'start_positions'
,
'end_positions'
],
labels
.
keys
())
self
.
assertEqual
(
labels
[
'start_positions'
].
shape
,
(
batch_size
,))
self
.
assertEqual
(
labels
[
'end_positions'
].
shape
,
(
batch_size
,))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/data/sentence_prediction_dataloader_test.py
0 → 100644
View file @
878f6caf
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.data.sentence_prediction_dataloader."""
import
os
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.data
import
sentence_prediction_dataloader
def
_create_fake_dataset
(
output_path
,
seq_length
,
label_type
):
"""Creates a fake dataset."""
writer
=
tf
.
io
.
TFRecordWriter
(
output_path
)
def
create_int_feature
(
values
):
f
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
f
def
create_float_feature
(
values
):
f
=
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
list
(
values
)))
return
f
for
_
in
range
(
100
):
features
=
{}
input_ids
=
np
.
random
.
randint
(
100
,
size
=
(
seq_length
))
features
[
'input_ids'
]
=
create_int_feature
(
input_ids
)
features
[
'input_mask'
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
features
[
'segment_ids'
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
if
label_type
==
'int'
:
features
[
'label_ids'
]
=
create_int_feature
([
1
])
elif
label_type
==
'float'
:
features
[
'label_ids'
]
=
create_float_feature
([
0.5
])
else
:
raise
ValueError
(
'Unsupported label_type: %s'
%
label_type
)
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
tf_example
.
SerializeToString
())
writer
.
close
()
class
SentencePredictionDataTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
((
'int'
,
tf
.
int32
),
(
'float'
,
tf
.
float32
))
def
test_load_dataset
(
self
,
label_type
,
expected_label_type
):
input_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'train.tf_record'
)
batch_size
=
10
seq_length
=
128
_create_fake_dataset
(
input_path
,
seq_length
,
label_type
)
data_config
=
sentence_prediction_dataloader
.
SentencePredictionDataConfig
(
input_path
=
input_path
,
seq_length
=
seq_length
,
global_batch_size
=
batch_size
,
label_type
=
label_type
)
dataset
=
sentence_prediction_dataloader
.
SentencePredictionDataLoader
(
data_config
).
load
()
features
,
labels
=
next
(
iter
(
dataset
))
self
.
assertCountEqual
([
'input_word_ids'
,
'input_mask'
,
'input_type_ids'
],
features
.
keys
())
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
labels
.
shape
,
(
batch_size
,))
self
.
assertEqual
(
labels
.
dtype
,
expected_label_type
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/data/tagging_data_lib_test.py
0 → 100644
View file @
878f6caf
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.data.tagging_data_lib."""
import
os
import
random
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.nlp.bert
import
tokenization
from
official.nlp.data
import
tagging_data_lib
def
_create_fake_file
(
filename
,
labels
,
is_test
):
def
write_one_sentence
(
writer
,
length
):
for
_
in
range
(
length
):
line
=
"hiworld"
if
not
is_test
:
line
+=
"
\t
%s"
%
(
labels
[
random
.
randint
(
0
,
len
(
labels
)
-
1
)])
writer
.
write
(
line
+
"
\n
"
)
# Writes two sentences with length of 3 and 12 respectively.
with
tf
.
io
.
gfile
.
GFile
(
filename
,
"w"
)
as
writer
:
write_one_sentence
(
writer
,
3
)
writer
.
write
(
"
\n
"
)
write_one_sentence
(
writer
,
12
)
class
TaggingDataLibTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
TaggingDataLibTest
,
self
).
setUp
()
self
.
processors
=
{
"panx"
:
tagging_data_lib
.
PanxProcessor
,
"udpos"
:
tagging_data_lib
.
UdposProcessor
,
}
self
.
vocab_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"vocab.txt"
)
with
tf
.
io
.
gfile
.
GFile
(
self
.
vocab_file
,
"w"
)
as
writer
:
writer
.
write
(
"
\n
"
.
join
([
"[CLS]"
,
"[SEP]"
,
"hi"
,
"##world"
,
"[UNK]"
]))
@
parameterized
.
parameters
(
{
"task_type"
:
"panx"
},
{
"task_type"
:
"udpos"
},
)
def
test_generate_tf_record
(
self
,
task_type
):
processor
=
self
.
processors
[
task_type
]()
input_data_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
task_type
)
tf
.
io
.
gfile
.
mkdir
(
input_data_dir
)
# Write fake train file.
_create_fake_file
(
os
.
path
.
join
(
input_data_dir
,
"train-en.tsv"
),
processor
.
get_labels
(),
is_test
=
False
)
# Write fake dev file.
_create_fake_file
(
os
.
path
.
join
(
input_data_dir
,
"dev-en.tsv"
),
processor
.
get_labels
(),
is_test
=
False
)
# Write fake test files.
for
lang
in
processor
.
supported_languages
:
_create_fake_file
(
os
.
path
.
join
(
input_data_dir
,
"test-%s.tsv"
%
lang
),
processor
.
get_labels
(),
is_test
=
True
)
output_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
task_type
,
"output"
)
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
self
.
vocab_file
,
do_lower_case
=
True
)
metadata
=
tagging_data_lib
.
generate_tf_record_from_data_file
(
processor
,
input_data_dir
,
tokenizer
,
max_seq_length
=
8
,
train_data_output_path
=
os
.
path
.
join
(
output_path
,
"train.tfrecord"
),
eval_data_output_path
=
os
.
path
.
join
(
output_path
,
"eval.tfrecord"
),
test_data_output_path
=
os
.
path
.
join
(
output_path
,
"test_{}.tfrecord"
),
text_preprocessing
=
tokenization
.
convert_to_unicode
)
self
.
assertEqual
(
metadata
[
"train_data_size"
],
5
)
files
=
tf
.
io
.
gfile
.
glob
(
output_path
+
"/*"
)
expected_files
=
[]
expected_files
.
append
(
os
.
path
.
join
(
output_path
,
"train.tfrecord"
))
expected_files
.
append
(
os
.
path
.
join
(
output_path
,
"eval.tfrecord"
))
for
lang
in
processor
.
supported_languages
:
expected_files
.
append
(
os
.
path
.
join
(
output_path
,
"test_%s.tfrecord"
%
lang
))
self
.
assertCountEqual
(
files
,
expected_files
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/data/tagging_data_loader_test.py
0 → 100644
View file @
878f6caf
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.data.tagging_data_loader."""
import
os
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.data
import
tagging_data_loader
def
_create_fake_dataset
(
output_path
,
seq_length
):
"""Creates a fake dataset."""
writer
=
tf
.
io
.
TFRecordWriter
(
output_path
)
def
create_int_feature
(
values
):
f
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
f
for
_
in
range
(
100
):
features
=
{}
input_ids
=
np
.
random
.
randint
(
100
,
size
=
(
seq_length
))
features
[
'input_ids'
]
=
create_int_feature
(
input_ids
)
features
[
'input_mask'
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
features
[
'segment_ids'
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
features
[
'label_ids'
]
=
create_int_feature
(
np
.
random
.
randint
(
10
,
size
=
(
seq_length
)))
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
tf_example
.
SerializeToString
())
writer
.
close
()
class
TaggingDataLoaderTest
(
tf
.
test
.
TestCase
):
def
test_load_dataset
(
self
):
seq_length
=
16
batch_size
=
10
train_data_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'train.tf_record'
)
_create_fake_dataset
(
train_data_path
,
seq_length
)
data_config
=
tagging_data_loader
.
TaggingDataConfig
(
input_path
=
train_data_path
,
seq_length
=
seq_length
,
global_batch_size
=
batch_size
)
dataset
=
tagging_data_loader
.
TaggingDataLoader
(
data_config
).
load
()
features
,
labels
=
next
(
iter
(
dataset
))
self
.
assertCountEqual
([
'input_word_ids'
,
'input_mask'
,
'input_type_ids'
],
features
.
keys
())
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
labels
.
shape
,
(
batch_size
,
seq_length
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment