Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1f2cebfa
Commit
1f2cebfa
authored
Aug 16, 2019
by
Priya Gupta
Committed by
A. Unique TensorFlower
Aug 16, 2019
Browse files
fix monkey patch for synthetic data for resnet keras model.
PiperOrigin-RevId: 263854996
parent
5a309240
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
47 additions
and
13 deletions
+47
-13
official/utils/misc/distribution_utils.py
official/utils/misc/distribution_utils.py
+47
-13
No files found.
official/utils/misc/distribution_utils.py
View file @
1f2cebfa
...
@@ -205,38 +205,64 @@ class SyntheticDataset(object):
...
@@ -205,38 +205,64 @@ class SyntheticDataset(object):
"""A dataset that generates synthetic data on each device."""
"""A dataset that generates synthetic data on each device."""
def
__init__
(
self
,
dataset
,
split_by
=
1
):
def
__init__
(
self
,
dataset
,
split_by
=
1
):
self
.
_input_data
=
{}
# dataset.take(1) doesn't have GPU kernel.
# dataset.take(1) doesn't have GPU kernel.
with
tf
.
device
(
'device:CPU:0'
):
with
tf
.
device
(
'device:CPU:0'
):
tensor
=
tf
.
data
.
experimental
.
get_single_element
(
dataset
.
take
(
1
))
tensor
=
tf
.
data
.
experimental
.
get_single_element
(
dataset
.
take
(
1
))
flat_tensor
=
tf
.
nest
.
flatten
(
tensor
)
flat_tensor
=
tf
.
nest
.
flatten
(
tensor
)
variable_data
=
[]
variable_data
=
[]
self
.
_
initializers
=
[]
initializers
=
[]
for
t
in
flat_tensor
:
for
t
in
flat_tensor
:
rebatched_t
=
tf
.
split
(
t
,
num_or_size_splits
=
split_by
,
axis
=
0
)[
0
]
rebatched_t
=
tf
.
split
(
t
,
num_or_size_splits
=
split_by
,
axis
=
0
)[
0
]
assert
rebatched_t
.
shape
.
is_fully_defined
(),
rebatched_t
.
shape
assert
rebatched_t
.
shape
.
is_fully_defined
(),
rebatched_t
.
shape
v
=
tf
.
compat
.
v1
.
get_local_variable
(
self
.
random_name
(),
v
=
tf
.
compat
.
v1
.
get_local_variable
(
self
.
_
random_name
(),
initializer
=
rebatched_t
)
initializer
=
rebatched_t
)
variable_data
.
append
(
v
)
variable_data
.
append
(
v
)
self
.
_initializers
.
append
(
v
.
initializer
)
initializers
.
append
(
v
.
initializer
)
self
.
_input_data
=
tf
.
nest
.
pack_sequence_as
(
tensor
,
variable_data
)
input_data
=
tf
.
nest
.
pack_sequence_as
(
tensor
,
variable_data
)
self
.
_iterator
=
SyntheticIterator
(
input_data
,
initializers
)
def
_random_name
(
self
,
size
=
10
,
chars
=
string
.
ascii_uppercase
+
string
.
digits
):
return
''
.
join
(
random
.
choice
(
chars
)
for
_
in
range
(
size
))
def
__iter__
(
self
):
return
self
.
_iterator
def
make_one_shot_iterator
(
self
):
return
self
.
_iterator
def
make_initializable_iterator
(
self
):
return
self
.
_iterator
class
SyntheticIterator
(
object
):
"""A dataset that generates synthetic data on each device."""
def
__init__
(
self
,
input_data
,
initializers
):
self
.
_input_data
=
input_data
self
.
_initializers
=
initializers
def
get_next
(
self
):
def
get_next
(
self
):
return
self
.
_input_data
return
self
.
_input_data
def
next
(
self
):
return
self
.
__next__
()
def
__next__
(
self
):
try
:
return
self
.
get_next
()
except
tf
.
errors
.
OutOfRangeError
:
raise
StopIteration
def
initialize
(
self
):
def
initialize
(
self
):
if
tf
.
executing_eagerly
():
if
tf
.
executing_eagerly
():
return
tf
.
no_op
()
return
tf
.
no_op
()
else
:
else
:
return
self
.
_initializers
return
self
.
_initializers
def
random_name
(
self
,
size
=
10
,
chars
=
string
.
ascii_uppercase
+
string
.
digits
):
return
''
.
join
(
random
.
choice
(
chars
)
for
_
in
range
(
size
))
def
_monkey_patch_dataset_method
(
strategy
):
def
_monkey_patch_dataset_method
(
strategy
):
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
"""Monkey-patch `strategy`'s `make_dataset_iterator` method."""
def
make_dataset
_iterator
(
self
,
dataset
):
def
make_dataset
(
self
,
dataset
):
tf
.
compat
.
v1
.
logging
.
info
(
'Using pure synthetic data.'
)
tf
.
compat
.
v1
.
logging
.
info
(
'Using pure synthetic data.'
)
with
self
.
scope
():
with
self
.
scope
():
if
self
.
extended
.
_global_batch_size
:
# pylint: disable=protected-access
if
self
.
extended
.
_global_batch_size
:
# pylint: disable=protected-access
...
@@ -244,13 +270,21 @@ def _monkey_patch_dataset_method(strategy):
...
@@ -244,13 +270,21 @@ def _monkey_patch_dataset_method(strategy):
else
:
else
:
return
SyntheticDataset
(
dataset
)
return
SyntheticDataset
(
dataset
)
strategy
.
org_make_dataset_iterator
=
strategy
.
make_dataset_iterator
def
make_iterator
(
self
,
dataset
):
strategy
.
make_dataset_iterator
=
make_dataset_iterator
dist_dataset
=
make_dataset
(
self
,
dataset
)
return
iter
(
dist_dataset
)
strategy
.
orig_make_dataset_iterator
=
strategy
.
make_dataset_iterator
strategy
.
make_dataset_iterator
=
make_iterator
strategy
.
orig_distribute_dataset
=
strategy
.
experimental_distribute_dataset
strategy
.
experimental_distribute_dataset
=
make_dataset
def
_undo_monkey_patch_dataset_method
(
strategy
):
def
_undo_monkey_patch_dataset_method
(
strategy
):
if
hasattr
(
strategy
,
'org_make_dataset_iterator'
):
if
hasattr
(
strategy
,
'orig_make_dataset_iterator'
):
strategy
.
make_dataset_iterator
=
strategy
.
org_make_dataset_iterator
strategy
.
make_dataset_iterator
=
strategy
.
orig_make_dataset_iterator
if
hasattr
(
strategy
,
'orig_distribute_dataset'
):
strategy
.
make_dataset_iterator
=
strategy
.
orig_distribute_dataset
def
set_up_synthetic_data
():
def
set_up_synthetic_data
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment