Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1d8b2263
Commit
1d8b2263
authored
Jul 03, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 319495967
parent
bed83905
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
104 additions
and
41 deletions
+104
-41
official/nlp/configs/bert.py
official/nlp/configs/bert.py
+0
-27
official/nlp/data/question_answering_dataloader.py
official/nlp/data/question_answering_dataloader.py
+95
-0
official/nlp/tasks/question_answering.py
official/nlp/tasks/question_answering.py
+5
-12
official/nlp/tasks/question_answering_test.py
official/nlp/tasks/question_answering_test.py
+4
-2
No files found.
official/nlp/configs/bert.py
View file @
1d8b2263
...
@@ -24,7 +24,6 @@ import tensorflow as tf
...
@@ -24,7 +24,6 @@ import tensorflow as tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.models
import
bert_pretrainer
from
official.nlp.modeling.models
import
bert_pretrainer
...
@@ -72,29 +71,3 @@ def instantiate_bertpretrainer_from_cfg(
...
@@ -72,29 +71,3 @@ def instantiate_bertpretrainer_from_cfg(
encoder_network
=
encoder_network
,
encoder_network
=
encoder_network
,
classification_heads
=
instantiate_classification_heads_from_cfgs
(
classification_heads
=
instantiate_classification_heads_from_cfgs
(
config
.
cls_heads
))
config
.
cls_heads
))
@
dataclasses
.
dataclass
class
QADataConfig
(
cfg
.
DataConfig
):
"""Data config for question answering task (tasks/question_answering)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
True
seq_length
:
int
=
384
@
dataclasses
.
dataclass
class
QADevDataConfig
(
cfg
.
DataConfig
):
"""Dev Data config for queston answering (tasks/question_answering)."""
input_path
:
str
=
""
input_preprocessed_data_path
:
str
=
""
version_2_with_negative
:
bool
=
False
doc_stride
:
int
=
128
global_batch_size
:
int
=
48
is_training
:
bool
=
False
seq_length
:
int
=
384
query_length
:
int
=
64
drop_remainder
:
bool
=
False
vocab_file
:
str
=
""
tokenization
:
str
=
"WordPiece"
# WordPiece or SentencePiece
do_lower_case
:
bool
=
True
official/nlp/data/question_answering_dataloader.py
0 → 100644
View file @
1d8b2263
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the question answering (e.g, SQuAD) task."""
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
@
dataclasses
.
dataclass
class
QADataConfig
(
cfg
.
DataConfig
):
"""Data config for question answering task (tasks/question_answering)."""
input_path
:
str
=
''
global_batch_size
:
int
=
48
is_training
:
bool
=
True
seq_length
:
int
=
384
# Settings below are question answering specific.
version_2_with_negative
:
bool
=
False
# Settings below are only used for eval mode.
input_preprocessed_data_path
:
str
=
''
doc_stride
:
int
=
128
query_length
:
int
=
64
vocab_file
:
str
=
''
tokenization
:
str
=
'WordPiece'
# WordPiece or SentencePiece
do_lower_case
:
bool
=
True
@
data_loader_factory
.
register_data_loader_cls
(
QADataConfig
)
class
QuestionAnsweringDataLoader
:
"""A class to load dataset for sentence prediction (classification) task."""
def
__init__
(
self
,
params
):
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
self
.
_is_training
=
params
.
is_training
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
}
if
self
.
_is_training
:
name_to_features
[
'start_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
name_to_features
[
'end_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
else
:
name_to_features
[
'unique_ids'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for
name
in
example
:
t
=
example
[
name
]
if
t
.
dtype
==
tf
.
int64
:
t
=
tf
.
cast
(
t
,
tf
.
int32
)
example
[
name
]
=
t
return
example
def
_parse
(
self
,
record
:
Mapping
[
str
,
tf
.
Tensor
]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x
,
y
=
{},
{}
for
name
,
tensor
in
record
.
items
():
if
name
in
(
'start_positions'
,
'end_positions'
):
y
[
name
]
=
tensor
elif
name
==
'input_ids'
:
x
[
'input_word_ids'
]
=
tensor
elif
name
==
'segment_ids'
:
x
[
'input_type_ids'
]
=
tensor
else
:
x
[
name
]
=
tensor
return
(
x
,
y
)
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
return
reader
.
read
(
input_context
)
official/nlp/tasks/question_answering.py
View file @
1d8b2263
...
@@ -24,11 +24,11 @@ import tensorflow_hub as hub
...
@@ -24,11 +24,11 @@ import tensorflow_hub as hub
from
official.core
import
base_task
from
official.core
import
base_task
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.bert
import
input_pipeline
from
official.nlp.bert
import
squad_evaluate_v1_1
from
official.nlp.bert
import
squad_evaluate_v1_1
from
official.nlp.bert
import
squad_evaluate_v2_0
from
official.nlp.bert
import
squad_evaluate_v2_0
from
official.nlp.bert
import
tokenization
from
official.nlp.bert
import
tokenization
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
data_loader_factory
from
official.nlp.data
import
squad_lib
as
squad_lib_wp
from
official.nlp.data
import
squad_lib
as
squad_lib_wp
from
official.nlp.data
import
squad_lib_sp
from
official.nlp.data
import
squad_lib_sp
from
official.nlp.modeling
import
models
from
official.nlp.modeling
import
models
...
@@ -174,20 +174,13 @@ class QuestionAnsweringTask(base_task.Task):
...
@@ -174,20 +174,13 @@ class QuestionAnsweringTask(base_task.Task):
return
dataset
return
dataset
if
params
.
is_training
:
if
params
.
is_training
:
input_path
=
params
.
input_path
dataloader_params
=
params
else
:
else
:
input_path
=
self
.
_tf_record_input_path
input_path
=
self
.
_tf_record_input_path
dataloader_params
=
params
.
replace
(
input_path
=
input_path
)
batch_size
=
input_context
.
get_per_replica_batch_size
(
return
data_loader_factory
.
get_data_loader
(
params
.
global_batch_size
)
if
input_context
else
params
.
global_batch_size
dataloader_params
).
load
(
input_context
)
# TODO(chendouble): add and use nlp.data.question_answering_dataloader.
dataset
=
input_pipeline
.
create_squad_dataset
(
input_path
,
params
.
seq_length
,
batch_size
,
is_training
=
params
.
is_training
,
input_pipeline_context
=
input_context
)
return
dataset
def
build_metrics
(
self
,
training
=
None
):
def
build_metrics
(
self
,
training
=
None
):
del
training
del
training
...
...
official/nlp/tasks/question_answering_test.py
View file @
1d8b2263
...
@@ -24,6 +24,7 @@ from official.nlp.bert import configs
...
@@ -24,6 +24,7 @@ from official.nlp.bert import configs
from
official.nlp.bert
import
export_tfhub
from
official.nlp.bert
import
export_tfhub
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
question_answering_dataloader
from
official.nlp.tasks
import
question_answering
from
official.nlp.tasks
import
question_answering
...
@@ -33,7 +34,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -33,7 +34,7 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
super
(
QuestionAnsweringTaskTest
,
self
).
setUp
()
super
(
QuestionAnsweringTaskTest
,
self
).
setUp
()
self
.
_encoder_config
=
encoders
.
TransformerEncoderConfig
(
self
.
_encoder_config
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)
vocab_size
=
30522
,
num_layers
=
1
)
self
.
_train_data_config
=
bert
.
QADataConfig
(
self
.
_train_data_config
=
question_answering_dataloader
.
QADataConfig
(
input_path
=
"dummy"
,
input_path
=
"dummy"
,
seq_length
=
128
,
seq_length
=
128
,
global_batch_size
=
1
)
global_batch_size
=
1
)
...
@@ -55,7 +56,8 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -55,7 +56,8 @@ class QuestionAnsweringTaskTest(tf.test.TestCase, parameterized.TestCase):
writer
.
write
(
"[PAD]
\n
[UNK]
\n
[CLS]
\n
[SEP]
\n
[MASK]
\n
sky
\n
is
\n
blue
\n
"
)
writer
.
write
(
"[PAD]
\n
[UNK]
\n
[CLS]
\n
[SEP]
\n
[MASK]
\n
sky
\n
is
\n
blue
\n
"
)
def
_get_validation_data_config
(
self
,
version_2_with_negative
=
False
):
def
_get_validation_data_config
(
self
,
version_2_with_negative
=
False
):
return
bert
.
QADevDataConfig
(
return
question_answering_dataloader
.
QADataConfig
(
is_training
=
False
,
input_path
=
self
.
_val_input_path
,
input_path
=
self
.
_val_input_path
,
input_preprocessed_data_path
=
self
.
get_temp_dir
(),
input_preprocessed_data_path
=
self
.
get_temp_dir
(),
seq_length
=
128
,
seq_length
=
128
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment