Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1c89b792
Commit
1c89b792
authored
Aug 21, 2020
by
Maxim Neumann
Committed by
A. Unique TensorFlower
Aug 21, 2020
Browse files
Add a flag to control the number of train examples.
PiperOrigin-RevId: 327838493
parent
e0b6ce02
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
6 deletions
+19
-6
official/nlp/bert/input_pipeline.py
official/nlp/bert/input_pipeline.py
+7
-3
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+12
-3
No files found.
official/nlp/bert/input_pipeline.py
View file @
1c89b792
...
...
@@ -36,11 +36,13 @@ def decode_record(record, name_to_features):
return
example
def
single_file_dataset
(
input_file
,
name_to_features
):
def
single_file_dataset
(
input_file
,
name_to_features
,
num_samples
=
None
):
"""Creates a single-file dataset to be passed for BERT custom training."""
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d
=
tf
.
data
.
TFRecordDataset
(
input_file
)
if
num_samples
:
d
=
d
.
take
(
num_samples
)
d
=
d
.
map
(
lambda
record
:
decode_record
(
record
,
name_to_features
),
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
...
...
@@ -156,7 +158,8 @@ def create_classifier_dataset(file_path,
is_training
=
True
,
input_pipeline_context
=
None
,
label_type
=
tf
.
int64
,
include_sample_weights
=
False
):
include_sample_weights
=
False
,
num_samples
=
None
):
"""Creates input dataset from (tf)records files for train/eval."""
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
...
...
@@ -166,7 +169,8 @@ def create_classifier_dataset(file_path,
}
if
include_sample_weights
:
name_to_features
[
'weight'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
float32
)
dataset
=
single_file_dataset
(
file_path
,
name_to_features
)
dataset
=
single_file_dataset
(
file_path
,
name_to_features
,
num_samples
=
num_samples
)
# The dataset is always sharded by number of hosts.
# num_input_pipelines is the number of hosts rather than number of cores.
...
...
official/nlp/bert/run_classifier.py
View file @
1c89b792
...
...
@@ -53,6 +53,9 @@ flags.DEFINE_string(
'input_meta_data_path'
,
None
,
'Path to file that contains meta data about input '
'to be used for training and evaluation.'
)
flags
.
DEFINE_integer
(
'train_data_size'
,
None
,
'Number of training samples '
'to use. If None, uses the full train data. '
'(default: None).'
)
flags
.
DEFINE_string
(
'predict_checkpoint_path'
,
None
,
'Path to the checkpoint for predictions.'
)
flags
.
DEFINE_integer
(
...
...
@@ -92,7 +95,8 @@ def get_dataset_fn(input_file_pattern,
global_batch_size
,
is_training
,
label_type
=
tf
.
int64
,
include_sample_weights
=
False
):
include_sample_weights
=
False
,
num_samples
=
None
):
"""Gets a closure to create a dataset."""
def
_dataset_fn
(
ctx
=
None
):
...
...
@@ -106,7 +110,8 @@ def get_dataset_fn(input_file_pattern,
is_training
=
is_training
,
input_pipeline_context
=
ctx
,
label_type
=
label_type
,
include_sample_weights
=
include_sample_weights
)
include_sample_weights
=
include_sample_weights
,
num_samples
=
num_samples
)
return
dataset
return
_dataset_fn
...
...
@@ -374,6 +379,9 @@ def run_bert(strategy,
epochs
=
FLAGS
.
num_train_epochs
*
FLAGS
.
num_eval_per_epoch
train_data_size
=
(
input_meta_data
[
'train_data_size'
]
//
FLAGS
.
num_eval_per_epoch
)
if
FLAGS
.
train_data_size
:
train_data_size
=
min
(
train_data_size
,
FLAGS
.
train_data_size
)
logging
.
info
(
'Updated train_data_size: %s'
,
train_data_size
)
steps_per_epoch
=
int
(
train_data_size
/
FLAGS
.
train_batch_size
)
warmup_steps
=
int
(
epochs
*
train_data_size
*
0.1
/
FLAGS
.
train_batch_size
)
eval_steps
=
int
(
...
...
@@ -489,7 +497,8 @@ def custom_main(custom_callbacks=None, custom_metrics=None):
FLAGS
.
train_batch_size
,
is_training
=
True
,
label_type
=
label_type
,
include_sample_weights
=
include_sample_weights
)
include_sample_weights
=
include_sample_weights
,
num_samples
=
FLAGS
.
train_data_size
)
run_bert
(
strategy
,
input_meta_data
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment