Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f9ac9618
Commit
f9ac9618
authored
Jun 22, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jun 22, 2020
Browse files
Remove this r1 folder from the master branch in June, 2020.
PiperOrigin-RevId: 317772122
parent
d4f5c193
Changes
70
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
2334 deletions
+0
-2334
official/r1/utils/__init__.py
official/r1/utils/__init__.py
+0
-0
official/r1/utils/data/__init__.py
official/r1/utils/data/__init__.py
+0
-0
official/r1/utils/data/file_io.py
official/r1/utils/data/file_io.py
+0
-207
official/r1/utils/data/file_io_test.py
official/r1/utils/data/file_io_test.py
+0
-197
official/r1/utils/export.py
official/r1/utils/export.py
+0
-49
official/r1/utils/export_test.py
official/r1/utils/export_test.py
+0
-63
official/r1/utils/logs/__init__.py
official/r1/utils/logs/__init__.py
+0
-0
official/r1/utils/logs/cloud_lib.py
official/r1/utils/logs/cloud_lib.py
+0
-34
official/r1/utils/logs/cloud_lib_test.py
official/r1/utils/logs/cloud_lib_test.py
+0
-48
official/r1/utils/logs/guidelines.md
official/r1/utils/logs/guidelines.md
+0
-58
official/r1/utils/logs/hooks.py
official/r1/utils/logs/hooks.py
+0
-130
official/r1/utils/logs/hooks_helper.py
official/r1/utils/logs/hooks_helper.py
+0
-173
official/r1/utils/logs/hooks_test.py
official/r1/utils/logs/hooks_test.py
+0
-159
official/r1/utils/logs/logger.py
official/r1/utils/logs/logger.py
+0
-305
official/r1/utils/logs/logger_test.py
official/r1/utils/logs/logger_test.py
+0
-253
official/r1/utils/logs/metric_hook.py
official/r1/utils/logs/metric_hook.py
+0
-97
official/r1/utils/logs/metric_hook_test.py
official/r1/utils/logs/metric_hook_test.py
+0
-217
official/r1/utils/logs/mlperf_helper.py
official/r1/utils/logs/mlperf_helper.py
+0
-192
official/r1/utils/logs/mock_lib.py
official/r1/utils/logs/mock_lib.py
+0
-36
official/r1/utils/tpu.py
official/r1/utils/tpu.py
+0
-116
No files found.
official/r1/utils/__init__.py
deleted
100644 → 0
View file @
d4f5c193
official/r1/utils/data/__init__.py
deleted
100644 → 0
View file @
d4f5c193
official/r1/utils/data/file_io.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convenience functions for managing dataset file buffers."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
atexit
import
multiprocessing
import
multiprocessing.dummy
import
os
import
tempfile
import
uuid
from
absl
import
logging
import
numpy
as
np
import
six
import
tensorflow
as
tf
# pylint:disable=logging-format-interpolation
class
_GarbageCollector
(
object
):
"""Deletes temporary buffer files at exit.
Certain tasks (such as NCF Recommendation) require writing buffers to
temporary files. (Which may be local or distributed.) It is not generally safe
to delete these files during operation, but they should be cleaned up. This
class keeps track of temporary files created, and deletes them at exit.
"""
def
__init__
(
self
):
self
.
temp_buffers
=
[]
def
register
(
self
,
filepath
):
self
.
temp_buffers
.
append
(
filepath
)
def
purge
(
self
):
try
:
for
i
in
self
.
temp_buffers
:
if
tf
.
io
.
gfile
.
exists
(
i
):
tf
.
io
.
gfile
.
remove
(
i
)
logging
.
info
(
"Buffer file {} removed"
.
format
(
i
))
except
Exception
as
e
:
logging
.
error
(
"Failed to cleanup buffer files: {}"
.
format
(
e
))
_GARBAGE_COLLECTOR
=
_GarbageCollector
()
atexit
.
register
(
_GARBAGE_COLLECTOR
.
purge
)
_ROWS_PER_CORE
=
50000
def
write_to_temp_buffer
(
dataframe
,
buffer_folder
,
columns
):
if
buffer_folder
is
None
:
_
,
buffer_path
=
tempfile
.
mkstemp
()
else
:
tf
.
io
.
gfile
.
makedirs
(
buffer_folder
)
buffer_path
=
os
.
path
.
join
(
buffer_folder
,
str
(
uuid
.
uuid4
()))
_GARBAGE_COLLECTOR
.
register
(
buffer_path
)
return
write_to_buffer
(
dataframe
,
buffer_path
,
columns
)
def
iter_shard_dataframe
(
df
,
rows_per_core
=
1000
):
"""Two way shard of a dataframe.
This function evenly shards a dataframe so that it can be mapped efficiently.
It yields a list of dataframes with length equal to the number of CPU cores,
with each dataframe having rows_per_core rows. (Except for the last batch
which may have fewer rows in the dataframes.) Passing vectorized inputs to
a pool is more effecient than iterating through a dataframe in serial and
passing a list of inputs to the pool.
Args:
df: Pandas dataframe to be sharded.
rows_per_core: Number of rows in each shard.
Returns:
A list of dataframe shards.
"""
n
=
len
(
df
)
num_cores
=
min
([
multiprocessing
.
cpu_count
(),
n
])
num_blocks
=
int
(
np
.
ceil
(
n
/
num_cores
/
rows_per_core
))
max_batch_size
=
num_cores
*
rows_per_core
for
i
in
range
(
num_blocks
):
min_index
=
i
*
max_batch_size
max_index
=
min
([(
i
+
1
)
*
max_batch_size
,
n
])
df_shard
=
df
[
min_index
:
max_index
]
n_shard
=
len
(
df_shard
)
boundaries
=
np
.
linspace
(
0
,
n_shard
,
num_cores
+
1
,
dtype
=
np
.
int64
)
yield
[
df_shard
[
boundaries
[
j
]:
boundaries
[
j
+
1
]]
for
j
in
range
(
num_cores
)]
def
_shard_dict_to_examples
(
shard_dict
):
"""Converts a dict of arrays into a list of example bytes."""
n
=
[
i
for
i
in
shard_dict
.
values
()][
0
].
shape
[
0
]
feature_list
=
[{}
for
_
in
range
(
n
)]
for
column
,
values
in
shard_dict
.
items
():
if
len
(
values
.
shape
)
==
1
:
values
=
np
.
reshape
(
values
,
values
.
shape
+
(
1
,))
if
values
.
dtype
.
kind
==
"i"
:
feature_map
=
lambda
x
:
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
x
))
elif
values
.
dtype
.
kind
==
"f"
:
feature_map
=
lambda
x
:
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
x
))
else
:
raise
ValueError
(
"Invalid dtype"
)
for
i
in
range
(
n
):
feature_list
[
i
][
column
]
=
feature_map
(
values
[
i
])
examples
=
[
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
example_features
))
for
example_features
in
feature_list
]
return
[
e
.
SerializeToString
()
for
e
in
examples
]
def
_serialize_shards
(
df_shards
,
columns
,
pool
,
writer
):
"""Map sharded dataframes to bytes, and write them to a buffer.
Args:
df_shards: A list of pandas dataframes. (Should be of similar size)
columns: The dataframe columns to be serialized.
pool: A pool to serialize in parallel.
writer: A TFRecordWriter to write the serialized shards.
"""
# Pandas does not store columns of arrays as nd arrays. stack remedies this.
map_inputs
=
[{
c
:
np
.
stack
(
shard
[
c
].
values
,
axis
=
0
)
for
c
in
columns
}
for
shard
in
df_shards
]
# Failure within pools is very irksome. Thus, it is better to thoroughly check
# inputs in the main process.
for
inp
in
map_inputs
:
# Check that all fields have the same number of rows.
assert
len
(
set
([
v
.
shape
[
0
]
for
v
in
inp
.
values
()]))
==
1
for
val
in
inp
.
values
():
assert
hasattr
(
val
,
"dtype"
)
assert
hasattr
(
val
.
dtype
,
"kind"
)
assert
val
.
dtype
.
kind
in
(
"i"
,
"f"
)
assert
len
(
val
.
shape
)
in
(
1
,
2
)
shard_bytes
=
pool
.
map
(
_shard_dict_to_examples
,
map_inputs
)
for
s
in
shard_bytes
:
for
example
in
s
:
writer
.
write
(
example
)
def
write_to_buffer
(
dataframe
,
buffer_path
,
columns
,
expected_size
=
None
):
"""Write a dataframe to a binary file for a dataset to consume.
Args:
dataframe: The pandas dataframe to be serialized.
buffer_path: The path where the serialized results will be written.
columns: The dataframe columns to be serialized.
expected_size: The size in bytes of the serialized results. This is used to
lazily construct the buffer.
Returns:
The path of the buffer.
"""
if
(
tf
.
io
.
gfile
.
exists
(
buffer_path
)
and
tf
.
io
.
gfile
.
stat
(
buffer_path
).
length
>
0
):
actual_size
=
tf
.
io
.
gfile
.
stat
(
buffer_path
).
length
if
expected_size
==
actual_size
:
return
buffer_path
logging
.
warning
(
"Existing buffer {} has size {}. Expected size {}. Deleting and "
"rebuilding buffer."
.
format
(
buffer_path
,
actual_size
,
expected_size
))
tf
.
io
.
gfile
.
remove
(
buffer_path
)
if
dataframe
is
None
:
raise
ValueError
(
"dataframe was None but a valid existing buffer was not found."
)
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
split
(
buffer_path
)[
0
])
logging
.
info
(
"Constructing TFRecordDataset buffer: {}"
.
format
(
buffer_path
))
count
=
0
pool
=
multiprocessing
.
dummy
.
Pool
(
multiprocessing
.
cpu_count
())
try
:
with
tf
.
io
.
TFRecordWriter
(
buffer_path
)
as
writer
:
for
df_shards
in
iter_shard_dataframe
(
df
=
dataframe
,
rows_per_core
=
_ROWS_PER_CORE
):
_serialize_shards
(
df_shards
,
columns
,
pool
,
writer
)
count
+=
sum
([
len
(
s
)
for
s
in
df_shards
])
logging
.
info
(
"{}/{} examples written."
.
format
(
str
(
count
).
ljust
(
8
),
len
(
dataframe
)))
finally
:
pool
.
terminate
()
logging
.
info
(
"Buffer write complete."
)
return
buffer_path
official/r1/utils/data/file_io_test.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for binary data file utilities."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
contextlib
import
multiprocessing
# pylint: disable=wrong-import-order
import
numpy
as
np
import
pandas
as
pd
import
tensorflow
as
tf
# pylint: enable=wrong-import-order
from
official.r1.utils.data
import
file_io
_RAW_ROW
=
"raw_row"
_DUMMY_COL
=
"column_0"
_DUMMY_VEC_COL
=
"column_1"
_DUMMY_VEC_LEN
=
4
_ROWS_PER_CORE
=
4
_TEST_CASES
=
[
# One batch of one
dict
(
row_count
=
1
,
cpu_count
=
1
,
expected
=
[
[[
0
]]
]),
dict
(
row_count
=
10
,
cpu_count
=
1
,
expected
=
[
[[
0
,
1
,
2
,
3
]],
[[
4
,
5
,
6
,
7
]],
[[
8
,
9
]]
]),
dict
(
row_count
=
21
,
cpu_count
=
1
,
expected
=
[
[[
0
,
1
,
2
,
3
]],
[[
4
,
5
,
6
,
7
]],
[[
8
,
9
,
10
,
11
]],
[[
12
,
13
,
14
,
15
]],
[[
16
,
17
,
18
,
19
]],
[[
20
]]
]),
dict
(
row_count
=
1
,
cpu_count
=
4
,
expected
=
[
[[
0
]]
]),
dict
(
row_count
=
10
,
cpu_count
=
4
,
expected
=
[
[[
0
,
1
],
[
2
,
3
,
4
],
[
5
,
6
],
[
7
,
8
,
9
]]
]),
dict
(
row_count
=
21
,
cpu_count
=
4
,
expected
=
[
[[
0
,
1
,
2
,
3
],
[
4
,
5
,
6
,
7
],
[
8
,
9
,
10
,
11
],
[
12
,
13
,
14
,
15
]],
[[
16
],
[
17
],
[
18
],
[
19
,
20
]]
]),
dict
(
row_count
=
10
,
cpu_count
=
8
,
expected
=
[
[[
0
],
[
1
],
[
2
],
[
3
,
4
],
[
5
],
[
6
],
[
7
],
[
8
,
9
]]
]),
dict
(
row_count
=
40
,
cpu_count
=
8
,
expected
=
[
[[
0
,
1
,
2
,
3
],
[
4
,
5
,
6
,
7
],
[
8
,
9
,
10
,
11
],
[
12
,
13
,
14
,
15
],
[
16
,
17
,
18
,
19
],
[
20
,
21
,
22
,
23
],
[
24
,
25
,
26
,
27
],
[
28
,
29
,
30
,
31
]],
[[
32
],
[
33
],
[
34
],
[
35
],
[
36
],
[
37
],
[
38
],
[
39
]]
]),
]
_FEATURE_MAP
=
{
_RAW_ROW
:
tf
.
io
.
FixedLenFeature
([
1
],
dtype
=
tf
.
int64
),
_DUMMY_COL
:
tf
.
io
.
FixedLenFeature
([
1
],
dtype
=
tf
.
int64
),
_DUMMY_VEC_COL
:
tf
.
io
.
FixedLenFeature
([
_DUMMY_VEC_LEN
],
dtype
=
tf
.
float32
)
}
@
contextlib
.
contextmanager
def
fixed_core_count
(
cpu_count
):
"""Override CPU count.
file_io.py uses the cpu_count function to scale to the size of the instance.
However, this is not desirable for testing because it can make the test flaky.
Instead, this context manager fixes the count for more robust testing.
Args:
cpu_count: How many cores multiprocessing claims to have.
Yields:
Nothing. (for context manager only)
"""
old_count_fn
=
multiprocessing
.
cpu_count
multiprocessing
.
cpu_count
=
lambda
:
cpu_count
yield
multiprocessing
.
cpu_count
=
old_count_fn
class
BaseTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
BaseTest
,
self
).
setUp
()
tf
.
compat
.
v1
.
disable_eager_execution
()
def
_test_sharding
(
self
,
row_count
,
cpu_count
,
expected
):
df
=
pd
.
DataFrame
({
_DUMMY_COL
:
list
(
range
(
row_count
))})
with
fixed_core_count
(
cpu_count
):
shards
=
list
(
file_io
.
iter_shard_dataframe
(
df
,
_ROWS_PER_CORE
))
result
=
[[
j
[
_DUMMY_COL
].
tolist
()
for
j
in
i
]
for
i
in
shards
]
self
.
assertAllEqual
(
expected
,
result
)
def
test_tiny_rows_low_core
(
self
):
self
.
_test_sharding
(
**
_TEST_CASES
[
0
])
def
test_small_rows_low_core
(
self
):
self
.
_test_sharding
(
**
_TEST_CASES
[
1
])
def
test_large_rows_low_core
(
self
):
self
.
_test_sharding
(
**
_TEST_CASES
[
2
])
def
test_tiny_rows_medium_core
(
self
):
self
.
_test_sharding
(
**
_TEST_CASES
[
3
])
def
test_small_rows_medium_core
(
self
):
self
.
_test_sharding
(
**
_TEST_CASES
[
4
])
def
test_large_rows_medium_core
(
self
):
self
.
_test_sharding
(
**
_TEST_CASES
[
5
])
def
test_small_rows_large_core
(
self
):
self
.
_test_sharding
(
**
_TEST_CASES
[
6
])
def
test_large_rows_large_core
(
self
):
self
.
_test_sharding
(
**
_TEST_CASES
[
7
])
def
_serialize_deserialize
(
self
,
num_cores
=
1
,
num_rows
=
20
):
np
.
random
.
seed
(
1
)
df
=
pd
.
DataFrame
({
# Serialization order is only deterministic for num_cores=1. raw_row is
# used in validation after the deserialization.
_RAW_ROW
:
np
.
array
(
range
(
num_rows
),
dtype
=
np
.
int64
),
_DUMMY_COL
:
np
.
random
.
randint
(
0
,
35
,
size
=
(
num_rows
,)),
_DUMMY_VEC_COL
:
[
np
.
array
([
np
.
random
.
random
()
for
_
in
range
(
_DUMMY_VEC_LEN
)])
for
i
in
range
(
num_rows
)
# pylint: disable=unused-variable
]
})
with
fixed_core_count
(
num_cores
):
buffer_path
=
file_io
.
write_to_temp_buffer
(
df
,
self
.
get_temp_dir
(),
[
_RAW_ROW
,
_DUMMY_COL
,
_DUMMY_VEC_COL
])
with
self
.
session
(
graph
=
tf
.
Graph
())
as
sess
:
dataset
=
tf
.
data
.
TFRecordDataset
(
buffer_path
)
dataset
=
dataset
.
batch
(
1
).
map
(
lambda
x
:
tf
.
io
.
parse_example
(
serialized
=
x
,
features
=
_FEATURE_MAP
))
data_iter
=
tf
.
compat
.
v1
.
data
.
make_one_shot_iterator
(
dataset
)
seen_rows
=
set
()
for
i
in
range
(
num_rows
+
5
):
row
=
data_iter
.
get_next
()
try
:
row_id
,
val_0
,
val_1
=
sess
.
run
(
[
row
[
_RAW_ROW
],
row
[
_DUMMY_COL
],
row
[
_DUMMY_VEC_COL
]])
row_id
,
val_0
,
val_1
=
row_id
[
0
][
0
],
val_0
[
0
][
0
],
val_1
[
0
]
assert
row_id
not
in
seen_rows
seen_rows
.
add
(
row_id
)
self
.
assertEqual
(
val_0
,
df
[
_DUMMY_COL
][
row_id
])
self
.
assertAllClose
(
val_1
,
df
[
_DUMMY_VEC_COL
][
row_id
])
self
.
assertLess
(
i
,
num_rows
,
msg
=
"Too many rows."
)
except
tf
.
errors
.
OutOfRangeError
:
self
.
assertGreaterEqual
(
i
,
num_rows
,
msg
=
"Too few rows."
)
file_io
.
_GARBAGE_COLLECTOR
.
purge
()
assert
not
tf
.
io
.
gfile
.
exists
(
buffer_path
)
def
test_serialize_deserialize_0
(
self
):
self
.
_serialize_deserialize
(
num_cores
=
1
)
def
test_serialize_deserialize_1
(
self
):
self
.
_serialize_deserialize
(
num_cores
=
2
)
def
test_serialize_deserialize_2
(
self
):
self
.
_serialize_deserialize
(
num_cores
=
8
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/r1/utils/export.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convenience functions for exporting models as SavedModels or other types."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
def
build_tensor_serving_input_receiver_fn
(
shape
,
dtype
=
tf
.
float32
,
batch_size
=
1
):
"""Returns a input_receiver_fn that can be used during serving.
This expects examples to come through as float tensors, and simply
wraps them as TensorServingInputReceivers.
Arguably, this should live in tf.estimator.export. Testing here first.
Args:
shape: list representing target size of a single example.
dtype: the expected datatype for the input example
batch_size: number of input tensors that will be passed for prediction
Returns:
A function that itself returns a TensorServingInputReceiver.
"""
def
serving_input_receiver_fn
():
# Prep a placeholder where the input example will be fed in
features
=
tf
.
compat
.
v1
.
placeholder
(
dtype
=
dtype
,
shape
=
[
batch_size
]
+
shape
,
name
=
'input_tensor'
)
return
tf
.
estimator
.
export
.
TensorServingInputReceiver
(
features
=
features
,
receiver_tensors
=
features
)
return
serving_input_receiver_fn
official/r1/utils/export_test.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for exporting utils."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.r1.utils
import
export
class
ExportUtilsTest
(
tf
.
test
.
TestCase
):
"""Tests for the ExportUtils."""
def
test_build_tensor_serving_input_receiver_fn
(
self
):
receiver_fn
=
export
.
build_tensor_serving_input_receiver_fn
(
shape
=
[
4
,
5
])
with
tf
.
Graph
().
as_default
():
receiver
=
receiver_fn
()
self
.
assertIsInstance
(
receiver
,
tf
.
estimator
.
export
.
TensorServingInputReceiver
)
self
.
assertIsInstance
(
receiver
.
features
,
tf
.
Tensor
)
self
.
assertEqual
(
receiver
.
features
.
shape
,
tf
.
TensorShape
([
1
,
4
,
5
]))
self
.
assertEqual
(
receiver
.
features
.
dtype
,
tf
.
float32
)
self
.
assertIsInstance
(
receiver
.
receiver_tensors
,
dict
)
# Note that Python 3 can no longer index .values() directly; cast to list.
self
.
assertEqual
(
list
(
receiver
.
receiver_tensors
.
values
())[
0
].
shape
,
tf
.
TensorShape
([
1
,
4
,
5
]))
def
test_build_tensor_serving_input_receiver_fn_batch_dtype
(
self
):
receiver_fn
=
export
.
build_tensor_serving_input_receiver_fn
(
shape
=
[
4
,
5
],
dtype
=
tf
.
int8
,
batch_size
=
10
)
with
tf
.
Graph
().
as_default
():
receiver
=
receiver_fn
()
self
.
assertIsInstance
(
receiver
,
tf
.
estimator
.
export
.
TensorServingInputReceiver
)
self
.
assertIsInstance
(
receiver
.
features
,
tf
.
Tensor
)
self
.
assertEqual
(
receiver
.
features
.
shape
,
tf
.
TensorShape
([
10
,
4
,
5
]))
self
.
assertEqual
(
receiver
.
features
.
dtype
,
tf
.
int8
)
self
.
assertIsInstance
(
receiver
.
receiver_tensors
,
dict
)
# Note that Python 3 can no longer index .values() directly; cast to list.
self
.
assertEqual
(
list
(
receiver
.
receiver_tensors
.
values
())[
0
].
shape
,
tf
.
TensorShape
([
10
,
4
,
5
]))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/r1/utils/logs/__init__.py
deleted
100644 → 0
View file @
d4f5c193
official/r1/utils/logs/cloud_lib.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities that interact with cloud service.
"""
import
requests
GCP_METADATA_URL
=
"http://metadata/computeMetadata/v1/instance/hostname"
GCP_METADATA_HEADER
=
{
"Metadata-Flavor"
:
"Google"
}
def
on_gcp
():
"""Detect whether the current running environment is on GCP."""
try
:
# Timeout in 5 seconds, in case the test environment has connectivity issue.
# There is not default timeout, which means it might block forever.
response
=
requests
.
get
(
GCP_METADATA_URL
,
headers
=
GCP_METADATA_HEADER
,
timeout
=
5
)
return
response
.
status_code
==
200
except
requests
.
exceptions
.
RequestException
:
return
False
official/r1/utils/logs/cloud_lib_test.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for cloud_lib."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
mock
import
requests
from
official.r1.utils.logs
import
cloud_lib
class
CloudLibTest
(
unittest
.
TestCase
):
@
mock
.
patch
(
"requests.get"
)
def
test_on_gcp
(
self
,
mock_requests_get
):
mock_response
=
mock
.
MagicMock
()
mock_requests_get
.
return_value
=
mock_response
mock_response
.
status_code
=
200
self
.
assertEqual
(
cloud_lib
.
on_gcp
(),
True
)
@
mock
.
patch
(
"requests.get"
)
def
test_not_on_gcp
(
self
,
mock_requests_get
):
mock_requests_get
.
side_effect
=
requests
.
exceptions
.
ConnectionError
()
self
.
assertEqual
(
cloud_lib
.
on_gcp
(),
False
)
if
__name__
==
"__main__"
:
unittest
.
main
()
official/r1/utils/logs/guidelines.md
deleted
100644 → 0
View file @
d4f5c193
# Logging in official models
This library adds logging functions that print or save tensor values. Official models should define all common hooks
(using hooks helper) and a benchmark logger.
1.
**Training Hooks**
Hooks are a TensorFlow concept that define specific actions at certain points of the execution. We use them to obtain and log
tensor values during training.
hooks_helper.py provides an easy way to create common hooks. The following hooks are currently defined:
*
LoggingTensorHook: Logs tensor values
*
ProfilerHook: Writes a timeline json that can be loaded into chrome://tracing.
*
ExamplesPerSecondHook: Logs the number of examples processed per second.
*
LoggingMetricHook: Similar to LoggingTensorHook, except that the tensors are logged in a format defined by our data
anaylsis pipeline.
2.
**Benchmarks**
The benchmark logger provides useful functions for logging environment information, and evaluation results.
The module also contains a context which is used to update the status of the run.
Example usage:
```
from absl import app as absl_app
from official.utils.logs import hooks_helper
from official.utils.logs import logger
def model_main(flags_obj):
estimator = ...
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info(...)
train_hooks = hooks_helper.get_train_hooks(...)
for epoch in range(10):
estimator.train(..., hooks=train_hooks)
eval_results = estimator.evaluate(...)
# Log a dictionary of metrics
benchmark_logger.log_evaluation_result(eval_results)
# Log an individual metric
benchmark_logger.log_metric(...)
def main(_):
with logger.benchmark_context(flags.FLAGS):
model_main(flags.FLAGS)
if __name__ == "__main__":
# define flags
absl_app.run(main)
```
official/r1/utils/logs/hooks.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hook that counts examples per second every N steps or seconds."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.r1.utils.logs
import
logger
class
ExamplesPerSecondHook
(
tf
.
estimator
.
SessionRunHook
):
"""Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps
to get the average step time and then batch_size is used to determine
the running average of examples per second. The examples per second for the
most recent interval is also logged.
"""
def
__init__
(
self
,
batch_size
,
every_n_steps
=
None
,
every_n_secs
=
None
,
warm_steps
=
0
,
metric_logger
=
None
):
"""Initializer for ExamplesPerSecondHook.
Args:
batch_size: Total batch size across all workers used to calculate
examples/second from global time.
every_n_steps: Log stats every n steps.
every_n_secs: Log stats every n seconds. Exactly one of the
`every_n_steps` or `every_n_secs` should be set.
warm_steps: The number of steps to be skipped before logging and running
average calculation. warm_steps steps refers to global steps across all
workers, not on each worker
metric_logger: instance of `BenchmarkLogger`, the benchmark logger that
hook should use to write the log. If None, BaseBenchmarkLogger will
be used.
Raises:
ValueError: if neither `every_n_steps` or `every_n_secs` is set, or
both are set.
"""
if
(
every_n_steps
is
None
)
==
(
every_n_secs
is
None
):
raise
ValueError
(
"exactly one of every_n_steps"
" and every_n_secs should be provided."
)
self
.
_logger
=
metric_logger
or
logger
.
BaseBenchmarkLogger
()
self
.
_timer
=
tf
.
estimator
.
SecondOrStepTimer
(
every_steps
=
every_n_steps
,
every_secs
=
every_n_secs
)
self
.
_step_train_time
=
0
self
.
_total_steps
=
0
self
.
_batch_size
=
batch_size
self
.
_warm_steps
=
warm_steps
# List of examples per second logged every_n_steps.
self
.
current_examples_per_sec_list
=
[]
def
begin
(
self
):
"""Called once before using the session to check global step."""
self
.
_global_step_tensor
=
tf
.
compat
.
v1
.
train
.
get_global_step
()
if
self
.
_global_step_tensor
is
None
:
raise
RuntimeError
(
"Global step should be created to use StepCounterHook."
)
def
before_run
(
self
,
run_context
):
# pylint: disable=unused-argument
"""Called before each call to run().
Args:
run_context: A SessionRunContext object.
Returns:
A SessionRunArgs object or None if never triggered.
"""
return
tf
.
estimator
.
SessionRunArgs
(
self
.
_global_step_tensor
)
def
after_run
(
self
,
run_context
,
run_values
):
# pylint: disable=unused-argument
"""Called after each call to run().
Args:
run_context: A SessionRunContext object.
run_values: A SessionRunValues object.
"""
global_step
=
run_values
.
results
if
self
.
_timer
.
should_trigger_for_step
(
global_step
)
and
global_step
>
self
.
_warm_steps
:
elapsed_time
,
elapsed_steps
=
self
.
_timer
.
update_last_triggered_step
(
global_step
)
if
elapsed_time
is
not
None
:
self
.
_step_train_time
+=
elapsed_time
self
.
_total_steps
+=
elapsed_steps
# average examples per second is based on the total (accumulative)
# training steps and training time so far
average_examples_per_sec
=
self
.
_batch_size
*
(
self
.
_total_steps
/
self
.
_step_train_time
)
# current examples per second is based on the elapsed training steps
# and training time per batch
current_examples_per_sec
=
self
.
_batch_size
*
(
elapsed_steps
/
elapsed_time
)
# Logs entries to be read from hook during or after run.
self
.
current_examples_per_sec_list
.
append
(
current_examples_per_sec
)
self
.
_logger
.
log_metric
(
"average_examples_per_sec"
,
average_examples_per_sec
,
global_step
=
global_step
)
self
.
_logger
.
log_metric
(
"current_examples_per_sec"
,
current_examples_per_sec
,
global_step
=
global_step
)
official/r1/utils/logs/hooks_helper.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks helper to return a list of TensorFlow hooks for training by name.
More hooks can be added to this set. To add a new hook, 1) add the new hook to
the registry in HOOKS, 2) add a corresponding function that parses out necessary
parameters.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
absl
import
logging
from
official.r1.utils.logs
import
hooks
from
official.r1.utils.logs
import
logger
from
official.r1.utils.logs
import
metric_hook
_TENSORS_TO_LOG
=
dict
((
x
,
x
)
for
x
in
[
'learning_rate'
,
'cross_entropy'
,
'train_accuracy'
])
def
get_train_hooks
(
name_list
,
use_tpu
=
False
,
**
kwargs
):
"""Factory for getting a list of TensorFlow hooks for training by name.
Args:
name_list: a list of strings to name desired hook classes. Allowed:
LoggingTensorHook, ProfilerHook, ExamplesPerSecondHook, which are defined
as keys in HOOKS
use_tpu: Boolean of whether computation occurs on a TPU. This will disable
hooks altogether.
**kwargs: a dictionary of arguments to the hooks.
Returns:
list of instantiated hooks, ready to be used in a classifier.train call.
Raises:
ValueError: if an unrecognized name is passed.
"""
if
not
name_list
:
return
[]
if
use_tpu
:
logging
.
warning
(
'hooks_helper received name_list `%s`, but a '
'TPU is specified. No hooks will be used.'
,
name_list
)
return
[]
train_hooks
=
[]
for
name
in
name_list
:
hook_name
=
HOOKS
.
get
(
name
.
strip
().
lower
())
if
hook_name
is
None
:
raise
ValueError
(
'Unrecognized training hook requested: {}'
.
format
(
name
))
else
:
train_hooks
.
append
(
hook_name
(
**
kwargs
))
return
train_hooks
def
get_logging_tensor_hook
(
every_n_iter
=
100
,
tensors_to_log
=
None
,
**
kwargs
):
# pylint: disable=unused-argument
"""Function to get LoggingTensorHook.
Args:
every_n_iter: `int`, print the values of `tensors` once every N local
steps taken on the current worker.
tensors_to_log: List of tensor names or dictionary mapping labels to tensor
names. If not set, log _TENSORS_TO_LOG by default.
**kwargs: a dictionary of arguments to LoggingTensorHook.
Returns:
Returns a LoggingTensorHook with a standard set of tensors that will be
printed to stdout.
"""
if
tensors_to_log
is
None
:
tensors_to_log
=
_TENSORS_TO_LOG
return
tf
.
estimator
.
LoggingTensorHook
(
tensors
=
tensors_to_log
,
every_n_iter
=
every_n_iter
)
def
get_profiler_hook
(
model_dir
,
save_steps
=
1000
,
**
kwargs
):
# pylint: disable=unused-argument
"""Function to get ProfilerHook.
Args:
model_dir: The directory to save the profile traces to.
save_steps: `int`, print profile traces every N steps.
**kwargs: a dictionary of arguments to ProfilerHook.
Returns:
Returns a ProfilerHook that writes out timelines that can be loaded into
profiling tools like chrome://tracing.
"""
return
tf
.
estimator
.
ProfilerHook
(
save_steps
=
save_steps
,
output_dir
=
model_dir
)
def
get_examples_per_second_hook
(
every_n_steps
=
100
,
batch_size
=
128
,
warm_steps
=
5
,
**
kwargs
):
# pylint: disable=unused-argument
"""Function to get ExamplesPerSecondHook.
Args:
every_n_steps: `int`, print current and average examples per second every
N steps.
batch_size: `int`, total batch size used to calculate examples/second from
global time.
warm_steps: skip this number of steps before logging and running average.
**kwargs: a dictionary of arguments to ExamplesPerSecondHook.
Returns:
Returns a ProfilerHook that writes out timelines that can be loaded into
profiling tools like chrome://tracing.
"""
return
hooks
.
ExamplesPerSecondHook
(
batch_size
=
batch_size
,
every_n_steps
=
every_n_steps
,
warm_steps
=
warm_steps
,
metric_logger
=
logger
.
get_benchmark_logger
())
def
get_logging_metric_hook
(
tensors_to_log
=
None
,
every_n_secs
=
600
,
**
kwargs
):
# pylint: disable=unused-argument
"""Function to get LoggingMetricHook.
Args:
tensors_to_log: List of tensor names or dictionary mapping labels to tensor
names. If not set, log _TENSORS_TO_LOG by default.
every_n_secs: `int`, the frequency for logging the metric. Default to every
10 mins.
**kwargs: a dictionary of arguments.
Returns:
Returns a LoggingMetricHook that saves tensor values in a JSON format.
"""
if
tensors_to_log
is
None
:
tensors_to_log
=
_TENSORS_TO_LOG
return
metric_hook
.
LoggingMetricHook
(
tensors
=
tensors_to_log
,
metric_logger
=
logger
.
get_benchmark_logger
(),
every_n_secs
=
every_n_secs
)
def
get_step_counter_hook
(
**
kwargs
):
"""Function to get StepCounterHook."""
del
kwargs
return
tf
.
estimator
.
StepCounterHook
()
# A dictionary to map one hook name and its corresponding function
HOOKS
=
{
'loggingtensorhook'
:
get_logging_tensor_hook
,
'profilerhook'
:
get_profiler_hook
,
'examplespersecondhook'
:
get_examples_per_second_hook
,
'loggingmetrichook'
:
get_logging_metric_hook
,
'stepcounterhook'
:
get_step_counter_hook
}
official/r1/utils/logs/hooks_test.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for hooks."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
time
from
absl
import
logging
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.r1.utils.logs
import
hooks
from
official.r1.utils.logs
import
mock_lib
logging
.
set_verbosity
(
logging
.
DEBUG
)
class
ExamplesPerSecondHookTest
(
tf
.
test
.
TestCase
):
"""Tests for the ExamplesPerSecondHook.
In the test, we explicitly run global_step tensor after train_op in order to
keep the global_step value and the train_op (which increase the glboal_step
by 1) consistent. This is to correct the discrepancies in reported global_step
value when running on GPUs.
"""
def
setUp
(
self
):
"""Mock out logging calls to verify if correct info is being monitored."""
self
.
_logger
=
mock_lib
.
MockBenchmarkLogger
()
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
tf
.
compat
.
v1
.
train
.
create_global_step
()
self
.
train_op
=
tf
.
compat
.
v1
.
assign_add
(
tf
.
compat
.
v1
.
train
.
get_global_step
(),
1
)
self
.
global_step
=
tf
.
compat
.
v1
.
train
.
get_global_step
()
def
test_raise_in_both_secs_and_steps
(
self
):
with
self
.
assertRaises
(
ValueError
):
hooks
.
ExamplesPerSecondHook
(
batch_size
=
256
,
every_n_steps
=
10
,
every_n_secs
=
20
,
metric_logger
=
self
.
_logger
)
def
test_raise_in_none_secs_and_steps
(
self
):
with
self
.
assertRaises
(
ValueError
):
hooks
.
ExamplesPerSecondHook
(
batch_size
=
256
,
every_n_steps
=
None
,
every_n_secs
=
None
,
metric_logger
=
self
.
_logger
)
def
_validate_log_every_n_steps
(
self
,
every_n_steps
,
warm_steps
):
hook
=
hooks
.
ExamplesPerSecondHook
(
batch_size
=
256
,
every_n_steps
=
every_n_steps
,
warm_steps
=
warm_steps
,
metric_logger
=
self
.
_logger
)
with
tf
.
compat
.
v1
.
train
.
MonitoredSession
(
tf
.
compat
.
v1
.
train
.
ChiefSessionCreator
(),
[
hook
])
as
mon_sess
:
for
_
in
range
(
every_n_steps
):
# Explicitly run global_step after train_op to get the accurate
# global_step value
mon_sess
.
run
(
self
.
train_op
)
mon_sess
.
run
(
self
.
global_step
)
# Nothing should be in the list yet
self
.
assertFalse
(
self
.
_logger
.
logged_metric
)
mon_sess
.
run
(
self
.
train_op
)
global_step_val
=
mon_sess
.
run
(
self
.
global_step
)
if
global_step_val
>
warm_steps
:
self
.
_assert_metrics
()
else
:
# Nothing should be in the list yet
self
.
assertFalse
(
self
.
_logger
.
logged_metric
)
# Add additional run to verify proper reset when called multiple times.
prev_log_len
=
len
(
self
.
_logger
.
logged_metric
)
mon_sess
.
run
(
self
.
train_op
)
global_step_val
=
mon_sess
.
run
(
self
.
global_step
)
if
every_n_steps
==
1
and
global_step_val
>
warm_steps
:
# Each time, we log two additional metrics. Did exactly 2 get added?
self
.
assertEqual
(
len
(
self
.
_logger
.
logged_metric
),
prev_log_len
+
2
)
else
:
# No change in the size of the metric list.
self
.
assertEqual
(
len
(
self
.
_logger
.
logged_metric
),
prev_log_len
)
def
test_examples_per_sec_every_1_steps
(
self
):
with
self
.
graph
.
as_default
():
self
.
_validate_log_every_n_steps
(
1
,
0
)
def
test_examples_per_sec_every_5_steps
(
self
):
with
self
.
graph
.
as_default
():
self
.
_validate_log_every_n_steps
(
5
,
0
)
def
test_examples_per_sec_every_1_steps_with_warm_steps
(
self
):
with
self
.
graph
.
as_default
():
self
.
_validate_log_every_n_steps
(
1
,
10
)
def
test_examples_per_sec_every_5_steps_with_warm_steps
(
self
):
with
self
.
graph
.
as_default
():
self
.
_validate_log_every_n_steps
(
5
,
10
)
def
_validate_log_every_n_secs
(
self
,
every_n_secs
):
hook
=
hooks
.
ExamplesPerSecondHook
(
batch_size
=
256
,
every_n_steps
=
None
,
every_n_secs
=
every_n_secs
,
metric_logger
=
self
.
_logger
)
with
tf
.
compat
.
v1
.
train
.
MonitoredSession
(
tf
.
compat
.
v1
.
train
.
ChiefSessionCreator
(),
[
hook
])
as
mon_sess
:
# Explicitly run global_step after train_op to get the accurate
# global_step value
mon_sess
.
run
(
self
.
train_op
)
mon_sess
.
run
(
self
.
global_step
)
# Nothing should be in the list yet
self
.
assertFalse
(
self
.
_logger
.
logged_metric
)
time
.
sleep
(
every_n_secs
)
mon_sess
.
run
(
self
.
train_op
)
mon_sess
.
run
(
self
.
global_step
)
self
.
_assert_metrics
()
def
test_examples_per_sec_every_1_secs
(
self
):
with
self
.
graph
.
as_default
():
self
.
_validate_log_every_n_secs
(
1
)
def
test_examples_per_sec_every_5_secs
(
self
):
with
self
.
graph
.
as_default
():
self
.
_validate_log_every_n_secs
(
5
)
def
_assert_metrics
(
self
):
metrics
=
self
.
_logger
.
logged_metric
self
.
assertEqual
(
metrics
[
-
2
][
"name"
],
"average_examples_per_sec"
)
self
.
assertEqual
(
metrics
[
-
1
][
"name"
],
"current_examples_per_sec"
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/r1/utils/logs/logger.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logging utilities for benchmark.
For collecting local environment metrics like CPU and memory, certain python
packages need be installed. See README for details.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
contextlib
import
datetime
import
json
import
numbers
import
os
import
threading
import
uuid
from
absl
import
flags
from
absl
import
logging
from
six.moves
import
_thread
as
thread
import
tensorflow
as
tf
from
tensorflow.python.client
import
device_lib
from
official.r1.utils.logs
import
cloud_lib
METRIC_LOG_FILE_NAME
=
"metric.log"
BENCHMARK_RUN_LOG_FILE_NAME
=
"benchmark_run.log"
_DATE_TIME_FORMAT_PATTERN
=
"%Y-%m-%dT%H:%M:%S.%fZ"
GCP_TEST_ENV
=
"GCP"
RUN_STATUS_SUCCESS
=
"success"
RUN_STATUS_FAILURE
=
"failure"
RUN_STATUS_RUNNING
=
"running"
FLAGS
=
flags
.
FLAGS
# Don't use it directly. Use get_benchmark_logger to access a logger.
_benchmark_logger
=
None
_logger_lock
=
threading
.
Lock
()
def
config_benchmark_logger
(
flag_obj
=
None
):
"""Config the global benchmark logger."""
_logger_lock
.
acquire
()
try
:
global
_benchmark_logger
if
not
flag_obj
:
flag_obj
=
FLAGS
if
(
not
hasattr
(
flag_obj
,
"benchmark_logger_type"
)
or
flag_obj
.
benchmark_logger_type
==
"BaseBenchmarkLogger"
):
_benchmark_logger
=
BaseBenchmarkLogger
()
elif
flag_obj
.
benchmark_logger_type
==
"BenchmarkFileLogger"
:
_benchmark_logger
=
BenchmarkFileLogger
(
flag_obj
.
benchmark_log_dir
)
else
:
raise
ValueError
(
"Unrecognized benchmark_logger_type: %s"
%
flag_obj
.
benchmark_logger_type
)
finally
:
_logger_lock
.
release
()
return
_benchmark_logger
def
get_benchmark_logger
():
if
not
_benchmark_logger
:
config_benchmark_logger
()
return
_benchmark_logger
@
contextlib
.
contextmanager
def
benchmark_context
(
flag_obj
):
"""Context of benchmark, which will update status of the run accordingly."""
benchmark_logger
=
config_benchmark_logger
(
flag_obj
)
try
:
yield
benchmark_logger
.
on_finish
(
RUN_STATUS_SUCCESS
)
except
Exception
:
# pylint: disable=broad-except
# Catch all the exception, update the run status to be failure, and re-raise
benchmark_logger
.
on_finish
(
RUN_STATUS_FAILURE
)
raise
class
BaseBenchmarkLogger
(
object
):
"""Class to log the benchmark information to STDOUT."""
def
log_evaluation_result
(
self
,
eval_results
):
"""Log the evaluation result.
The evaluate result is a dictionary that contains metrics defined in
model_fn. It also contains a entry for global_step which contains the value
of the global step when evaluation was performed.
Args:
eval_results: dict, the result of evaluate.
"""
if
not
isinstance
(
eval_results
,
dict
):
logging
.
warning
(
"eval_results should be dictionary for logging. Got %s"
,
type
(
eval_results
))
return
global_step
=
eval_results
[
tf
.
compat
.
v1
.
GraphKeys
.
GLOBAL_STEP
]
for
key
in
sorted
(
eval_results
):
if
key
!=
tf
.
compat
.
v1
.
GraphKeys
.
GLOBAL_STEP
:
self
.
log_metric
(
key
,
eval_results
[
key
],
global_step
=
global_step
)
def
log_metric
(
self
,
name
,
value
,
unit
=
None
,
global_step
=
None
,
extras
=
None
):
"""Log the benchmark metric information to local file.
Currently the logging is done in a synchronized way. This should be updated
to log asynchronously.
Args:
name: string, the name of the metric to log.
value: number, the value of the metric. The value will not be logged if it
is not a number type.
unit: string, the unit of the metric, E.g "image per second".
global_step: int, the global_step when the metric is logged.
extras: map of string:string, the extra information about the metric.
"""
metric
=
_process_metric_to_json
(
name
,
value
,
unit
,
global_step
,
extras
)
if
metric
:
logging
.
info
(
"Benchmark metric: %s"
,
metric
)
def
log_run_info
(
self
,
model_name
,
dataset_name
,
run_params
,
test_id
=
None
):
logging
.
info
(
"Benchmark run: %s"
,
_gather_run_info
(
model_name
,
dataset_name
,
run_params
,
test_id
))
def
on_finish
(
self
,
status
):
pass
class
BenchmarkFileLogger
(
BaseBenchmarkLogger
):
"""Class to log the benchmark information to local disk."""
def
__init__
(
self
,
logging_dir
):
super
(
BenchmarkFileLogger
,
self
).
__init__
()
self
.
_logging_dir
=
logging_dir
if
not
tf
.
io
.
gfile
.
isdir
(
self
.
_logging_dir
):
tf
.
io
.
gfile
.
makedirs
(
self
.
_logging_dir
)
self
.
_metric_file_handler
=
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
self
.
_logging_dir
,
METRIC_LOG_FILE_NAME
),
"a"
)
def
log_metric
(
self
,
name
,
value
,
unit
=
None
,
global_step
=
None
,
extras
=
None
):
"""Log the benchmark metric information to local file.
Currently the logging is done in a synchronized way. This should be updated
to log asynchronously.
Args:
name: string, the name of the metric to log.
value: number, the value of the metric. The value will not be logged if it
is not a number type.
unit: string, the unit of the metric, E.g "image per second".
global_step: int, the global_step when the metric is logged.
extras: map of string:string, the extra information about the metric.
"""
metric
=
_process_metric_to_json
(
name
,
value
,
unit
,
global_step
,
extras
)
if
metric
:
try
:
json
.
dump
(
metric
,
self
.
_metric_file_handler
)
self
.
_metric_file_handler
.
write
(
"
\n
"
)
self
.
_metric_file_handler
.
flush
()
except
(
TypeError
,
ValueError
)
as
e
:
logging
.
warning
(
"Failed to dump metric to log file: name %s, value %s, error %s"
,
name
,
value
,
e
)
def
log_run_info
(
self
,
model_name
,
dataset_name
,
run_params
,
test_id
=
None
):
"""Collect most of the TF runtime information for the local env.
The schema of the run info follows official/benchmark/datastore/schema.
Args:
model_name: string, the name of the model.
dataset_name: string, the name of dataset for training and evaluation.
run_params: dict, the dictionary of parameters for the run, it could
include hyperparameters or other params that are important for the run.
test_id: string, the unique name of the test run by the combination of key
parameters, eg batch size, num of GPU. It is hardware independent.
"""
run_info
=
_gather_run_info
(
model_name
,
dataset_name
,
run_params
,
test_id
)
with
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
self
.
_logging_dir
,
BENCHMARK_RUN_LOG_FILE_NAME
),
"w"
)
as
f
:
try
:
json
.
dump
(
run_info
,
f
)
f
.
write
(
"
\n
"
)
except
(
TypeError
,
ValueError
)
as
e
:
logging
.
warning
(
"Failed to dump benchmark run info to log file: %s"
,
e
)
def
on_finish
(
self
,
status
):
self
.
_metric_file_handler
.
flush
()
self
.
_metric_file_handler
.
close
()
def
_gather_run_info
(
model_name
,
dataset_name
,
run_params
,
test_id
):
"""Collect the benchmark run information for the local environment."""
run_info
=
{
"model_name"
:
model_name
,
"dataset"
:
{
"name"
:
dataset_name
},
"machine_config"
:
{},
"test_id"
:
test_id
,
"run_date"
:
datetime
.
datetime
.
utcnow
().
strftime
(
_DATE_TIME_FORMAT_PATTERN
)}
_collect_tensorflow_info
(
run_info
)
_collect_tensorflow_environment_variables
(
run_info
)
_collect_run_params
(
run_info
,
run_params
)
_collect_memory_info
(
run_info
)
_collect_test_environment
(
run_info
)
return
run_info
def
_process_metric_to_json
(
name
,
value
,
unit
=
None
,
global_step
=
None
,
extras
=
None
):
"""Validate the metric data and generate JSON for insert."""
if
not
isinstance
(
value
,
numbers
.
Number
):
logging
.
warning
(
"Metric value to log should be a number. Got %s"
,
type
(
value
))
return
None
extras
=
_convert_to_json_dict
(
extras
)
return
{
"name"
:
name
,
"value"
:
float
(
value
),
"unit"
:
unit
,
"global_step"
:
global_step
,
"timestamp"
:
datetime
.
datetime
.
utcnow
().
strftime
(
_DATE_TIME_FORMAT_PATTERN
),
"extras"
:
extras
}
def
_collect_tensorflow_info
(
run_info
):
run_info
[
"tensorflow_version"
]
=
{
"version"
:
tf
.
version
.
VERSION
,
"git_hash"
:
tf
.
version
.
GIT_VERSION
}
def
_collect_run_params
(
run_info
,
run_params
):
"""Log the parameter information for the benchmark run."""
def
process_param
(
name
,
value
):
type_check
=
{
str
:
{
"name"
:
name
,
"string_value"
:
value
},
int
:
{
"name"
:
name
,
"long_value"
:
value
},
bool
:
{
"name"
:
name
,
"bool_value"
:
str
(
value
)},
float
:
{
"name"
:
name
,
"float_value"
:
value
},
}
return
type_check
.
get
(
type
(
value
),
{
"name"
:
name
,
"string_value"
:
str
(
value
)})
if
run_params
:
run_info
[
"run_parameters"
]
=
[
process_param
(
k
,
v
)
for
k
,
v
in
sorted
(
run_params
.
items
())]
def
_collect_tensorflow_environment_variables
(
run_info
):
run_info
[
"tensorflow_environment_variables"
]
=
[
{
"name"
:
k
,
"value"
:
v
}
for
k
,
v
in
sorted
(
os
.
environ
.
items
())
if
k
.
startswith
(
"TF_"
)]
def
_collect_memory_info
(
run_info
):
try
:
# Note: psutil is not installed in the TensorFlow OSS tree.
# It is installable via pip.
import
psutil
# pylint: disable=g-import-not-at-top
vmem
=
psutil
.
virtual_memory
()
run_info
[
"machine_config"
][
"memory_total"
]
=
vmem
.
total
run_info
[
"machine_config"
][
"memory_available"
]
=
vmem
.
available
except
ImportError
:
logging
.
warn
(
"'psutil' not imported. Memory info will not be logged."
)
def
_collect_test_environment
(
run_info
):
"""Detect the local environment, eg GCE, AWS or DGX, etc."""
if
cloud_lib
.
on_gcp
():
run_info
[
"test_environment"
]
=
GCP_TEST_ENV
# TODO(scottzhu): Add more testing env detection for other platform
def
_parse_gpu_model
(
physical_device_desc
):
# Assume all the GPU connected are same model
for
kv
in
physical_device_desc
.
split
(
","
):
k
,
_
,
v
=
kv
.
partition
(
":"
)
if
k
.
strip
()
==
"name"
:
return
v
.
strip
()
return
None
def
_convert_to_json_dict
(
input_dict
):
if
input_dict
:
return
[{
"name"
:
k
,
"value"
:
v
}
for
k
,
v
in
sorted
(
input_dict
.
items
())]
else
:
return
[]
official/r1/utils/logs/logger_test.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for benchmark logger."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
json
import
os
import
tempfile
import
time
import
unittest
from
absl
import
logging
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
from
official.r1.utils.logs
import
logger
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
keras_utils
class
BenchmarkLoggerTest
(
tf
.
test
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
BenchmarkLoggerTest
,
cls
).
setUpClass
()
flags_core
.
define_benchmark
()
def
test_get_default_benchmark_logger
(
self
):
with
flagsaver
.
flagsaver
(
benchmark_logger_type
=
"foo"
):
self
.
assertIsInstance
(
logger
.
get_benchmark_logger
(),
logger
.
BaseBenchmarkLogger
)
def
test_config_base_benchmark_logger
(
self
):
with
flagsaver
.
flagsaver
(
benchmark_logger_type
=
"BaseBenchmarkLogger"
):
logger
.
config_benchmark_logger
()
self
.
assertIsInstance
(
logger
.
get_benchmark_logger
(),
logger
.
BaseBenchmarkLogger
)
def
test_config_benchmark_file_logger
(
self
):
# Set the benchmark_log_dir first since the benchmark_logger_type will need
# the value to be set when it does the validation.
with
flagsaver
.
flagsaver
(
benchmark_log_dir
=
"/tmp"
):
with
flagsaver
.
flagsaver
(
benchmark_logger_type
=
"BenchmarkFileLogger"
):
logger
.
config_benchmark_logger
()
self
.
assertIsInstance
(
logger
.
get_benchmark_logger
(),
logger
.
BenchmarkFileLogger
)
class
BaseBenchmarkLoggerTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
BaseBenchmarkLoggerTest
,
self
).
setUp
()
self
.
_actual_log
=
logging
.
info
self
.
logged_message
=
None
def
mock_log
(
*
args
,
**
kwargs
):
self
.
logged_message
=
args
self
.
_actual_log
(
*
args
,
**
kwargs
)
logging
.
info
=
mock_log
def
tearDown
(
self
):
super
(
BaseBenchmarkLoggerTest
,
self
).
tearDown
()
logging
.
info
=
self
.
_actual_log
def
test_log_metric
(
self
):
log
=
logger
.
BaseBenchmarkLogger
()
log
.
log_metric
(
"accuracy"
,
0.999
,
global_step
=
1e4
,
extras
=
{
"name"
:
"value"
})
expected_log_prefix
=
"Benchmark metric:"
self
.
assertRegexpMatches
(
str
(
self
.
logged_message
),
expected_log_prefix
)
class
BenchmarkFileLoggerTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
BenchmarkFileLoggerTest
,
self
).
setUp
()
# Avoid pulling extra env vars from test environment which affects the test
# result, eg. Kokoro test has a TF_PKG env which affect the test case
# test_collect_tensorflow_environment_variables()
self
.
original_environ
=
dict
(
os
.
environ
)
os
.
environ
.
clear
()
def
tearDown
(
self
):
super
(
BenchmarkFileLoggerTest
,
self
).
tearDown
()
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
os
.
environ
.
clear
()
os
.
environ
.
update
(
self
.
original_environ
)
def
test_create_logging_dir
(
self
):
non_exist_temp_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"unknown_dir"
)
self
.
assertFalse
(
tf
.
io
.
gfile
.
isdir
(
non_exist_temp_dir
))
logger
.
BenchmarkFileLogger
(
non_exist_temp_dir
)
self
.
assertTrue
(
tf
.
io
.
gfile
.
isdir
(
non_exist_temp_dir
))
def
test_log_metric
(
self
):
log_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
())
log
=
logger
.
BenchmarkFileLogger
(
log_dir
)
log
.
log_metric
(
"accuracy"
,
0.999
,
global_step
=
1e4
,
extras
=
{
"name"
:
"value"
})
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
metric_log
))
with
tf
.
io
.
gfile
.
GFile
(
metric_log
)
as
f
:
metric
=
json
.
loads
(
f
.
readline
())
self
.
assertEqual
(
metric
[
"name"
],
"accuracy"
)
self
.
assertEqual
(
metric
[
"value"
],
0.999
)
self
.
assertEqual
(
metric
[
"unit"
],
None
)
self
.
assertEqual
(
metric
[
"global_step"
],
1e4
)
self
.
assertEqual
(
metric
[
"extras"
],
[{
"name"
:
"name"
,
"value"
:
"value"
}])
def
test_log_multiple_metrics
(
self
):
log_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
())
log
=
logger
.
BenchmarkFileLogger
(
log_dir
)
log
.
log_metric
(
"accuracy"
,
0.999
,
global_step
=
1e4
,
extras
=
{
"name"
:
"value"
})
log
.
log_metric
(
"loss"
,
0.02
,
global_step
=
1e4
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
metric_log
))
with
tf
.
io
.
gfile
.
GFile
(
metric_log
)
as
f
:
accuracy
=
json
.
loads
(
f
.
readline
())
self
.
assertEqual
(
accuracy
[
"name"
],
"accuracy"
)
self
.
assertEqual
(
accuracy
[
"value"
],
0.999
)
self
.
assertEqual
(
accuracy
[
"unit"
],
None
)
self
.
assertEqual
(
accuracy
[
"global_step"
],
1e4
)
self
.
assertEqual
(
accuracy
[
"extras"
],
[{
"name"
:
"name"
,
"value"
:
"value"
}])
loss
=
json
.
loads
(
f
.
readline
())
self
.
assertEqual
(
loss
[
"name"
],
"loss"
)
self
.
assertEqual
(
loss
[
"value"
],
0.02
)
self
.
assertEqual
(
loss
[
"unit"
],
None
)
self
.
assertEqual
(
loss
[
"global_step"
],
1e4
)
self
.
assertEqual
(
loss
[
"extras"
],
[])
def
test_log_non_number_value
(
self
):
log_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
())
log
=
logger
.
BenchmarkFileLogger
(
log_dir
)
const
=
tf
.
constant
(
1
)
log
.
log_metric
(
"accuracy"
,
const
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
self
.
assertFalse
(
tf
.
io
.
gfile
.
exists
(
metric_log
))
def
test_log_evaluation_result
(
self
):
eval_result
=
{
"loss"
:
0.46237424
,
"global_step"
:
207082
,
"accuracy"
:
0.9285
}
log_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
())
log
=
logger
.
BenchmarkFileLogger
(
log_dir
)
log
.
log_evaluation_result
(
eval_result
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
metric_log
))
with
tf
.
io
.
gfile
.
GFile
(
metric_log
)
as
f
:
accuracy
=
json
.
loads
(
f
.
readline
())
self
.
assertEqual
(
accuracy
[
"name"
],
"accuracy"
)
self
.
assertEqual
(
accuracy
[
"value"
],
0.9285
)
self
.
assertEqual
(
accuracy
[
"unit"
],
None
)
self
.
assertEqual
(
accuracy
[
"global_step"
],
207082
)
loss
=
json
.
loads
(
f
.
readline
())
self
.
assertEqual
(
loss
[
"name"
],
"loss"
)
self
.
assertEqual
(
loss
[
"value"
],
0.46237424
)
self
.
assertEqual
(
loss
[
"unit"
],
None
)
self
.
assertEqual
(
loss
[
"global_step"
],
207082
)
def
test_log_evaluation_result_with_invalid_type
(
self
):
eval_result
=
"{'loss': 0.46237424, 'global_step': 207082}"
log_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
())
log
=
logger
.
BenchmarkFileLogger
(
log_dir
)
log
.
log_evaluation_result
(
eval_result
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
self
.
assertFalse
(
tf
.
io
.
gfile
.
exists
(
metric_log
))
def
test_collect_tensorflow_info
(
self
):
run_info
=
{}
logger
.
_collect_tensorflow_info
(
run_info
)
self
.
assertNotEqual
(
run_info
[
"tensorflow_version"
],
{})
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"version"
],
tf
.
version
.
VERSION
)
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"git_hash"
],
tf
.
version
.
GIT_VERSION
)
def
test_collect_run_params
(
self
):
run_info
=
{}
run_parameters
=
{
"batch_size"
:
32
,
"synthetic_data"
:
True
,
"train_epochs"
:
100.00
,
"dtype"
:
"fp16"
,
"resnet_size"
:
50
,
"random_tensor"
:
tf
.
constant
(
2.0
)
}
logger
.
_collect_run_params
(
run_info
,
run_parameters
)
self
.
assertEqual
(
len
(
run_info
[
"run_parameters"
]),
6
)
self
.
assertEqual
(
run_info
[
"run_parameters"
][
0
],
{
"name"
:
"batch_size"
,
"long_value"
:
32
})
self
.
assertEqual
(
run_info
[
"run_parameters"
][
1
],
{
"name"
:
"dtype"
,
"string_value"
:
"fp16"
})
v1_tensor
=
{
"name"
:
"random_tensor"
,
"string_value"
:
"Tensor(
\"
Const:0
\"
, shape=(), dtype=float32)"
}
v2_tensor
=
{
"name"
:
"random_tensor"
,
"string_value"
:
"tf.Tensor(2.0, shape=(), dtype=float32)"
}
self
.
assertIn
(
run_info
[
"run_parameters"
][
2
],
[
v1_tensor
,
v2_tensor
])
self
.
assertEqual
(
run_info
[
"run_parameters"
][
3
],
{
"name"
:
"resnet_size"
,
"long_value"
:
50
})
self
.
assertEqual
(
run_info
[
"run_parameters"
][
4
],
{
"name"
:
"synthetic_data"
,
"bool_value"
:
"True"
})
self
.
assertEqual
(
run_info
[
"run_parameters"
][
5
],
{
"name"
:
"train_epochs"
,
"float_value"
:
100.00
})
def
test_collect_tensorflow_environment_variables
(
self
):
os
.
environ
[
"TF_ENABLE_WINOGRAD_NONFUSED"
]
=
"1"
os
.
environ
[
"TF_OTHER"
]
=
"2"
os
.
environ
[
"OTHER"
]
=
"3"
run_info
=
{}
logger
.
_collect_tensorflow_environment_variables
(
run_info
)
self
.
assertIsNotNone
(
run_info
[
"tensorflow_environment_variables"
])
expected_tf_envs
=
[
{
"name"
:
"TF_ENABLE_WINOGRAD_NONFUSED"
,
"value"
:
"1"
},
{
"name"
:
"TF_OTHER"
,
"value"
:
"2"
},
]
self
.
assertEqual
(
run_info
[
"tensorflow_environment_variables"
],
expected_tf_envs
)
def
test_collect_memory_info
(
self
):
run_info
=
{
"machine_config"
:
{}}
logger
.
_collect_memory_info
(
run_info
)
self
.
assertIsNotNone
(
run_info
[
"machine_config"
][
"memory_total"
])
self
.
assertIsNotNone
(
run_info
[
"machine_config"
][
"memory_available"
])
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/r1/utils/logs/metric_hook.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Session hook for logging benchmark metric."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
class
LoggingMetricHook
(
tf
.
estimator
.
LoggingTensorHook
):
"""Hook to log benchmark metric information.
This hook is very similar as tf.train.LoggingTensorHook, which logs given
tensors every N local steps, every N seconds, or at the end. The metric
information will be logged to given log_dir or via metric_logger in JSON
format, which can be consumed by data analysis pipeline later.
Note that if `at_end` is True, `tensors` should not include any tensor
whose evaluation produces a side effect such as consuming additional inputs.
"""
def
__init__
(
self
,
tensors
,
metric_logger
=
None
,
every_n_iter
=
None
,
every_n_secs
=
None
,
at_end
=
False
):
"""Initializer for LoggingMetricHook.
Args:
tensors: `dict` that maps string-valued tags to tensors/tensor names,
or `iterable` of tensors/tensor names.
metric_logger: instance of `BenchmarkLogger`, the benchmark logger that
hook should use to write the log.
every_n_iter: `int`, print the values of `tensors` once every N local
steps taken on the current worker.
every_n_secs: `int` or `float`, print the values of `tensors` once every N
seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
provided.
at_end: `bool` specifying whether to print the values of `tensors` at the
end of the run.
Raises:
ValueError:
1. `every_n_iter` is non-positive, or
2. Exactly one of every_n_iter and every_n_secs should be provided.
3. Exactly one of log_dir and metric_logger should be provided.
"""
super
(
LoggingMetricHook
,
self
).
__init__
(
tensors
=
tensors
,
every_n_iter
=
every_n_iter
,
every_n_secs
=
every_n_secs
,
at_end
=
at_end
)
if
metric_logger
is
None
:
raise
ValueError
(
"metric_logger should be provided."
)
self
.
_logger
=
metric_logger
def
begin
(
self
):
super
(
LoggingMetricHook
,
self
).
begin
()
self
.
_global_step_tensor
=
tf
.
compat
.
v1
.
train
.
get_global_step
()
if
self
.
_global_step_tensor
is
None
:
raise
RuntimeError
(
"Global step should be created to use LoggingMetricHook."
)
if
self
.
_global_step_tensor
.
name
not
in
self
.
_current_tensors
:
self
.
_current_tensors
[
self
.
_global_step_tensor
.
name
]
=
(
self
.
_global_step_tensor
)
def
after_run
(
self
,
unused_run_context
,
run_values
):
# should_trigger is a internal state that populated at before_run, and it is
# using self_timer to determine whether it should trigger.
if
self
.
_should_trigger
:
self
.
_log_metric
(
run_values
.
results
)
self
.
_iter_count
+=
1
def
end
(
self
,
session
):
if
self
.
_log_at_end
:
values
=
session
.
run
(
self
.
_current_tensors
)
self
.
_log_metric
(
values
)
def
_log_metric
(
self
,
tensor_values
):
self
.
_timer
.
update_last_triggered_step
(
self
.
_iter_count
)
global_step
=
tensor_values
[
self
.
_global_step_tensor
.
name
]
# self._tag_order is populated during the init of LoggingTensorHook
for
tag
in
self
.
_tag_order
:
self
.
_logger
.
log_metric
(
tag
,
tensor_values
[
tag
],
global_step
=
global_step
)
official/r1/utils/logs/metric_hook_test.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metric_hook."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tempfile
import
time
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
tensorflow.python.training
import
monitored_session
# pylint: disable=g-bad-import-order
from
official.r1.utils.logs
import
metric_hook
from
official.r1.utils.logs
import
mock_lib
class
LoggingMetricHookTest
(
tf
.
test
.
TestCase
):
"""Tests for LoggingMetricHook."""
def
setUp
(
self
):
super
(
LoggingMetricHookTest
,
self
).
setUp
()
self
.
_log_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
())
self
.
_logger
=
mock_lib
.
MockBenchmarkLogger
()
def
tearDown
(
self
):
super
(
LoggingMetricHookTest
,
self
).
tearDown
()
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
def
test_illegal_args
(
self
):
with
self
.
assertRaisesRegexp
(
ValueError
,
"nvalid every_n_iter"
):
metric_hook
.
LoggingMetricHook
(
tensors
=
[
"t"
],
every_n_iter
=
0
)
with
self
.
assertRaisesRegexp
(
ValueError
,
"nvalid every_n_iter"
):
metric_hook
.
LoggingMetricHook
(
tensors
=
[
"t"
],
every_n_iter
=-
10
)
with
self
.
assertRaisesRegexp
(
ValueError
,
"xactly one of"
):
metric_hook
.
LoggingMetricHook
(
tensors
=
[
"t"
],
every_n_iter
=
5
,
every_n_secs
=
5
)
with
self
.
assertRaisesRegexp
(
ValueError
,
"xactly one of"
):
metric_hook
.
LoggingMetricHook
(
tensors
=
[
"t"
])
with
self
.
assertRaisesRegexp
(
ValueError
,
"metric_logger"
):
metric_hook
.
LoggingMetricHook
(
tensors
=
[
"t"
],
every_n_iter
=
5
)
def
test_print_at_end_only
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
t
=
tf
.
constant
(
42.0
,
name
=
"foo"
)
train_op
=
tf
.
constant
(
3
)
hook
=
metric_hook
.
LoggingMetricHook
(
tensors
=
[
t
.
name
],
at_end
=
True
,
metric_logger
=
self
.
_logger
)
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
for
_
in
range
(
3
):
mon_sess
.
run
(
train_op
)
self
.
assertEqual
(
self
.
_logger
.
logged_metric
,
[])
hook
.
end
(
sess
)
self
.
assertEqual
(
len
(
self
.
_logger
.
logged_metric
),
1
)
metric
=
self
.
_logger
.
logged_metric
[
0
]
self
.
assertRegexpMatches
(
metric
[
"name"
],
"foo"
)
self
.
assertEqual
(
metric
[
"value"
],
42.0
)
self
.
assertEqual
(
metric
[
"unit"
],
None
)
self
.
assertEqual
(
metric
[
"global_step"
],
0
)
def
test_global_step_not_found
(
self
):
with
tf
.
Graph
().
as_default
():
t
=
tf
.
constant
(
42.0
,
name
=
"foo"
)
hook
=
metric_hook
.
LoggingMetricHook
(
tensors
=
[
t
.
name
],
at_end
=
True
,
metric_logger
=
self
.
_logger
)
with
self
.
assertRaisesRegexp
(
RuntimeError
,
"should be created to use LoggingMetricHook."
):
hook
.
begin
()
def
test_log_tensors
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
t1
=
tf
.
constant
(
42.0
,
name
=
"foo"
)
t2
=
tf
.
constant
(
43.0
,
name
=
"bar"
)
train_op
=
tf
.
constant
(
3
)
hook
=
metric_hook
.
LoggingMetricHook
(
tensors
=
[
t1
,
t2
],
at_end
=
True
,
metric_logger
=
self
.
_logger
)
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
for
_
in
range
(
3
):
mon_sess
.
run
(
train_op
)
self
.
assertEqual
(
self
.
_logger
.
logged_metric
,
[])
hook
.
end
(
sess
)
self
.
assertEqual
(
len
(
self
.
_logger
.
logged_metric
),
2
)
metric1
=
self
.
_logger
.
logged_metric
[
0
]
self
.
assertRegexpMatches
(
str
(
metric1
[
"name"
]),
"foo"
)
self
.
assertEqual
(
metric1
[
"value"
],
42.0
)
self
.
assertEqual
(
metric1
[
"unit"
],
None
)
self
.
assertEqual
(
metric1
[
"global_step"
],
0
)
metric2
=
self
.
_logger
.
logged_metric
[
1
]
self
.
assertRegexpMatches
(
str
(
metric2
[
"name"
]),
"bar"
)
self
.
assertEqual
(
metric2
[
"value"
],
43.0
)
self
.
assertEqual
(
metric2
[
"unit"
],
None
)
self
.
assertEqual
(
metric2
[
"global_step"
],
0
)
def
_validate_print_every_n_steps
(
self
,
sess
,
at_end
):
t
=
tf
.
constant
(
42.0
,
name
=
"foo"
)
train_op
=
tf
.
constant
(
3
)
hook
=
metric_hook
.
LoggingMetricHook
(
tensors
=
[
t
.
name
],
every_n_iter
=
10
,
at_end
=
at_end
,
metric_logger
=
self
.
_logger
)
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
mon_sess
.
run
(
train_op
)
self
.
assertRegexpMatches
(
str
(
self
.
_logger
.
logged_metric
),
t
.
name
)
for
_
in
range
(
3
):
self
.
_logger
.
logged_metric
=
[]
for
_
in
range
(
9
):
mon_sess
.
run
(
train_op
)
# assertNotRegexpMatches is not supported by python 3.1 and later
self
.
assertEqual
(
str
(
self
.
_logger
.
logged_metric
).
find
(
t
.
name
),
-
1
)
mon_sess
.
run
(
train_op
)
self
.
assertRegexpMatches
(
str
(
self
.
_logger
.
logged_metric
),
t
.
name
)
# Add additional run to verify proper reset when called multiple times.
self
.
_logger
.
logged_metric
=
[]
mon_sess
.
run
(
train_op
)
# assertNotRegexpMatches is not supported by python 3.1 and later
self
.
assertEqual
(
str
(
self
.
_logger
.
logged_metric
).
find
(
t
.
name
),
-
1
)
self
.
_logger
.
logged_metric
=
[]
hook
.
end
(
sess
)
if
at_end
:
self
.
assertRegexpMatches
(
str
(
self
.
_logger
.
logged_metric
),
t
.
name
)
else
:
# assertNotRegexpMatches is not supported by python 3.1 and later
self
.
assertEqual
(
str
(
self
.
_logger
.
logged_metric
).
find
(
t
.
name
),
-
1
)
def
test_print_every_n_steps
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
False
)
# Verify proper reset.
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
False
)
def
test_print_every_n_steps_and_end
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
True
)
# Verify proper reset.
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
True
)
def
_validate_print_every_n_secs
(
self
,
sess
,
at_end
):
t
=
tf
.
constant
(
42.0
,
name
=
"foo"
)
train_op
=
tf
.
constant
(
3
)
hook
=
metric_hook
.
LoggingMetricHook
(
tensors
=
[
t
.
name
],
every_n_secs
=
1.0
,
at_end
=
at_end
,
metric_logger
=
self
.
_logger
)
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
mon_sess
.
run
(
train_op
)
self
.
assertRegexpMatches
(
str
(
self
.
_logger
.
logged_metric
),
t
.
name
)
# assertNotRegexpMatches is not supported by python 3.1 and later
self
.
_logger
.
logged_metric
=
[]
mon_sess
.
run
(
train_op
)
self
.
assertEqual
(
str
(
self
.
_logger
.
logged_metric
).
find
(
t
.
name
),
-
1
)
time
.
sleep
(
1.0
)
self
.
_logger
.
logged_metric
=
[]
mon_sess
.
run
(
train_op
)
self
.
assertRegexpMatches
(
str
(
self
.
_logger
.
logged_metric
),
t
.
name
)
self
.
_logger
.
logged_metric
=
[]
hook
.
end
(
sess
)
if
at_end
:
self
.
assertRegexpMatches
(
str
(
self
.
_logger
.
logged_metric
),
t
.
name
)
else
:
# assertNotRegexpMatches is not supported by python 3.1 and later
self
.
assertEqual
(
str
(
self
.
_logger
.
logged_metric
).
find
(
t
.
name
),
-
1
)
def
test_print_every_n_secs
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
False
)
# Verify proper reset.
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
False
)
def
test_print_every_n_secs_and_end
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
True
)
# Verify proper reset.
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
True
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/r1/utils/logs/mlperf_helper.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for the mlperf logging utils.
MLPerf compliance logging is only desired under a limited set of circumstances.
This module is intended to keep users from needing to consider logging (or
install the module) unless they are performing mlperf runs.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
namedtuple
import
json
import
os
import
re
import
subprocess
import
sys
from
absl
import
logging
import
typing
# pylint:disable=logging-format-interpolation
_MIN_VERSION
=
(
0
,
0
,
10
)
_STACK_OFFSET
=
2
SUDO
=
"sudo"
if
os
.
geteuid
()
else
""
# This indirection is used in docker.
DROP_CACHE_LOC
=
os
.
getenv
(
"DROP_CACHE_LOC"
,
"/proc/sys/vm/drop_caches"
)
_NCF_PREFIX
=
"NCF_RAW_"
# TODO(robieta): move line parsing to mlperf util
_PREFIX
=
r
"(?:{})?:::MLPv([0-9]+).([0-9]+).([0-9]+)"
.
format
(
_NCF_PREFIX
)
_BENCHMARK
=
r
"([a-zA-Z0-9_]+)"
_TIMESTAMP
=
r
"([0-9]+\.[0-9]+)"
_CALLSITE
=
r
"\((.+):([0-9]+)\)"
_TAG
=
r
"([a-zA-Z0-9_]+)"
_VALUE
=
r
"(.*)"
ParsedLine
=
namedtuple
(
"ParsedLine"
,
[
"version"
,
"benchmark"
,
"timestamp"
,
"callsite"
,
"tag"
,
"value"
])
LINE_PATTERN
=
re
.
compile
(
"^{prefix} {benchmark} {timestamp} {callsite} {tag}(: |$){value}?$"
.
format
(
prefix
=
_PREFIX
,
benchmark
=
_BENCHMARK
,
timestamp
=
_TIMESTAMP
,
callsite
=
_CALLSITE
,
tag
=
_TAG
,
value
=
_VALUE
))
def
parse_line
(
line
):
# type: (str) -> typing.Optional[ParsedLine]
match
=
LINE_PATTERN
.
match
(
line
.
strip
())
if
not
match
:
return
major
,
minor
,
micro
,
benchmark
,
timestamp
=
match
.
groups
()[:
5
]
call_file
,
call_line
,
tag
,
_
,
value
=
match
.
groups
()[
5
:]
return
ParsedLine
(
version
=
(
int
(
major
),
int
(
minor
),
int
(
micro
)),
benchmark
=
benchmark
,
timestamp
=
timestamp
,
callsite
=
(
call_file
,
call_line
),
tag
=
tag
,
value
=
value
)
def
unparse_line
(
parsed_line
):
# type: (ParsedLine) -> str
version_str
=
"{}.{}.{}"
.
format
(
*
parsed_line
.
version
)
callsite_str
=
"({}:{})"
.
format
(
*
parsed_line
.
callsite
)
value_str
=
": {}"
.
format
(
parsed_line
.
value
)
if
parsed_line
.
value
else
""
return
":::MLPv{} {} {} {} {} {}"
.
format
(
version_str
,
parsed_line
.
benchmark
,
parsed_line
.
timestamp
,
callsite_str
,
parsed_line
.
tag
,
value_str
)
def
get_mlperf_log
():
"""Shielded import of mlperf_log module."""
try
:
import
mlperf_compliance
def
test_mlperf_log_pip_version
():
"""Check that mlperf_compliance is up to date."""
import
pkg_resources
version
=
pkg_resources
.
get_distribution
(
"mlperf_compliance"
)
version
=
tuple
(
int
(
i
)
for
i
in
version
.
version
.
split
(
"."
))
if
version
<
_MIN_VERSION
:
logging
.
warning
(
"mlperf_compliance is version {}, must be >= {}"
.
format
(
"."
.
join
([
str
(
i
)
for
i
in
version
]),
"."
.
join
([
str
(
i
)
for
i
in
_MIN_VERSION
])))
raise
ImportError
return
mlperf_compliance
.
mlperf_log
mlperf_log
=
test_mlperf_log_pip_version
()
except
ImportError
:
mlperf_log
=
None
return
mlperf_log
class
Logger
(
object
):
"""MLPerf logger indirection class.
This logger only logs for MLPerf runs, and prevents various errors associated
with not having the mlperf_compliance package installed.
"""
class
Tags
(
object
):
def
__init__
(
self
,
mlperf_log
):
self
.
_enabled
=
False
self
.
_mlperf_log
=
mlperf_log
def
__getattr__
(
self
,
item
):
if
self
.
_mlperf_log
is
None
or
not
self
.
_enabled
:
return
return
getattr
(
self
.
_mlperf_log
,
item
)
def
__init__
(
self
):
self
.
_enabled
=
False
self
.
_mlperf_log
=
get_mlperf_log
()
self
.
tags
=
self
.
Tags
(
self
.
_mlperf_log
)
def
__call__
(
self
,
enable
=
False
):
if
enable
and
self
.
_mlperf_log
is
None
:
raise
ImportError
(
"MLPerf logging was requested, but mlperf_compliance "
"module could not be loaded."
)
self
.
_enabled
=
enable
self
.
tags
.
_enabled
=
enable
return
self
def
__enter__
(
self
):
pass
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
_enabled
=
False
self
.
tags
.
_enabled
=
False
@
property
def
log_file
(
self
):
if
self
.
_mlperf_log
is
None
:
return
return
self
.
_mlperf_log
.
LOG_FILE
@
property
def
enabled
(
self
):
return
self
.
_enabled
def
ncf_print
(
self
,
key
,
value
=
None
,
stack_offset
=
_STACK_OFFSET
,
deferred
=
False
,
extra_print
=
False
,
prefix
=
_NCF_PREFIX
):
if
self
.
_mlperf_log
is
None
or
not
self
.
enabled
:
return
self
.
_mlperf_log
.
ncf_print
(
key
=
key
,
value
=
value
,
stack_offset
=
stack_offset
,
deferred
=
deferred
,
extra_print
=
extra_print
,
prefix
=
prefix
)
def
set_ncf_root
(
self
,
path
):
if
self
.
_mlperf_log
is
None
:
return
self
.
_mlperf_log
.
ROOT_DIR_NCF
=
path
LOGGER
=
Logger
()
ncf_print
,
set_ncf_root
=
LOGGER
.
ncf_print
,
LOGGER
.
set_ncf_root
TAGS
=
LOGGER
.
tags
def
clear_system_caches
():
if
not
LOGGER
.
enabled
:
return
ret_code
=
subprocess
.
call
(
[
"sync && echo 3 | {} tee {}"
.
format
(
SUDO
,
DROP_CACHE_LOC
)],
shell
=
True
)
if
ret_code
:
raise
ValueError
(
"Failed to clear caches"
)
if
__name__
==
"__main__"
:
logging
.
set_verbosity
(
logging
.
INFO
)
with
LOGGER
(
True
):
ncf_print
(
key
=
TAGS
.
RUN_START
)
official/r1/utils/logs/mock_lib.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mock objects and related functions for testing."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
class
MockBenchmarkLogger
(
object
):
"""This is a mock logger that can be used in dependent tests."""
def
__init__
(
self
):
self
.
logged_metric
=
[]
def
log_metric
(
self
,
name
,
value
,
unit
=
None
,
global_step
=
None
,
extras
=
None
):
self
.
logged_metric
.
append
({
"name"
:
name
,
"value"
:
float
(
value
),
"unit"
:
unit
,
"global_step"
:
global_step
,
"extras"
:
extras
})
official/r1/utils/tpu.py
deleted
100644 → 0
View file @
d4f5c193
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions specific to running TensorFlow on TPUs."""
import
tensorflow
as
tf
# "local" is a magic word in the TPU cluster resolver; it informs the resolver
# to use the local CPU as the compute device. This is useful for testing and
# debugging; the code flow is ostensibly identical, but without the need to
# actually have a TPU on the other end.
LOCAL
=
"local"
def
construct_scalar_host_call
(
metric_dict
,
model_dir
,
prefix
=
""
):
"""Construct a host call to log scalars when training on TPU.
Args:
metric_dict: A dict of the tensors to be logged.
model_dir: The location to write the summary.
prefix: The prefix (if any) to prepend to the metric names.
Returns:
A tuple of (function, args_to_be_passed_to_said_function)
"""
# type: (dict, str) -> (function, list)
metric_names
=
list
(
metric_dict
.
keys
())
def
host_call_fn
(
global_step
,
*
args
):
"""Training host call. Creates scalar summaries for training metrics.
This function is executed on the CPU and should not directly reference
any Tensors in the rest of the `model_fn`. To pass Tensors from the
model to the `metric_fn`, provide as part of the `host_call`. See
https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
for more information.
Arguments should match the list of `Tensor` objects passed as the second
element in the tuple passed to `host_call`.
Args:
global_step: `Tensor with shape `[batch]` for the global_step
*args: Remaining tensors to log.
Returns:
List of summary ops to run on the CPU host.
"""
step
=
global_step
[
0
]
with
tf
.
compat
.
v1
.
summary
.
create_file_writer
(
logdir
=
model_dir
,
filename_suffix
=
".host_call"
).
as_default
():
with
tf
.
compat
.
v1
.
summary
.
always_record_summaries
():
for
i
,
name
in
enumerate
(
metric_names
):
tf
.
compat
.
v1
.
summary
.
scalar
(
prefix
+
name
,
args
[
i
][
0
],
step
=
step
)
return
tf
.
compat
.
v1
.
summary
.
all_summary_ops
()
# To log the current learning rate, and gradient norm for Tensorboard, the
# summary op needs to be run on the host CPU via host_call. host_call
# expects [batch_size, ...] Tensors, thus reshape to introduce a batch
# dimension. These Tensors are implicitly concatenated to
# [params['batch_size']].
global_step_tensor
=
tf
.
reshape
(
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
(),
[
1
])
other_tensors
=
[
tf
.
reshape
(
metric_dict
[
key
],
[
1
])
for
key
in
metric_names
]
return
host_call_fn
,
[
global_step_tensor
]
+
other_tensors
def
embedding_matmul
(
embedding_table
,
values
,
mask
,
name
=
"embedding_matmul"
):
"""Performs embedding lookup via a matmul.
The matrix to be multiplied by the embedding table Tensor is constructed
via an implementation of scatter based on broadcasting embedding indices
and performing an equality comparison against a broadcasted
range(num_embedding_table_rows). All masked positions will produce an
embedding vector of zeros.
Args:
embedding_table: Tensor of embedding table.
Rank 2 (table_size x embedding dim)
values: Tensor of embedding indices. Rank 2 (batch x n_indices)
mask: Tensor of mask / weights. Rank 2 (batch x n_indices)
name: Optional name scope for created ops
Returns:
Rank 3 tensor of embedding vectors.
"""
with
tf
.
name_scope
(
name
):
n_embeddings
=
embedding_table
.
get_shape
().
as_list
()[
0
]
batch_size
,
padded_size
=
values
.
shape
.
as_list
()
emb_idcs
=
tf
.
tile
(
tf
.
reshape
(
values
,
(
batch_size
,
padded_size
,
1
)),
(
1
,
1
,
n_embeddings
))
emb_weights
=
tf
.
tile
(
tf
.
reshape
(
mask
,
(
batch_size
,
padded_size
,
1
)),
(
1
,
1
,
n_embeddings
))
col_idcs
=
tf
.
tile
(
tf
.
reshape
(
tf
.
range
(
n_embeddings
),
(
1
,
1
,
n_embeddings
)),
(
batch_size
,
padded_size
,
1
))
one_hot
=
tf
.
where
(
tf
.
equal
(
emb_idcs
,
col_idcs
),
emb_weights
,
tf
.
zeros
((
batch_size
,
padded_size
,
n_embeddings
)))
return
tf
.
tensordot
(
one_hot
,
embedding_table
,
1
)
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment