Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
88c864f7
Commit
88c864f7
authored
Oct 23, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Oct 23, 2019
Browse files
Remove unnecessary test_input_fn.
PiperOrigin-RevId: 276394582
parent
a4789d12
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
5 additions
and
11 deletions
+5
-11
official/nlp/xlnet/run_classifier.py
official/nlp/xlnet/run_classifier.py
+0
-1
official/nlp/xlnet/run_pretrain.py
official/nlp/xlnet/run_pretrain.py
+0
-1
official/nlp/xlnet/run_squad.py
official/nlp/xlnet/run_squad.py
+0
-1
official/nlp/xlnet/training_utils.py
official/nlp/xlnet/training_utils.py
+5
-8
No files found.
official/nlp/xlnet/run_classifier.py
View file @
88c864f7
...
...
@@ -184,7 +184,6 @@ def main(unused_argv):
eval_fn
=
eval_fn
,
metric_fn
=
get_metric_fn
,
train_input_fn
=
train_input_fn
,
test_input_fn
=
test_input_fn
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
init_from_transformerxl
=
FLAGS
.
init_from_transformerxl
,
total_training_steps
=
total_training_steps
,
...
...
official/nlp/xlnet/run_pretrain.py
View file @
88c864f7
...
...
@@ -135,7 +135,6 @@ def main(unused_argv):
eval_fn
=
None
,
metric_fn
=
None
,
train_input_fn
=
train_input_fn
,
test_input_fn
=
None
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
init_from_transformerxl
=
FLAGS
.
init_from_transformerxl
,
total_training_steps
=
total_training_steps
,
...
...
official/nlp/xlnet/run_squad.py
View file @
88c864f7
...
...
@@ -281,7 +281,6 @@ def main(unused_argv):
eval_fn
=
eval_fn
,
metric_fn
=
None
,
train_input_fn
=
train_input_fn
,
test_input_fn
=
test_input_fn
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
init_from_transformerxl
=
FLAGS
.
init_from_transformerxl
,
total_training_steps
=
total_training_steps
,
...
...
official/nlp/xlnet/training_utils.py
View file @
88c864f7
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet
classification finetuning runner in tf2.0
."""
"""XLNet
training utils
."""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -61,7 +61,6 @@ def train(
eval_fn
:
Optional
[
Callable
[[
tf
.
keras
.
Model
,
int
,
tf
.
summary
.
SummaryWriter
],
Any
]]
=
None
,
metric_fn
:
Optional
[
Callable
[[],
tf
.
keras
.
metrics
.
Metric
]]
=
None
,
test_input_fn
:
Optional
[
Callable
]
=
None
,
init_checkpoint
:
Optional
[
Text
]
=
None
,
init_from_transformerxl
:
Optional
[
bool
]
=
False
,
model_dir
:
Optional
[
Text
]
=
None
,
...
...
@@ -86,8 +85,6 @@ def train(
metric_fn: A metrics function returns a Keras Metric object to record
evaluation result using evaluation dataset or with training dataset
after every epoch.
test_input_fn: Function returns a evaluation dataset. If none, evaluation
is skipped.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`.
init_from_transformerxl: Whether to load to `transformerxl_model` of
...
...
@@ -124,7 +121,7 @@ def train(
tf
.
io
.
gfile
.
mkdir
(
summary_dir
)
train_summary_writer
=
None
eval_summary_writer
=
None
if
test_input
_fn
:
if
eval
_fn
:
eval_summary_writer
=
tf
.
summary
.
create_file_writer
(
os
.
path
.
join
(
summary_dir
,
"eval"
))
if
steps_per_loop
>=
_MIN_SUMMARY_STEPS
:
...
...
@@ -288,7 +285,7 @@ def train(
_save_checkpoint
(
checkpoint
,
model_dir
,
checkpoint_name
.
format
(
step
=
current_step
))
if
test_input
_fn
and
current_step
%
save_steps
==
0
:
if
eval
_fn
and
current_step
%
save_steps
==
0
:
logging
.
info
(
"Running evaluation after step: %s."
,
current_step
)
...
...
@@ -296,7 +293,7 @@ def train(
if
model_dir
:
_save_checkpoint
(
checkpoint
,
model_dir
,
checkpoint_name
.
format
(
step
=
current_step
))
if
test_input
_fn
:
if
eval
_fn
:
logging
.
info
(
"Running final evaluation after training is complete."
)
eval_metric
=
eval_fn
(
model
,
current_step
,
eval_summary_writer
)
...
...
@@ -306,7 +303,7 @@ def train(
}
if
train_metric
:
training_summary
[
"last_train_metrics"
]
=
_float_metric_value
(
train_metric
)
if
test_input
_fn
:
if
eval
_fn
:
# eval_metric is supposed to be a float.
training_summary
[
"eval_metrics"
]
=
eval_metric
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment