Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c4451b7a
Commit
c4451b7a
authored
Jul 01, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Jul 01, 2020
Browse files
Internal change
PiperOrigin-RevId: 319267378
parent
7b5cb554
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
115 additions
and
76 deletions
+115
-76
official/nlp/configs/bert.py
official/nlp/configs/bert.py
+0
-58
official/nlp/data/data_loader_factory.py
official/nlp/data/data_loader_factory.py
+59
-0
official/nlp/data/pretrain_dataloader.py
official/nlp/data/pretrain_dataloader.py
+17
-3
official/nlp/data/sentence_prediction_dataloader.py
official/nlp/data/sentence_prediction_dataloader.py
+13
-0
official/nlp/data/tagging_data_loader.py
official/nlp/data/tagging_data_loader.py
+12
-1
official/nlp/tasks/masked_lm.py
official/nlp/tasks/masked_lm.py
+2
-3
official/nlp/tasks/masked_lm_test.py
official/nlp/tasks/masked_lm_test.py
+2
-1
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+2
-3
official/nlp/tasks/sentence_prediction_test.py
official/nlp/tasks/sentence_prediction_test.py
+4
-2
official/nlp/tasks/tagging.py
official/nlp/tasks/tagging.py
+2
-3
official/nlp/tasks/tagging_test.py
official/nlp/tasks/tagging_test.py
+2
-2
No files found.
official/nlp/configs/bert.py
View file @
c4451b7a
...
@@ -74,45 +74,6 @@ def instantiate_bertpretrainer_from_cfg(
...
@@ -74,45 +74,6 @@ def instantiate_bertpretrainer_from_cfg(
config
.
cls_heads
))
config
.
cls_heads
))
@
dataclasses
.
dataclass
class
BertPretrainDataConfig
(
cfg
.
DataConfig
):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path
:
str
=
""
global_batch_size
:
int
=
512
is_training
:
bool
=
True
seq_length
:
int
=
512
max_predictions_per_seq
:
int
=
76
use_next_sentence_label
:
bool
=
True
use_position_id
:
bool
=
False
@
dataclasses
.
dataclass
class
BertPretrainEvalDataConfig
(
BertPretrainDataConfig
):
"""Data config for the eval set in BERT pretraining task (tasks/masked_lm)."""
input_path
:
str
=
""
global_batch_size
:
int
=
512
is_training
:
bool
=
False
@
dataclasses
.
dataclass
class
SentencePredictionDataConfig
(
cfg
.
DataConfig
):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path
:
str
=
""
global_batch_size
:
int
=
32
is_training
:
bool
=
True
seq_length
:
int
=
128
@
dataclasses
.
dataclass
class
SentencePredictionDevDataConfig
(
cfg
.
DataConfig
):
"""Dev Data config for sentence prediction (tasks/sentence_prediction)."""
input_path
:
str
=
""
global_batch_size
:
int
=
32
is_training
:
bool
=
False
seq_length
:
int
=
128
drop_remainder
:
bool
=
False
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
QADataConfig
(
cfg
.
DataConfig
):
class
QADataConfig
(
cfg
.
DataConfig
):
"""Data config for question answering task (tasks/question_answering)."""
"""Data config for question answering task (tasks/question_answering)."""
...
@@ -137,22 +98,3 @@ class QADevDataConfig(cfg.DataConfig):
...
@@ -137,22 +98,3 @@ class QADevDataConfig(cfg.DataConfig):
vocab_file
:
str
=
""
vocab_file
:
str
=
""
tokenization
:
str
=
"WordPiece"
# WordPiece or SentencePiece
tokenization
:
str
=
"WordPiece"
# WordPiece or SentencePiece
do_lower_case
:
bool
=
True
do_lower_case
:
bool
=
True
@
dataclasses
.
dataclass
class
TaggingDataConfig
(
cfg
.
DataConfig
):
"""Data config for tagging (tasks/tagging)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
True
seq_length
:
int
=
384
@
dataclasses
.
dataclass
class
TaggingDevDataConfig
(
cfg
.
DataConfig
):
"""Dev Data config for tagging (tasks/tagging)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
False
seq_length
:
int
=
384
drop_remainder
:
bool
=
False
official/nlp/data/data_loader_factory.py
0 → 100644
View file @
c4451b7a
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A global factory to access NLP registered data loaders."""
from
official.utils
import
registry
_REGISTERED_DATA_LOADER_CLS
=
{}
def
register_data_loader_cls
(
data_config_cls
):
"""Decorates a factory of DataLoader for lookup by a subclass of DataConfig.
This decorator supports registration of data loaders as follows:
```
@dataclasses.dataclass
class MyDataConfig(DataConfig):
# Add fields here.
pass
@register_data_loader_cls(MyDataConfig)
class MyDataLoader:
# Inherits def __init__(self, data_config).
pass
my_data_config = MyDataConfig()
# Returns MyDataLoader(my_data_config).
my_loader = get_data_loader(my_data_config)
```
Args:
data_config_cls: a subclass of DataConfig (*not* an instance
of DataConfig).
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of data_config_cls.
"""
return
registry
.
register
(
_REGISTERED_DATA_LOADER_CLS
,
data_config_cls
)
def
get_data_loader
(
data_config
):
"""Creates a data_loader from data_config."""
return
registry
.
lookup
(
_REGISTERED_DATA_LOADER_CLS
,
data_config
.
__class__
)(
data_config
)
official/nlp/data/pretrain_dataloader.py
View file @
c4451b7a
...
@@ -16,11 +16,27 @@
...
@@ -16,11 +16,27 @@
"""Loads dataset for the BERT pretraining task."""
"""Loads dataset for the BERT pretraining task."""
from
typing
import
Mapping
,
Optional
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
@
dataclasses
.
dataclass
class
BertPretrainDataConfig
(
cfg
.
DataConfig
):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path
:
str
=
''
global_batch_size
:
int
=
512
is_training
:
bool
=
True
seq_length
:
int
=
512
max_predictions_per_seq
:
int
=
76
use_next_sentence_label
:
bool
=
True
use_position_id
:
bool
=
False
@
data_loader_factory
.
register_data_loader_cls
(
BertPretrainDataConfig
)
class
BertPretrainDataLoader
:
class
BertPretrainDataLoader
:
"""A class to load dataset for bert pretraining task."""
"""A class to load dataset for bert pretraining task."""
...
@@ -91,7 +107,5 @@ class BertPretrainDataLoader:
...
@@ -91,7 +107,5 @@ class BertPretrainDataLoader:
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
"""Returns a tf.dataset.Dataset."""
reader
=
input_reader
.
InputReader
(
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
return
reader
.
read
(
input_context
)
return
reader
.
read
(
input_context
)
official/nlp/data/sentence_prediction_dataloader.py
View file @
c4451b7a
...
@@ -15,11 +15,24 @@
...
@@ -15,11 +15,24 @@
# ==============================================================================
# ==============================================================================
"""Loads dataset for the sentence prediction (classification) task."""
"""Loads dataset for the sentence prediction (classification) task."""
from
typing
import
Mapping
,
Optional
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
@
dataclasses
.
dataclass
class
SentencePredictionDataConfig
(
cfg
.
DataConfig
):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path
:
str
=
''
global_batch_size
:
int
=
32
is_training
:
bool
=
True
seq_length
:
int
=
128
@
data_loader_factory
.
register_data_loader_cls
(
SentencePredictionDataConfig
)
class
SentencePredictionDataLoader
:
class
SentencePredictionDataLoader
:
"""A class to load dataset for sentence prediction (classification) task."""
"""A class to load dataset for sentence prediction (classification) task."""
...
...
official/nlp/data/tagging_data_loader.py
View file @
c4451b7a
...
@@ -15,15 +15,26 @@
...
@@ -15,15 +15,26 @@
# ==============================================================================
# ==============================================================================
"""Loads dataset for the tagging (e.g., NER/POS) task."""
"""Loads dataset for the tagging (e.g., NER/POS) task."""
from
typing
import
Mapping
,
Optional
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
@
dataclasses
.
dataclass
class
TaggingDataConfig
(
cfg
.
DataConfig
):
"""Data config for tagging (tasks/tagging)."""
is_training
:
bool
=
True
seq_length
:
int
=
128
@
data_loader_factory
.
register_data_loader_cls
(
TaggingDataConfig
)
class
TaggingDataLoader
:
class
TaggingDataLoader
:
"""A class to load dataset for tagging (e.g., NER and POS) task."""
"""A class to load dataset for tagging (e.g., NER and POS) task."""
def
__init__
(
self
,
params
):
def
__init__
(
self
,
params
:
TaggingDataConfig
):
self
.
_params
=
params
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
self
.
_seq_length
=
params
.
seq_length
...
...
official/nlp/tasks/masked_lm.py
View file @
c4451b7a
...
@@ -20,7 +20,7 @@ import tensorflow as tf
...
@@ -20,7 +20,7 @@ import tensorflow as tf
from
official.core
import
base_task
from
official.core
import
base_task
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.nlp.data
import
pretrain_
dataloader
from
official.nlp.data
import
data
_
loader
_factory
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -95,8 +95,7 @@ class MaskedLMTask(base_task.Task):
...
@@ -95,8 +95,7 @@ class MaskedLMTask(base_task.Task):
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
dataset
return
pretrain_dataloader
.
BertPretrainDataLoader
(
params
).
load
(
return
data_loader_factory
.
get_data_loader
(
params
).
load
(
input_context
)
input_context
)
def
build_metrics
(
self
,
training
=
None
):
def
build_metrics
(
self
,
training
=
None
):
del
training
del
training
...
...
official/nlp/tasks/masked_lm_test.py
View file @
c4451b7a
...
@@ -19,6 +19,7 @@ import tensorflow as tf
...
@@ -19,6 +19,7 @@ import tensorflow as tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.tasks
import
masked_lm
from
official.nlp.tasks
import
masked_lm
...
@@ -33,7 +34,7 @@ class MLMTaskTest(tf.test.TestCase):
...
@@ -33,7 +34,7 @@ class MLMTaskTest(tf.test.TestCase):
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
]),
]),
train_data
=
bert
.
BertPretrainDataConfig
(
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
input_path
=
"dummy"
,
input_path
=
"dummy"
,
max_predictions_per_seq
=
20
,
max_predictions_per_seq
=
20
,
seq_length
=
128
,
seq_length
=
128
,
...
...
official/nlp/tasks/sentence_prediction.py
View file @
c4451b7a
...
@@ -25,7 +25,7 @@ import tensorflow_hub as hub
...
@@ -25,7 +25,7 @@ import tensorflow_hub as hub
from
official.core
import
base_task
from
official.core
import
base_task
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.nlp.data
import
sentence_prediction_
dataloader
from
official.nlp.data
import
data
_
loader
_factory
from
official.nlp.tasks
import
utils
from
official.nlp.tasks
import
utils
...
@@ -103,8 +103,7 @@ class SentencePredictionTask(base_task.Task):
...
@@ -103,8 +103,7 @@ class SentencePredictionTask(base_task.Task):
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
dataset
return
sentence_prediction_dataloader
.
SentencePredictionDataLoader
(
return
data_loader_factory
.
get_data_loader
(
params
).
load
(
input_context
)
params
).
load
(
input_context
)
def
build_metrics
(
self
,
training
=
None
):
def
build_metrics
(
self
,
training
=
None
):
del
training
del
training
...
...
official/nlp/tasks/sentence_prediction_test.py
View file @
c4451b7a
...
@@ -24,6 +24,7 @@ from official.nlp.bert import configs
...
@@ -24,6 +24,7 @@ from official.nlp.bert import configs
from
official.nlp.bert
import
export_tfhub
from
official.nlp.bert
import
export_tfhub
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
sentence_prediction_dataloader
from
official.nlp.tasks
import
sentence_prediction
from
official.nlp.tasks
import
sentence_prediction
...
@@ -31,8 +32,9 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -31,8 +32,9 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
SentencePredictionTaskTest
,
self
).
setUp
()
super
(
SentencePredictionTaskTest
,
self
).
setUp
()
self
.
_train_data_config
=
bert
.
SentencePredictionDataConfig
(
self
.
_train_data_config
=
(
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
)
sentence_prediction_dataloader
.
SentencePredictionDataConfig
(
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
))
def
get_model_config
(
self
,
num_classes
):
def
get_model_config
(
self
,
num_classes
):
return
bert
.
BertPretrainerConfig
(
return
bert
.
BertPretrainerConfig
(
...
...
official/nlp/tasks/tagging.py
View file @
c4451b7a
...
@@ -27,7 +27,7 @@ import tensorflow_hub as hub
...
@@ -27,7 +27,7 @@ import tensorflow_hub as hub
from
official.core
import
base_task
from
official.core
import
base_task
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
tagging_
data_loader
from
official.nlp.data
import
data_loader
_factory
from
official.nlp.modeling
import
models
from
official.nlp.modeling
import
models
from
official.nlp.tasks
import
utils
from
official.nlp.tasks
import
utils
...
@@ -138,8 +138,7 @@ class TaggingTask(base_task.Task):
...
@@ -138,8 +138,7 @@ class TaggingTask(base_task.Task):
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
dataset
dataset
=
tagging_data_loader
.
TaggingDataLoader
(
params
).
load
(
input_context
)
return
data_loader_factory
.
get_data_loader
(
params
).
load
(
input_context
)
return
dataset
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
"""Validatation step.
"""Validatation step.
...
...
official/nlp/tasks/tagging_test.py
View file @
c4451b7a
...
@@ -20,8 +20,8 @@ import tensorflow as tf
...
@@ -20,8 +20,8 @@ import tensorflow as tf
from
official.nlp.bert
import
configs
from
official.nlp.bert
import
configs
from
official.nlp.bert
import
export_tfhub
from
official.nlp.bert
import
export_tfhub
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
tagging_data_loader
from
official.nlp.tasks
import
tagging
from
official.nlp.tasks
import
tagging
...
@@ -31,7 +31,7 @@ class TaggingTest(tf.test.TestCase):
...
@@ -31,7 +31,7 @@ class TaggingTest(tf.test.TestCase):
super
(
TaggingTest
,
self
).
setUp
()
super
(
TaggingTest
,
self
).
setUp
()
self
.
_encoder_config
=
encoders
.
TransformerEncoderConfig
(
self
.
_encoder_config
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)
vocab_size
=
30522
,
num_layers
=
1
)
self
.
_train_data_config
=
b
er
t
.
TaggingDataConfig
(
self
.
_train_data_config
=
tagging_data_load
er
.
TaggingDataConfig
(
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
)
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
)
def
_run_task
(
self
,
config
):
def
_run_task
(
self
,
config
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment