Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
32e4ca51
Commit
32e4ca51
authored
Nov 28, 2023
by
qianyj
Browse files
Update code to v2.11.0
parents
9485aa1d
71060f67
Changes
775
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
462 additions
and
20 deletions
+462
-20
official/legacy/transformer/transformer_layers_test.py
official/legacy/transformer/transformer_layers_test.py
+1
-1
official/legacy/transformer/transformer_main.py
official/legacy/transformer/transformer_main.py
+1
-1
official/legacy/transformer/transformer_main_test.py
official/legacy/transformer/transformer_main_test.py
+1
-1
official/legacy/transformer/transformer_test.py
official/legacy/transformer/transformer_test.py
+1
-1
official/legacy/transformer/translate.py
official/legacy/transformer/translate.py
+1
-1
official/legacy/transformer/utils/__init__.py
official/legacy/transformer/utils/__init__.py
+1
-1
official/legacy/transformer/utils/metrics.py
official/legacy/transformer/utils/metrics.py
+1
-1
official/legacy/transformer/utils/tokenizer.py
official/legacy/transformer/utils/tokenizer.py
+1
-1
official/legacy/transformer/utils/tokenizer_test.py
official/legacy/transformer/utils/tokenizer_test.py
+1
-1
official/legacy/xlnet/README.md
official/legacy/xlnet/README.md
+0
-0
official/legacy/xlnet/__init__.py
official/legacy/xlnet/__init__.py
+15
-0
official/legacy/xlnet/classifier_utils.py
official/legacy/xlnet/classifier_utils.py
+2
-2
official/legacy/xlnet/common_flags.py
official/legacy/xlnet/common_flags.py
+142
-0
official/legacy/xlnet/data_utils.py
official/legacy/xlnet/data_utils.py
+1
-1
official/legacy/xlnet/optimization.py
official/legacy/xlnet/optimization.py
+98
-0
official/legacy/xlnet/preprocess_classification_data.py
official/legacy/xlnet/preprocess_classification_data.py
+3
-3
official/legacy/xlnet/preprocess_pretrain_data.py
official/legacy/xlnet/preprocess_pretrain_data.py
+2
-2
official/legacy/xlnet/preprocess_squad_data.py
official/legacy/xlnet/preprocess_squad_data.py
+2
-2
official/legacy/xlnet/preprocess_utils.py
official/legacy/xlnet/preprocess_utils.py
+1
-1
official/legacy/xlnet/run_classifier.py
official/legacy/xlnet/run_classifier.py
+187
-0
No files found.
Too many changes to show.
To preserve performance only
775 of 775+
files are displayed.
Plain diff
Email patch
official/legacy/transformer/transformer_layers_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/transformer/transformer_main.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/transformer/transformer_main_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/transformer/transformer_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/transformer/translate.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/transformer/utils/__init__.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/transformer/utils/metrics.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/transformer/utils/tokenizer.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/transformer/utils/tokenizer_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/
nlp
/xlnet/README.md
→
official/
legacy
/xlnet/README.md
View file @
32e4ca51
File moved
official/legacy/xlnet/__init__.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/
nlp
/xlnet/classifier_utils.py
→
official/
legacy
/xlnet/classifier_utils.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -16,7 +16,7 @@
from
absl
import
logging
from
official.
nlp
.xlnet
import
data_utils
from
official.
legacy
.xlnet
import
data_utils
SEG_ID_A
=
0
SEG_ID_B
=
1
...
...
official/legacy/xlnet/common_flags.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common flags used in XLNet model."""
from
absl
import
flags
flags
.
DEFINE_string
(
"master"
,
default
=
None
,
help
=
"master"
)
flags
.
DEFINE_string
(
"tpu"
,
default
=
None
,
help
=
"The Cloud TPU to use for training. This should be "
"either the name used when creating the Cloud TPU, or a "
"url like grpc://ip.address.of.tpu:8470."
)
flags
.
DEFINE_bool
(
"use_tpu"
,
default
=
True
,
help
=
"Use TPUs rather than plain CPUs."
)
flags
.
DEFINE_string
(
"tpu_topology"
,
"2x2"
,
help
=
"TPU topology."
)
flags
.
DEFINE_integer
(
"num_core_per_host"
,
default
=
8
,
help
=
"number of cores per host"
)
flags
.
DEFINE_string
(
"model_dir"
,
default
=
None
,
help
=
"Estimator model_dir."
)
flags
.
DEFINE_string
(
"init_checkpoint"
,
default
=
None
,
help
=
"Checkpoint path for initializing the model."
)
flags
.
DEFINE_bool
(
"init_from_transformerxl"
,
default
=
False
,
help
=
"Init from a transformerxl model checkpoint. Otherwise, init from the "
"entire model checkpoint."
)
# Optimization config
flags
.
DEFINE_float
(
"learning_rate"
,
default
=
1e-4
,
help
=
"Maximum learning rate."
)
flags
.
DEFINE_float
(
"clip"
,
default
=
1.0
,
help
=
"Gradient clipping value."
)
flags
.
DEFINE_float
(
"weight_decay_rate"
,
default
=
0.0
,
help
=
"Weight decay rate."
)
# lr decay
flags
.
DEFINE_integer
(
"warmup_steps"
,
default
=
0
,
help
=
"Number of steps for linear lr warmup."
)
flags
.
DEFINE_float
(
"adam_epsilon"
,
default
=
1e-8
,
help
=
"Adam epsilon."
)
flags
.
DEFINE_float
(
"lr_layer_decay_rate"
,
default
=
1.0
,
help
=
"Top layer: lr[L] = FLAGS.learning_rate."
"Lower layers: lr[l-1] = lr[l] * lr_layer_decay_rate."
)
flags
.
DEFINE_float
(
"min_lr_ratio"
,
default
=
0.0
,
help
=
"Minimum ratio learning rate."
)
# Training config
flags
.
DEFINE_integer
(
"train_batch_size"
,
default
=
16
,
help
=
"Size of the train batch across all hosts."
)
flags
.
DEFINE_integer
(
"train_steps"
,
default
=
100000
,
help
=
"Total number of training steps."
)
flags
.
DEFINE_integer
(
"iterations"
,
default
=
1000
,
help
=
"Number of iterations per repeat loop."
)
# Data config
flags
.
DEFINE_integer
(
"seq_len"
,
default
=
0
,
help
=
"Sequence length for pretraining."
)
flags
.
DEFINE_integer
(
"reuse_len"
,
default
=
0
,
help
=
"How many tokens to be reused in the next batch. "
"Could be half of `seq_len`."
)
flags
.
DEFINE_bool
(
"uncased"
,
False
,
help
=
"Use uncased inputs or not."
)
flags
.
DEFINE_bool
(
"bi_data"
,
default
=
False
,
help
=
"Use bidirectional data streams, "
"i.e., forward & backward."
)
flags
.
DEFINE_integer
(
"n_token"
,
32000
,
help
=
"Vocab size"
)
# Model config
flags
.
DEFINE_integer
(
"mem_len"
,
default
=
0
,
help
=
"Number of steps to cache"
)
flags
.
DEFINE_bool
(
"same_length"
,
default
=
False
,
help
=
"Same length attention"
)
flags
.
DEFINE_integer
(
"clamp_len"
,
default
=-
1
,
help
=
"Clamp length"
)
flags
.
DEFINE_integer
(
"n_layer"
,
default
=
6
,
help
=
"Number of layers."
)
flags
.
DEFINE_integer
(
"d_model"
,
default
=
32
,
help
=
"Dimension of the model."
)
flags
.
DEFINE_integer
(
"d_embed"
,
default
=
32
,
help
=
"Dimension of the embeddings."
)
flags
.
DEFINE_integer
(
"n_head"
,
default
=
4
,
help
=
"Number of attention heads."
)
flags
.
DEFINE_integer
(
"d_head"
,
default
=
8
,
help
=
"Dimension of each attention head."
)
flags
.
DEFINE_integer
(
"d_inner"
,
default
=
32
,
help
=
"Dimension of inner hidden size in positionwise "
"feed-forward."
)
flags
.
DEFINE_float
(
"dropout"
,
default
=
0.1
,
help
=
"Dropout rate."
)
flags
.
DEFINE_float
(
"dropout_att"
,
default
=
0.1
,
help
=
"Attention dropout rate."
)
flags
.
DEFINE_bool
(
"untie_r"
,
default
=
False
,
help
=
"Untie r_w_bias and r_r_bias"
)
flags
.
DEFINE_string
(
"ff_activation"
,
default
=
"relu"
,
help
=
"Activation type used in position-wise feed-forward."
)
flags
.
DEFINE_string
(
"strategy_type"
,
default
=
"tpu"
,
help
=
"Activation type used in position-wise feed-forward."
)
flags
.
DEFINE_bool
(
"use_bfloat16"
,
False
,
help
=
"Whether to use bfloat16."
)
# Parameter initialization
flags
.
DEFINE_enum
(
"init_method"
,
default
=
"normal"
,
enum_values
=
[
"normal"
,
"uniform"
],
help
=
"Initialization method."
)
flags
.
DEFINE_float
(
"init_std"
,
default
=
0.02
,
help
=
"Initialization std when init is normal."
)
flags
.
DEFINE_float
(
"init_range"
,
default
=
0.1
,
help
=
"Initialization std when init is uniform."
)
flags
.
DEFINE_integer
(
"test_data_size"
,
default
=
12048
,
help
=
"Number of test data samples."
)
flags
.
DEFINE_string
(
"train_tfrecord_path"
,
default
=
None
,
help
=
"Path to preprocessed training set tfrecord."
)
flags
.
DEFINE_string
(
"test_tfrecord_path"
,
default
=
None
,
help
=
"Path to preprocessed test set tfrecord."
)
flags
.
DEFINE_integer
(
"test_batch_size"
,
default
=
16
,
help
=
"Size of the test batch across all hosts."
)
flags
.
DEFINE_integer
(
"save_steps"
,
default
=
1000
,
help
=
"Number of steps for saving checkpoint."
)
FLAGS
=
flags
.
FLAGS
official/
nlp
/xlnet/data_utils.py
→
official/
legacy
/xlnet/data_utils.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/xlnet/optimization.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions and classes related to optimization (weight updates)."""
from
absl
import
logging
import
tensorflow
as
tf
from
official.nlp
import
optimization
class
WarmUp
(
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
):
"""Applys a warmup schedule on a given learning rate decay schedule."""
def
__init__
(
self
,
initial_learning_rate
,
decay_schedule_fn
,
warmup_steps
,
power
=
1.0
,
name
=
None
):
super
(
WarmUp
,
self
).
__init__
()
self
.
initial_learning_rate
=
initial_learning_rate
self
.
warmup_steps
=
warmup_steps
self
.
power
=
power
self
.
decay_schedule_fn
=
decay_schedule_fn
self
.
name
=
name
def
__call__
(
self
,
step
):
with
tf
.
name_scope
(
self
.
name
or
"WarmUp"
)
as
name
:
# Implements polynomial warmup. i.e., if global_step < warmup_steps, the
# learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float
=
tf
.
cast
(
step
,
tf
.
float32
)
warmup_steps_float
=
tf
.
cast
(
self
.
warmup_steps
,
tf
.
float32
)
warmup_percent_done
=
global_step_float
/
warmup_steps_float
warmup_learning_rate
=
(
self
.
initial_learning_rate
*
tf
.
math
.
pow
(
warmup_percent_done
,
self
.
power
))
return
tf
.
cond
(
global_step_float
<
warmup_steps_float
,
lambda
:
warmup_learning_rate
,
lambda
:
self
.
decay_schedule_fn
(
step
-
self
.
warmup_steps
),
name
=
name
)
def
get_config
(
self
):
return
{
"initial_learning_rate"
:
self
.
initial_learning_rate
,
"decay_schedule_fn"
:
self
.
decay_schedule_fn
,
"warmup_steps"
:
self
.
warmup_steps
,
"power"
:
self
.
power
,
"name"
:
self
.
name
}
def
create_optimizer
(
init_lr
,
num_train_steps
,
num_warmup_steps
,
min_lr_ratio
=
0.0
,
adam_epsilon
=
1e-8
,
weight_decay_rate
=
0.0
):
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
learning_rate_fn
=
tf
.
keras
.
optimizers
.
schedules
.
PolynomialDecay
(
initial_learning_rate
=
init_lr
,
decay_steps
=
num_train_steps
-
num_warmup_steps
,
end_learning_rate
=
init_lr
*
min_lr_ratio
)
if
num_warmup_steps
:
learning_rate_fn
=
WarmUp
(
initial_learning_rate
=
init_lr
,
decay_schedule_fn
=
learning_rate_fn
,
warmup_steps
=
num_warmup_steps
)
if
weight_decay_rate
>
0.0
:
logging
.
info
(
"Using AdamWeightDecay with adam_epsilon=%.9f weight_decay_rate=%.3f"
,
adam_epsilon
,
weight_decay_rate
)
optimizer
=
optimization
.
AdamWeightDecay
(
learning_rate
=
learning_rate_fn
,
weight_decay_rate
=
weight_decay_rate
,
beta_1
=
0.9
,
beta_2
=
0.999
,
epsilon
=
adam_epsilon
,
exclude_from_weight_decay
=
[
"LayerNorm"
,
"layer_norm"
,
"bias"
],
include_in_weight_decay
=
[
"r_s_bias"
,
"r_r_bias"
,
"r_w_bias"
])
else
:
logging
.
info
(
"Using Adam with adam_epsilon=%.9f"
,
(
adam_epsilon
))
optimizer
=
tf
.
keras
.
optimizers
.
legacy
.
Adam
(
learning_rate
=
learning_rate_fn
,
epsilon
=
adam_epsilon
)
return
optimizer
,
learning_rate_fn
official/
nlp
/xlnet/preprocess_classification_data.py
→
official/
legacy
/xlnet/preprocess_classification_data.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -26,8 +26,8 @@ import numpy as np
import
tensorflow
as
tf
import
sentencepiece
as
spm
from
official.
nlp
.xlnet
import
classifier_utils
from
official.
nlp
.xlnet
import
preprocess_utils
from
official.
legacy
.xlnet
import
classifier_utils
from
official.
legacy
.xlnet
import
preprocess_utils
flags
.
DEFINE_bool
(
...
...
official/
nlp
/xlnet/preprocess_pretrain_data.py
→
official/
legacy
/xlnet/preprocess_pretrain_data.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -28,7 +28,7 @@ import numpy as np
import
tensorflow.compat.v1
as
tf
import
sentencepiece
as
spm
from
official.
nlp
.xlnet
import
preprocess_utils
from
official.
legacy
.xlnet
import
preprocess_utils
FLAGS
=
flags
.
FLAGS
...
...
official/
nlp
/xlnet/preprocess_squad_data.py
→
official/
legacy
/xlnet/preprocess_squad_data.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -25,7 +25,7 @@ from absl import logging
import
tensorflow
as
tf
import
sentencepiece
as
spm
from
official.
nlp
.xlnet
import
squad_utils
from
official.
legacy
.xlnet
import
squad_utils
flags
.
DEFINE_integer
(
"num_proc"
,
default
=
1
,
help
=
"Number of preprocessing processes."
)
...
...
official/
nlp
/xlnet/preprocess_utils.py
→
official/
legacy
/xlnet/preprocess_utils.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/legacy/xlnet/run_classifier.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""XLNet classification finetuning runner in tf2.0."""
import
functools
# Import libraries
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
# pylint: disable=unused-import
from
official.common
import
distribute_utils
from
official.legacy.xlnet
import
common_flags
from
official.legacy.xlnet
import
data_utils
from
official.legacy.xlnet
import
optimization
from
official.legacy.xlnet
import
training_utils
from
official.legacy.xlnet
import
xlnet_config
from
official.legacy.xlnet
import
xlnet_modeling
as
modeling
flags
.
DEFINE_integer
(
"n_class"
,
default
=
2
,
help
=
"Number of classes."
)
flags
.
DEFINE_string
(
"summary_type"
,
default
=
"last"
,
help
=
"Method used to summarize a sequence into a vector."
)
FLAGS
=
flags
.
FLAGS
def
get_classificationxlnet_model
(
model_config
,
run_config
,
n_class
,
summary_type
=
"last"
):
model
=
modeling
.
ClassificationXLNetModel
(
model_config
,
run_config
,
n_class
,
summary_type
,
name
=
"model"
)
return
model
def
run_evaluation
(
strategy
,
test_input_fn
,
eval_steps
,
model
,
step
,
eval_summary_writer
=
None
):
"""Run evaluation for classification task.
Args:
strategy: distribution strategy.
test_input_fn: input function for evaluation data.
eval_steps: total number of evaluation steps.
model: keras model object.
step: current train step.
eval_summary_writer: summary writer used to record evaluation metrics. As
there are fake data samples in validation set, we use mask to get rid of
them when calculating the accuracy. For the reason that there will be
dynamic-shape tensor, we first collect logits, labels and masks from TPU
and calculate the accuracy via numpy locally.
Returns:
A float metric, accuracy.
"""
def
_test_step_fn
(
inputs
):
"""Replicated validation step."""
inputs
[
"mems"
]
=
None
_
,
logits
=
model
(
inputs
,
training
=
False
)
return
logits
,
inputs
[
"label_ids"
],
inputs
[
"is_real_example"
]
@
tf
.
function
def
_run_evaluation
(
test_iterator
):
"""Runs validation steps."""
logits
,
labels
,
masks
=
strategy
.
run
(
_test_step_fn
,
args
=
(
next
(
test_iterator
),))
return
logits
,
labels
,
masks
test_iterator
=
data_utils
.
get_input_iterator
(
test_input_fn
,
strategy
)
correct
=
0
total
=
0
for
_
in
range
(
eval_steps
):
logits
,
labels
,
masks
=
_run_evaluation
(
test_iterator
)
logits
=
strategy
.
experimental_local_results
(
logits
)
labels
=
strategy
.
experimental_local_results
(
labels
)
masks
=
strategy
.
experimental_local_results
(
masks
)
merged_logits
=
[]
merged_labels
=
[]
merged_masks
=
[]
for
i
in
range
(
strategy
.
num_replicas_in_sync
):
merged_logits
.
append
(
logits
[
i
].
numpy
())
merged_labels
.
append
(
labels
[
i
].
numpy
())
merged_masks
.
append
(
masks
[
i
].
numpy
())
merged_logits
=
np
.
vstack
(
np
.
array
(
merged_logits
))
merged_labels
=
np
.
hstack
(
np
.
array
(
merged_labels
))
merged_masks
=
np
.
hstack
(
np
.
array
(
merged_masks
))
real_index
=
np
.
where
(
np
.
equal
(
merged_masks
,
1
))
correct
+=
np
.
sum
(
np
.
equal
(
np
.
argmax
(
merged_logits
[
real_index
],
axis
=-
1
),
merged_labels
[
real_index
]))
total
+=
np
.
shape
(
real_index
)[
-
1
]
accuracy
=
float
(
correct
)
/
float
(
total
)
logging
.
info
(
"Train step: %d / acc = %d/%d = %f"
,
step
,
correct
,
total
,
accuracy
)
if
eval_summary_writer
:
with
eval_summary_writer
.
as_default
():
tf
.
summary
.
scalar
(
"eval_acc"
,
float
(
correct
)
/
float
(
total
),
step
=
step
)
eval_summary_writer
.
flush
()
return
accuracy
def
get_metric_fn
():
train_acc_metric
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
"acc"
,
dtype
=
tf
.
float32
)
return
train_acc_metric
def
main
(
unused_argv
):
del
unused_argv
strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
FLAGS
.
strategy_type
,
tpu_address
=
FLAGS
.
tpu
)
if
strategy
:
logging
.
info
(
"***** Number of cores used : %d"
,
strategy
.
num_replicas_in_sync
)
train_input_fn
=
functools
.
partial
(
data_utils
.
get_classification_input_data
,
FLAGS
.
train_batch_size
,
FLAGS
.
seq_len
,
strategy
,
True
,
FLAGS
.
train_tfrecord_path
)
test_input_fn
=
functools
.
partial
(
data_utils
.
get_classification_input_data
,
FLAGS
.
test_batch_size
,
FLAGS
.
seq_len
,
strategy
,
False
,
FLAGS
.
test_tfrecord_path
)
total_training_steps
=
FLAGS
.
train_steps
steps_per_loop
=
FLAGS
.
iterations
eval_steps
=
int
(
FLAGS
.
test_data_size
/
FLAGS
.
test_batch_size
)
eval_fn
=
functools
.
partial
(
run_evaluation
,
strategy
,
test_input_fn
,
eval_steps
)
optimizer
,
learning_rate_fn
=
optimization
.
create_optimizer
(
FLAGS
.
learning_rate
,
total_training_steps
,
FLAGS
.
warmup_steps
,
adam_epsilon
=
FLAGS
.
adam_epsilon
)
model_config
=
xlnet_config
.
XLNetConfig
(
FLAGS
)
run_config
=
xlnet_config
.
create_run_config
(
True
,
False
,
FLAGS
)
model_fn
=
functools
.
partial
(
get_classificationxlnet_model
,
model_config
,
run_config
,
FLAGS
.
n_class
,
FLAGS
.
summary_type
)
input_meta_data
=
{}
input_meta_data
[
"d_model"
]
=
FLAGS
.
d_model
input_meta_data
[
"mem_len"
]
=
FLAGS
.
mem_len
input_meta_data
[
"batch_size_per_core"
]
=
int
(
FLAGS
.
train_batch_size
/
strategy
.
num_replicas_in_sync
)
input_meta_data
[
"n_layer"
]
=
FLAGS
.
n_layer
input_meta_data
[
"lr_layer_decay_rate"
]
=
FLAGS
.
lr_layer_decay_rate
input_meta_data
[
"n_class"
]
=
FLAGS
.
n_class
training_utils
.
train
(
strategy
=
strategy
,
model_fn
=
model_fn
,
input_meta_data
=
input_meta_data
,
eval_fn
=
eval_fn
,
metric_fn
=
get_metric_fn
,
train_input_fn
=
train_input_fn
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
init_from_transformerxl
=
FLAGS
.
init_from_transformerxl
,
total_training_steps
=
total_training_steps
,
steps_per_loop
=
steps_per_loop
,
optimizer
=
optimizer
,
learning_rate_fn
=
learning_rate_fn
,
model_dir
=
FLAGS
.
model_dir
,
save_steps
=
FLAGS
.
save_steps
)
if
__name__
==
"__main__"
:
app
.
run
(
main
)
Prev
1
…
7
8
9
10
11
12
13
14
15
…
39
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment