Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
a32ffa95
Commit
a32ffa95
authored
Feb 03, 2023
by
qianyj
Browse files
update TensorFlow2x test method
parent
e286da17
Changes
268
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2783 additions
and
0 deletions
+2783
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/hyperparams/oneof_test.py
...models-master/official/modeling/hyperparams/oneof_test.py
+71
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/hyperparams/params_dict.py
...odels-master/official/modeling/hyperparams/params_dict.py
+464
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/hyperparams/params_dict_test.py
...-master/official/modeling/hyperparams/params_dict_test.py
+429
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/__init__.py
...ion/models-master/official/modeling/multitask/__init__.py
+14
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/base_model.py
...n/models-master/official/modeling/multitask/base_model.py
+45
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/base_trainer.py
...models-master/official/modeling/multitask/base_trainer.py
+170
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/base_trainer_test.py
...s-master/official/modeling/multitask/base_trainer_test.py
+90
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/configs.py
...tion/models-master/official/modeling/multitask/configs.py
+80
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/evaluator.py
...on/models-master/official/modeling/multitask/evaluator.py
+180
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/evaluator_test.py
...dels-master/official/modeling/multitask/evaluator_test.py
+133
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/interleaving_trainer.py
...aster/official/modeling/multitask/interleaving_trainer.py
+102
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/interleaving_trainer_test.py
.../official/modeling/multitask/interleaving_trainer_test.py
+102
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/multitask.py
...on/models-master/official/modeling/multitask/multitask.py
+145
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/task_sampler.py
...models-master/official/modeling/multitask/task_sampler.py
+128
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/task_sampler_test.py
...s-master/official/modeling/multitask/task_sampler_test.py
+75
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/test_utils.py
...n/models-master/official/modeling/multitask/test_utils.py
+125
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/train_lib.py
...on/models-master/official/modeling/multitask/train_lib.py
+265
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/train_lib_test.py
...dels-master/official/modeling/multitask/train_lib_test.py
+121
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/optimization/__init__.py
.../models-master/official/modeling/optimization/__init__.py
+24
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/optimization/adafactor_optimizer.py
...ter/official/modeling/optimization/adafactor_optimizer.py
+20
-0
No files found.
Too many changes to show.
To preserve performance only
268 of 268+
files are displayed.
Plain diff
Email patch
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/hyperparams/oneof_test.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
dataclasses
import
tensorflow
as
tf
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
oneof
@
dataclasses
.
dataclass
class
ResNet
(
base_config
.
Config
):
model_depth
:
int
=
50
@
dataclasses
.
dataclass
class
Backbone
(
oneof
.
OneOfConfig
):
type
:
str
=
'resnet'
resnet
:
ResNet
=
ResNet
()
not_resnet
:
int
=
2
@
dataclasses
.
dataclass
class
OutputLayer
(
oneof
.
OneOfConfig
):
type
:
str
=
'single'
single
:
int
=
1
multi_head
:
int
=
2
@
dataclasses
.
dataclass
class
Network
(
base_config
.
Config
):
backbone
:
Backbone
=
Backbone
()
output_layer
:
OutputLayer
=
OutputLayer
()
class
OneOfTest
(
tf
.
test
.
TestCase
):
def
test_to_dict
(
self
):
network_params
=
{
'backbone'
:
{
'type'
:
'resnet'
,
'resnet'
:
{
'model_depth'
:
50
}
},
'output_layer'
:
{
'type'
:
'single'
,
'single'
:
1000
}
}
network_config
=
Network
(
network_params
)
self
.
assertEqual
(
network_config
.
as_dict
(),
network_params
)
def
test_get_oneof
(
self
):
backbone
=
Backbone
()
self
.
assertIsInstance
(
backbone
.
get
(),
ResNet
)
self
.
assertEqual
(
backbone
.
get
().
as_dict
(),
{
'model_depth'
:
50
})
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/hyperparams/params_dict.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A parameter dictionary class which supports the nest structure."""
import
collections
import
copy
import
re
import
six
import
tensorflow
as
tf
import
yaml
# regex pattern that matches on key-value pairs in a comma-separated
# key-value pair string. It splits each k-v pair on the = sign, and
# matches on values that are within single quotes, double quotes, single
# values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE
=
re
.
compile
(
r
"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
\s*=\s*
((?P<val>\'(.*?)\' # single quote
|
\"(.*?)\" # double quote
|
[^,\[]* # single value
|
\[[^\]]*\])) # list of values
($|,\s*)"""
,
re
.
VERBOSE
)
_CONST_VALUE_RE
=
re
.
compile
(
r
'(\d.*|-\d.*|None)'
)
# Yaml loader with an implicit resolver to parse float decimal and exponential
# format. The regular experission parse the following cases:
# 1- Decimal number with an optional exponential term.
# 2- Integer number with an exponential term.
# 3- Decimal number with an optional exponential term.
# 4- Decimal number.
LOADER
=
yaml
.
SafeLoader
LOADER
.
add_implicit_resolver
(
'tag:yaml.org,2002:float'
,
re
.
compile
(
r
'''
^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*)$'''
,
re
.
X
),
list
(
'-+0123456789.'
))
class
ParamsDict
(
object
):
"""A hyperparameter container class."""
RESERVED_ATTR
=
[
'_locked'
,
'_restrictions'
]
def
__init__
(
self
,
default_params
=
None
,
restrictions
=
None
):
"""Instantiate a ParamsDict.
Instantiate a ParamsDict given a set of default parameters and a list of
restrictions. Upon initialization, it validates itself by checking all the
defined restrictions, and raise error if it finds inconsistency.
Args:
default_params: a Python dict or another ParamsDict object including the
default parameters to initialize.
restrictions: a list of strings, which define a list of restrictions to
ensure the consistency of different parameters internally. Each
restriction string is defined as a binary relation with a set of
operators, including {'==', '!=', '<', '<=', '>', '>='}.
"""
self
.
_locked
=
False
self
.
_restrictions
=
[]
if
restrictions
:
self
.
_restrictions
=
restrictions
if
default_params
is
None
:
default_params
=
{}
self
.
override
(
default_params
,
is_strict
=
False
)
def
_set
(
self
,
k
,
v
):
if
isinstance
(
v
,
dict
):
self
.
__dict__
[
k
]
=
ParamsDict
(
v
)
else
:
self
.
__dict__
[
k
]
=
copy
.
deepcopy
(
v
)
def
__setattr__
(
self
,
k
,
v
):
"""Sets the value of the existing key.
Note that this does not allow directly defining a new key. Use the
`override` method with `is_strict=False` instead.
Args:
k: the key string.
v: the value to be used to set the key `k`.
Raises:
KeyError: if k is not defined in the ParamsDict.
"""
if
k
not
in
ParamsDict
.
RESERVED_ATTR
:
if
k
not
in
self
.
__dict__
.
keys
():
raise
KeyError
(
'The key `%{}` does not exist. '
'To extend the existing keys, use '
'`override` with `is_strict` = True.'
.
format
(
k
))
if
self
.
_locked
:
raise
ValueError
(
'The ParamsDict has been locked. '
'No change is allowed.'
)
self
.
_set
(
k
,
v
)
def
__getattr__
(
self
,
k
):
"""Gets the value of the existing key.
Args:
k: the key string.
Returns:
the value of the key.
Raises:
AttributeError: if k is not defined in the ParamsDict.
"""
if
k
not
in
self
.
__dict__
.
keys
():
raise
AttributeError
(
'The key `{}` does not exist. '
.
format
(
k
))
return
self
.
__dict__
[
k
]
def
__contains__
(
self
,
key
):
"""Implements the membership test operator."""
return
key
in
self
.
__dict__
def
get
(
self
,
key
,
value
=
None
):
"""Accesses through built-in dictionary get method."""
return
self
.
__dict__
.
get
(
key
,
value
)
def
__delattr__
(
self
,
k
):
"""Deletes the key and removes its values.
Args:
k: the key string.
Raises:
AttributeError: if k is reserverd or not defined in the ParamsDict.
ValueError: if the ParamsDict instance has been locked.
"""
if
k
in
ParamsDict
.
RESERVED_ATTR
:
raise
AttributeError
(
'The key `{}` is reserved. No change is allowes. '
.
format
(
k
))
if
k
not
in
self
.
__dict__
.
keys
():
raise
AttributeError
(
'The key `{}` does not exist. '
.
format
(
k
))
if
self
.
_locked
:
raise
ValueError
(
'The ParamsDict has been locked. No change is allowed.'
)
del
self
.
__dict__
[
k
]
def
override
(
self
,
override_params
,
is_strict
=
True
):
"""Override the ParamsDict with a set of given params.
Args:
override_params: a dict or a ParamsDict specifying the parameters to be
overridden.
is_strict: a boolean specifying whether override is strict or not. If
True, keys in `override_params` must be present in the ParamsDict. If
False, keys in `override_params` can be different from what is currently
defined in the ParamsDict. In this case, the ParamsDict will be extended
to include the new keys.
"""
if
self
.
_locked
:
raise
ValueError
(
'The ParamsDict has been locked. No change is allowed.'
)
if
isinstance
(
override_params
,
ParamsDict
):
override_params
=
override_params
.
as_dict
()
self
.
_override
(
override_params
,
is_strict
)
# pylint: disable=protected-access
def
_override
(
self
,
override_dict
,
is_strict
=
True
):
"""The implementation of `override`."""
for
k
,
v
in
six
.
iteritems
(
override_dict
):
if
k
in
ParamsDict
.
RESERVED_ATTR
:
raise
KeyError
(
'The key `%{}` is internally reserved. '
'Can not be overridden.'
)
if
k
not
in
self
.
__dict__
.
keys
():
if
is_strict
:
raise
KeyError
(
'The key `{}` does not exist. '
'To extend the existing keys, use '
'`override` with `is_strict` = False.'
.
format
(
k
))
else
:
self
.
_set
(
k
,
v
)
else
:
if
isinstance
(
v
,
dict
):
self
.
__dict__
[
k
].
_override
(
v
,
is_strict
)
# pylint: disable=protected-access
elif
isinstance
(
v
,
ParamsDict
):
self
.
__dict__
[
k
].
_override
(
v
.
as_dict
(),
is_strict
)
# pylint: disable=protected-access
else
:
self
.
__dict__
[
k
]
=
copy
.
deepcopy
(
v
)
def
lock
(
self
):
"""Makes the ParamsDict immutable."""
self
.
_locked
=
True
def
as_dict
(
self
):
"""Returns a dict representation of ParamsDict.
For the nested ParamsDict, a nested dict will be returned.
"""
params_dict
=
{}
for
k
,
v
in
six
.
iteritems
(
self
.
__dict__
):
if
k
not
in
ParamsDict
.
RESERVED_ATTR
:
if
isinstance
(
v
,
ParamsDict
):
params_dict
[
k
]
=
v
.
as_dict
()
else
:
params_dict
[
k
]
=
copy
.
deepcopy
(
v
)
return
params_dict
def
validate
(
self
):
"""Validate the parameters consistency based on the restrictions.
This method validates the internal consistency using the pre-defined list of
restrictions. A restriction is defined as a string which specfiies a binary
operation. The supported binary operations are {'==', '!=', '<', '<=', '>',
'>='}. Note that the meaning of these operators are consistent with the
underlying Python immplementation. Users should make sure the define
restrictions on their type make sense.
For example, for a ParamsDict like the following
```
a:
a1: 1
a2: 2
b:
bb:
bb1: 10
bb2: 20
ccc:
a1: 1
a3: 3
```
one can define two restrictions like this
['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']
What it enforces are:
- a.a1 = 1 == b.ccc.a1 = 1
- a.a2 = 2 <= b.bb.bb2 = 20
Raises:
KeyError: if any of the following happens
(1) any of parameters in any of restrictions is not defined in
ParamsDict,
(2) any inconsistency violating the restriction is found.
ValueError: if the restriction defined in the string is not supported.
"""
def
_get_kv
(
dotted_string
,
params_dict
):
"""Get keys and values indicated by dotted_string."""
if
_CONST_VALUE_RE
.
match
(
dotted_string
)
is
not
None
:
const_str
=
dotted_string
if
const_str
==
'None'
:
constant
=
None
else
:
constant
=
float
(
const_str
)
return
None
,
constant
else
:
tokenized_params
=
dotted_string
.
split
(
'.'
)
v
=
params_dict
for
t
in
tokenized_params
:
v
=
v
[
t
]
return
tokenized_params
[
-
1
],
v
def
_get_kvs
(
tokens
,
params_dict
):
if
len
(
tokens
)
!=
2
:
raise
ValueError
(
'Only support binary relation in restriction.'
)
stripped_tokens
=
[
t
.
strip
()
for
t
in
tokens
]
left_k
,
left_v
=
_get_kv
(
stripped_tokens
[
0
],
params_dict
)
right_k
,
right_v
=
_get_kv
(
stripped_tokens
[
1
],
params_dict
)
return
left_k
,
left_v
,
right_k
,
right_v
params_dict
=
self
.
as_dict
()
for
restriction
in
self
.
_restrictions
:
if
'=='
in
restriction
:
tokens
=
restriction
.
split
(
'=='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
!=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'!='
in
restriction
:
tokens
=
restriction
.
split
(
'!='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
==
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'<'
in
restriction
:
tokens
=
restriction
.
split
(
'<'
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
>=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'<='
in
restriction
:
tokens
=
restriction
.
split
(
'<='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
>
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'>'
in
restriction
:
tokens
=
restriction
.
split
(
'>'
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
<=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'>='
in
restriction
:
tokens
=
restriction
.
split
(
'>='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
<
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
else
:
raise
ValueError
(
'Unsupported relation in restriction.'
)
def
read_yaml_to_params_dict
(
file_path
:
str
):
"""Reads a YAML file to a ParamsDict."""
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'r'
)
as
f
:
params_dict
=
yaml
.
load
(
f
,
Loader
=
LOADER
)
return
ParamsDict
(
params_dict
)
def
save_params_dict_to_yaml
(
params
,
file_path
):
"""Saves the input ParamsDict to a YAML file."""
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'w'
)
as
f
:
def
_my_list_rep
(
dumper
,
data
):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return
dumper
.
represent_sequence
(
u
'tag:yaml.org,2002:seq'
,
data
,
flow_style
=
True
)
yaml
.
add_representer
(
list
,
_my_list_rep
)
yaml
.
dump
(
params
.
as_dict
(),
f
,
default_flow_style
=
False
)
def
nested_csv_str_to_json_str
(
csv_str
):
"""Converts a nested (using '.') comma-separated k=v string to a JSON string.
Converts a comma-separated string of key/value pairs that supports
nesting of keys to a JSON string. Nesting is implemented using
'.' between levels for a given key.
Spacing between commas and = is supported (e.g. there is no difference between
"a=1,b=2", "a = 1, b = 2", or "a=1, b=2") but there should be no spaces before
keys or after values (e.g. " a=1,b=2" and "a=1,b=2 " are not supported).
Note that this will only support values supported by CSV, meaning
values such as nested lists (e.g. "a=[[1,2,3],[4,5,6]]") are not
supported. Strings are supported as well, e.g. "a='hello'".
An example conversion would be:
"a=1, b=2, c.a=2, c.b=3, d.a.a=5"
to
"{ a: 1, b : 2, c: {a : 2, b : 3}, d: {a: {a : 5}}}"
Args:
csv_str: the comma separated string.
Returns:
the converted JSON string.
Raises:
ValueError: If csv_str is not in a comma separated string or
if the string is formatted incorrectly.
"""
if
not
csv_str
:
return
''
formatted_entries
=
[]
nested_map
=
collections
.
defaultdict
(
list
)
pos
=
0
while
pos
<
len
(
csv_str
):
m
=
_PARAM_RE
.
match
(
csv_str
,
pos
)
if
not
m
:
raise
ValueError
(
'Malformed hyperparameter value while parsing '
'CSV string: %s'
%
csv_str
[
pos
:])
pos
=
m
.
end
()
# Parse the values.
m_dict
=
m
.
groupdict
()
name
=
m_dict
[
'name'
]
v
=
m_dict
[
'val'
]
# If a GCS path (e.g. gs://...) is provided, wrap this in quotes
# as yaml.load would otherwise throw an exception
if
re
.
match
(
r
'(?=[^\"\'])(?=[gs://])'
,
v
):
v
=
'
\'
{}
\'
'
.
format
(
v
)
name_nested
=
name
.
split
(
'.'
)
if
len
(
name_nested
)
>
1
:
grouping
=
name_nested
[
0
]
value
=
'.'
.
join
(
name_nested
[
1
:])
+
'='
+
v
nested_map
[
grouping
].
append
(
value
)
else
:
formatted_entries
.
append
(
'%s : %s'
%
(
name
,
v
))
for
grouping
,
value
in
nested_map
.
items
():
value
=
','
.
join
(
value
)
value
=
nested_csv_str_to_json_str
(
value
)
formatted_entries
.
append
(
'%s : %s'
%
(
grouping
,
value
))
return
'{'
+
', '
.
join
(
formatted_entries
)
+
'}'
def
override_params_dict
(
params
,
dict_or_string_or_yaml_file
,
is_strict
):
"""Override a given ParamsDict using a dict, JSON/YAML/CSV string or YAML file.
The logic of the function is outlined below:
1. Test that the input is a dict. If not, proceed to 2.
2. Tests that the input is a string. If not, raise unknown ValueError
2.1. Test if the string is in a CSV format. If so, parse.
If not, proceed to 2.2.
2.2. Try loading the string as a YAML/JSON. If successful, parse to
dict and use it to override. If not, proceed to 2.3.
2.3. Try using the string as a file path and load the YAML file.
Args:
params: a ParamsDict object to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or path to
a YAML file specifying the parameters to be overridden.
is_strict: a boolean specifying whether override is strict or not.
Returns:
params: the overridden ParamsDict object.
Raises:
ValueError: if failed to override the parameters.
"""
if
not
dict_or_string_or_yaml_file
:
return
params
if
isinstance
(
dict_or_string_or_yaml_file
,
dict
):
params
.
override
(
dict_or_string_or_yaml_file
,
is_strict
)
elif
isinstance
(
dict_or_string_or_yaml_file
,
six
.
string_types
):
try
:
dict_or_string_or_yaml_file
=
(
nested_csv_str_to_json_str
(
dict_or_string_or_yaml_file
))
except
ValueError
:
pass
params_dict
=
yaml
.
load
(
dict_or_string_or_yaml_file
,
Loader
=
LOADER
)
if
isinstance
(
params_dict
,
dict
):
params
.
override
(
params_dict
,
is_strict
)
else
:
with
tf
.
io
.
gfile
.
GFile
(
dict_or_string_or_yaml_file
)
as
f
:
params
.
override
(
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
),
is_strict
)
else
:
raise
ValueError
(
'Unknown input type to parse.'
)
return
params
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/hyperparams/params_dict_test.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for params_dict.py."""
import
os
import
tensorflow
as
tf
import
yaml
from
official.modeling.hyperparams
import
params_dict
class
ParamsDictTest
(
tf
.
test
.
TestCase
):
def
test_init_from_an_empty_dict
(
self
):
params
=
params_dict
.
ParamsDict
()
with
self
.
assertRaises
(
AttributeError
):
_
=
params
.
a
with
self
.
assertRaises
(
KeyError
):
params
.
a
=
'aa'
def
test_init_from_a_dict
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
})
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
def
test_init_from_a_param_dict
(
self
):
params_init
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
})
params
=
params_dict
.
ParamsDict
(
params_init
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
def
test_lock
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2
,
'c'
:
3
})
params
.
lock
()
with
self
.
assertRaises
(
ValueError
):
params
.
a
=
10
with
self
.
assertRaises
(
ValueError
):
params
.
override
({
'b'
:
20
})
with
self
.
assertRaises
(
ValueError
):
del
params
.
c
def
test_setattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
params
.
c
=
'ccc'
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
c
,
'ccc'
)
def
test_getattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
c
,
None
)
def
test_delattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
,
'd'
:
{
'd1'
:
1
,
'd2'
:
10
}
},
is_strict
=
False
)
del
params
.
c
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
with
self
.
assertRaises
(
AttributeError
):
_
=
params
.
c
del
params
.
d
with
self
.
assertRaises
(
AttributeError
):
_
=
params
.
d
.
d1
def
test_contains
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
({
'a'
:
'aa'
},
is_strict
=
False
)
self
.
assertIn
(
'a'
,
params
)
self
.
assertNotIn
(
'b'
,
params
)
def
test_get
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
({
'a'
:
'aa'
},
is_strict
=
False
)
self
.
assertEqual
(
params
.
get
(
'a'
),
'aa'
)
self
.
assertEqual
(
params
.
get
(
'b'
,
2
),
2
)
self
.
assertEqual
(
params
.
get
(
'b'
),
None
)
def
test_override_is_strict_true
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
'cc'
,
'c2'
:
20
}
})
params
.
override
({
'a'
:
2
,
'c'
:
{
'c1'
:
'ccc'
}},
is_strict
=
True
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
c
.
c1
,
'ccc'
)
with
self
.
assertRaises
(
KeyError
):
params
.
override
({
'd'
:
'ddd'
},
is_strict
=
True
)
with
self
.
assertRaises
(
KeyError
):
params
.
override
({
'c'
:
{
'c3'
:
30
}},
is_strict
=
True
)
def
test_override_is_strict_false
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
params
.
override
({
'a'
:
2
,
'c'
:
{
'c3'
:
3000
}},
is_strict
=
False
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
c
.
c3
,
3000
)
params
.
override
({
'd'
:
'ddd'
},
is_strict
=
False
)
self
.
assertEqual
(
params
.
d
,
'ddd'
)
params
.
override
({
'c'
:
{
'c4'
:
4444
}},
is_strict
=
False
)
self
.
assertEqual
(
params
.
c
.
c4
,
4444
)
def
test_as_dict
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
params_d
=
params
.
as_dict
()
self
.
assertEqual
(
params_d
[
'a'
],
'aa'
)
self
.
assertEqual
(
params_d
[
'b'
],
2
)
self
.
assertEqual
(
params_d
[
'c'
][
'c1'
],
10
)
self
.
assertEqual
(
params_d
[
'c'
][
'c2'
],
20
)
def
test_validate
(
self
):
# Raise error due to the unknown parameter.
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'a'
:
11
}},
[
'a == c'
])
params
.
validate
()
# OK to check equality of two nested dicts.
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'a'
:
10
},
'c'
:
{
'a'
:
10
}
},
[
'b == c'
])
# Raise error due to inconsistency
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'c'
:
{
'a'
:
10
}},
[
'a == c.a'
])
params
.
validate
()
# Valid rule.
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'c'
:
{
'a'
:
1
}},
[
'a == c.a'
])
# Overridding violates the existing rule, raise error upon validate.
params
.
override
({
'a'
:
11
})
with
self
.
assertRaises
(
KeyError
):
params
.
validate
()
# Valid restrictions with constant.
params
=
params_dict
.
ParamsDict
({
'a'
:
None
,
'c'
:
{
'a'
:
1
}
},
[
'a == None'
,
'c.a == 1'
])
params
.
validate
()
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
({
'a'
:
4
,
'c'
:
{
'a'
:
1
}
},
[
'a == None'
,
'c.a == 1'
])
params
.
validate
()
class
ParamsDictIOTest
(
tf
.
test
.
TestCase
):
def
write_temp_file
(
self
,
filename
,
text
):
temp_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
filename
)
with
tf
.
io
.
gfile
.
GFile
(
temp_file
,
'w'
)
as
writer
:
writer
.
write
(
text
)
return
temp_file
def
test_save_params_dict_to_yaml
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
output_yaml_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'params.yaml'
)
params_dict
.
save_params_dict_to_yaml
(
params
,
output_yaml_file
)
with
tf
.
io
.
gfile
.
GFile
(
output_yaml_file
,
'r'
)
as
f
:
params_d
=
yaml
.
load
(
f
)
self
.
assertEqual
(
params
.
a
,
params_d
[
'a'
])
self
.
assertEqual
(
params
.
b
,
params_d
[
'b'
])
self
.
assertEqual
(
params
.
c
.
c1
,
params_d
[
'c'
][
'c1'
])
self
.
assertEqual
(
params
.
c
.
c2
,
params_d
[
'c'
][
'c2'
])
def
test_read_yaml_to_params_dict
(
self
):
input_yaml_file
=
self
.
write_temp_file
(
'params.yaml'
,
r
"""
a: 'aa'
b: 2
c:
c1: 10
c2: 20
"""
)
params
=
params_dict
.
read_yaml_to_params_dict
(
input_yaml_file
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
c
.
c1
,
10
)
self
.
assertEqual
(
params
.
c
.
c2
,
20
)
def
test_override_params_dict_using_dict
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_dict
=
{
'b'
:
5.2
,
'c'
:
[
30
,
40
]}
params
=
params_dict
.
override_params_dict
(
params
,
override_dict
,
is_strict
=
True
)
self
.
assertEqual
(
1
,
params
.
a
)
self
.
assertEqual
(
5.2
,
params
.
b
)
self
.
assertEqual
([
30
,
40
],
params
.
c
)
self
.
assertEqual
(
'hello'
,
params
.
d
)
self
.
assertEqual
(
False
,
params
.
e
)
def
test_override_params_dict_using_yaml_string
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_yaml_string
=
"'b': 5.2
\n
'c': [30, 40]"
params
=
params_dict
.
override_params_dict
(
params
,
override_yaml_string
,
is_strict
=
True
)
self
.
assertEqual
(
1
,
params
.
a
)
self
.
assertEqual
(
5.2
,
params
.
b
)
self
.
assertEqual
([
30
,
40
],
params
.
c
)
self
.
assertEqual
(
'hello'
,
params
.
d
)
self
.
assertEqual
(
False
,
params
.
e
)
def
test_override_params_dict_using_json_string
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],
},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}
},
'e'
:
False
})
override_json_string
=
"{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
params
=
params_dict
.
override_params_dict
(
params
,
override_json_string
,
is_strict
=
True
)
self
.
assertEqual
(
1
,
params
.
a
)
self
.
assertEqual
(
2
,
params
.
b
.
b1
)
self
.
assertEqual
([
3
,
4
],
params
.
b
.
b2
)
self
.
assertEqual
(
'hi'
,
params
.
d
.
d1
.
d2
)
self
.
assertEqual
(
False
,
params
.
e
)
def
test_override_params_dict_using_csv_string
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],
},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}
},
'e'
:
False
})
override_csv_string
=
"b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
params
=
params_dict
.
override_params_dict
(
params
,
override_csv_string
,
is_strict
=
True
)
self
.
assertEqual
(
1
,
params
.
a
)
self
.
assertEqual
(
2
,
params
.
b
.
b1
)
self
.
assertEqual
([
3
,
4
],
params
.
b
.
b2
)
self
.
assertEqual
(
'hi, world'
,
params
.
d
.
d1
.
d2
)
self
.
assertEqual
(
'gs://test'
,
params
.
e
)
# Test different float formats
override_csv_string
=
'b.b2=-1.e-3, d.d1.d2=+0.001, e=1e+3, a=-1.5E-3'
params
=
params_dict
.
override_params_dict
(
params
,
override_csv_string
,
is_strict
=
True
)
self
.
assertEqual
(
-
1e-3
,
params
.
b
.
b2
)
self
.
assertEqual
(
0.001
,
params
.
d
.
d1
.
d2
)
self
.
assertEqual
(
1e3
,
params
.
e
)
self
.
assertEqual
(
-
1.5e-3
,
params
.
a
)
def
test_override_params_dict_using_yaml_file
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_yaml_file
=
self
.
write_temp_file
(
'params.yaml'
,
r
"""
b: 5.2
c: [30, 40]
"""
)
params
=
params_dict
.
override_params_dict
(
params
,
override_yaml_file
,
is_strict
=
True
)
self
.
assertEqual
(
1
,
params
.
a
)
self
.
assertEqual
(
5.2
,
params
.
b
)
self
.
assertEqual
([
30
,
40
],
params
.
c
)
self
.
assertEqual
(
'hello'
,
params
.
d
)
self
.
assertEqual
(
False
,
params
.
e
)
class
IOTest
(
tf
.
test
.
TestCase
):
def
test_basic_csv_str_to_json_str
(
self
):
csv_str
=
'a=1,b=2,c=3'
json_str
=
'{a : 1, b : 2, c : 3}'
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
self
.
assertEqual
(
converted_csv_str
,
json_str
)
def
test_basic_csv_str_load
(
self
):
csv_str
=
'a=1,b=2,c=3'
expected_output
=
{
'a'
:
1
,
'b'
:
2
,
'c'
:
3
}
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
converted_dict
=
yaml
.
load
(
converted_csv_str
)
self
.
assertDictEqual
(
converted_dict
,
expected_output
)
def
test_basic_nested_csv_str_to_json_str
(
self
):
csv_str
=
'a=1,b.b1=2'
json_str
=
'{a : 1, b : {b1 : 2}}'
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
self
.
assertEqual
(
converted_csv_str
,
json_str
)
def
test_basic_nested_csv_str_load
(
self
):
csv_str
=
'a=1,b.b1=2,c.c1=3'
expected_output
=
{
'a'
:
1
,
'b'
:
{
'b1'
:
2
},
'c'
:
{
'c1'
:
3
}}
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
converted_dict
=
yaml
.
load
(
converted_csv_str
)
self
.
assertDictEqual
(
converted_dict
,
expected_output
)
def
test_complex_nested_csv_str_to_json_str
(
self
):
csv_str
=
'a.aa.aaa.aaaaa.a=1'
json_str
=
'{a : {aa : {aaa : {aaaaa : {a : 1}}}}}'
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
self
.
assertEqual
(
converted_csv_str
,
json_str
)
def
test_complex_nested_csv_str_load
(
self
):
csv_str
=
'a.aa.aaa.aaaaa.a=1,a.a=2'
expected_output
=
{
'a'
:
{
'aa'
:
{
'aaa'
:
{
'aaaaa'
:
{
'a'
:
1
}}},
'a'
:
2
}}
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
converted_dict
=
yaml
.
load
(
converted_csv_str
)
self
.
assertDictEqual
(
converted_dict
,
expected_output
)
def
test_csv_str_load_supported_datatypes
(
self
):
csv_str
=
'a=1,b=2.,c=[1,2,3],d=
\'
hello, there
\'
,e=
\"
Hi.
\"
'
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
converted_dict
=
yaml
.
load
(
converted_csv_str
)
self
.
assertEqual
(
converted_dict
[
'a'
],
1
)
self
.
assertEqual
(
converted_dict
[
'b'
],
2.
)
self
.
assertEqual
(
converted_dict
[
'c'
],
[
1
,
2
,
3
])
self
.
assertEqual
(
converted_dict
[
'd'
],
'hello, there'
)
self
.
assertEqual
(
converted_dict
[
'e'
],
'Hi.'
)
def
test_csv_str_load_unsupported_datatypes
(
self
):
csv_str
=
'a=[[1,2,3],[4,5,6]]'
self
.
assertRaises
(
ValueError
,
params_dict
.
nested_csv_str_to_json_str
,
csv_str
)
def
test_csv_str_to_json_str_spacing
(
self
):
csv_str1
=
'a=1,b=2,c=3'
csv_str2
=
'a = 1, b = 2, c = 3'
json_str
=
'{a : 1, b : 2, c : 3}'
converted_csv_str1
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str1
)
converted_csv_str2
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str2
)
self
.
assertEqual
(
converted_csv_str1
,
converted_csv_str2
)
self
.
assertEqual
(
converted_csv_str1
,
json_str
)
self
.
assertEqual
(
converted_csv_str2
,
json_str
)
def
test_gcs_added_quotes
(
self
):
csv_str
=
'a=gs://abc, b=gs://def'
expected_output
=
'{a :
\'
gs://abc
\'
, b :
\'
gs://def
\'
}'
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
self
.
assertEqual
(
converted_csv_str
,
expected_output
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/__init__.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/base_model.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstraction of multi-task model."""
from
typing
import
Text
,
Dict
import
tensorflow
as
tf
class
MultiTaskBaseModel
(
tf
.
Module
):
"""Base class that holds multi-task model computation."""
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_sub_tasks
=
self
.
_instantiate_sub_tasks
()
def
_instantiate_sub_tasks
(
self
)
->
Dict
[
Text
,
tf
.
keras
.
Model
]:
"""Abstract function that sets up the computation for each sub-task.
Returns:
A map from task name (as string) to a tf.keras.Model object that
represents the sub-task in the multi-task pool.
"""
raise
NotImplementedError
(
"_instantiate_sub_task_models() is not implemented."
)
@
property
def
sub_tasks
(
self
):
"""Fetch a map of task name (string) to task model (tf.keras.Model)."""
return
self
.
_sub_tasks
def
initialize
(
self
):
"""Optional function that loads a pre-train checkpoint."""
return
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/base_trainer.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask base trainer implementation.
The trainer derives from the Orbit `StandardTrainer` class.
"""
from
typing
import
Union
import
gin
import
orbit
import
tensorflow
as
tf
from
official.modeling
import
optimization
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
multitask
@
gin
.
configurable
class
MultiTaskBaseTrainer
(
orbit
.
StandardTrainer
):
"""Multitask base trainer."""
def
__init__
(
self
,
multi_task
:
multitask
.
MultiTask
,
multi_task_model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
optimizer
:
tf
.
optimizers
.
Optimizer
,
trainer_options
=
None
,
train_datasets
=
None
):
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_multi_task
=
multi_task
self
.
_multi_task_model
=
multi_task_model
self
.
_optimizer
=
optimizer
self
.
_training_losses
=
None
self
.
_training_metrics
=
None
self
.
_global_step
=
orbit
.
utils
.
create_global_step
()
# Creates a shadow copy of the weights to store weights moving average.
if
isinstance
(
self
.
_optimizer
,
optimization
.
ExponentialMovingAverage
)
and
not
self
.
_optimizer
.
has_shadow_copy
:
self
.
_optimizer
.
shadow_copy
(
multi_task_model
)
if
hasattr
(
self
.
multi_task_model
,
"checkpoint_items"
):
checkpoint_items
=
self
.
multi_task_model
.
checkpoint_items
else
:
checkpoint_items
=
{}
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
self
.
multi_task_model
,
optimizer
=
self
.
optimizer
,
global_step
=
self
.
global_step
,
**
checkpoint_items
)
if
train_datasets
is
None
:
train_datasets
=
{}
for
name
,
task
in
self
.
multi_task
.
tasks
.
items
():
train_datasets
[
name
]
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
strategy
,
task
.
build_inputs
,
task
.
task_config
.
train_data
)
super
().
__init__
(
train_dataset
=
train_datasets
,
options
=
trainer_options
or
orbit
.
StandardTrainerOptions
())
def
train_loop_begin
(
self
):
"""Clean up states that hold losses and metrics."""
for
_
,
train_loss_metric
in
self
.
training_losses
.
items
():
train_loss_metric
.
reset_states
()
for
_
,
metrics
in
self
.
training_metrics
.
items
():
for
metric
in
metrics
:
metric
.
reset_states
()
def
train_loop_end
(
self
):
"""Record loss and metric values per task."""
result
=
{}
for
task_name
,
loss
in
self
.
training_losses
.
items
():
result
[
task_name
]
=
{
loss
.
name
:
loss
.
result
()}
for
task_name
,
task_metrics
in
self
.
training_metrics
.
items
():
result
[
task_name
].
update
(
{
metric
.
name
:
metric
.
result
()
for
metric
in
task_metrics
})
# Note that, the learning rate schedule is managed by the keras optimizer
# internally, which respects the number of backward pass as `iterations`.
# The learning rate schedule does not follow the trainer logical global
# step of multiple tasks.
if
callable
(
self
.
optimizer
.
learning_rate
):
result
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
(
self
.
optimizer
.
iterations
)
else
:
result
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
return
result
@
property
def
checkpoint
(
self
):
"""Accesses the training checkpoint."""
return
self
.
_checkpoint
@
property
def
training_losses
(
self
):
"""Access training loss metric objects for all tasks."""
if
self
.
_training_losses
is
None
:
# Builds the per-task metrics and losses.
# This the total summed training loss of tasks in the joint training.
self
.
_training_losses
=
dict
(
total_loss
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
))
for
name
in
self
.
multi_task
.
tasks
:
self
.
_training_losses
[
name
]
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
)
return
self
.
_training_losses
@
property
def
training_metrics
(
self
):
"""Access training metric metric objects for all tasks."""
if
self
.
_training_metrics
is
None
:
# Builds the per-task metrics and losses.
self
.
_training_metrics
=
{}
for
name
,
task
in
self
.
multi_task
.
tasks
.
items
():
self
.
_training_metrics
[
name
]
=
task
.
build_metrics
(
training
=
True
)
return
self
.
_training_metrics
@
property
def
strategy
(
self
):
return
self
.
_strategy
@
property
def
multi_task
(
self
):
return
self
.
_multi_task
@
property
def
multi_task_model
(
self
):
return
self
.
_multi_task_model
@
property
def
optimizer
(
self
):
return
self
.
_optimizer
@
property
def
global_step
(
self
):
return
self
.
_global_step
def
train_step
(
self
,
iterator_map
):
"""The default train step calling the multi-task train step.
Args:
iterator_map: a dictionary of task names and per-task dataset iterators.
"""
def
step_fn
(
inputs
):
losses
=
self
.
multi_task
.
joint_train_step
(
inputs
,
multi_task_model
=
self
.
multi_task_model
,
optimizer
=
self
.
optimizer
,
task_metrics
=
self
.
training_metrics
)
for
key
,
loss
in
losses
.
items
():
self
.
training_losses
[
key
].
update_state
(
loss
)
self
.
strategy
.
run
(
step_fn
,
args
=
(
tf
.
nest
.
map_structure
(
next
,
iterator_map
),))
self
.
global_step
.
assign_add
(
1
)
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/base_trainer_test.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.base_trainer."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.modeling.multitask
import
base_trainer
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
test_utils
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
class
BaseTrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_joint_trainer
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
test_utils
.
MockFooTask
(
params
=
test_utils
.
FooConfig
(),
name
=
"foo"
),
test_utils
.
MockBarTask
(
params
=
test_utils
.
BarConfig
(),
name
=
"bar"
)
]
task_weights
=
{
"foo"
:
1.0
,
"bar"
:
1.0
}
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
,
task_weights
=
task_weights
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
test_trainer
=
base_trainer
.
MultiTaskBaseTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
def
test_trainer_with_configs
(
self
):
config
=
configs
.
MultiTaskConfig
(
task_routines
=
(
configs
.
TaskRoutine
(
task_name
=
"foo"
,
task_config
=
test_utils
.
FooConfig
(),
task_weight
=
0.5
),
configs
.
TaskRoutine
(
task_name
=
"bar"
,
task_config
=
test_utils
.
BarConfig
(),
task_weight
=
0.5
)))
test_multitask
=
multitask
.
MultiTask
.
from_config
(
config
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
test_trainer
=
base_trainer
.
MultiTaskBaseTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
self
.
assertEqual
(
test_multitask
.
task_weight
(
"foo"
),
0.5
)
self
.
assertEqual
(
test_trainer
.
global_step
.
numpy
(),
5
)
self
.
assertIn
(
"learning_rate"
,
results
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/configs.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration definitions for multi-task training."""
from
typing
import
Optional
,
Tuple
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.modeling
import
hyperparams
@
dataclasses
.
dataclass
class
TaskRoutine
(
hyperparams
.
Config
):
# TODO(hongkuny): deprecate the task_name once we migrated client code.
task_name
:
str
=
""
task_config
:
cfg
.
TaskConfig
=
None
eval_steps
:
Optional
[
int
]
=
None
task_weight
:
Optional
[
float
]
=
1.0
@
dataclasses
.
dataclass
class
MultiTaskConfig
(
hyperparams
.
Config
):
init_checkpoint
:
str
=
""
model
:
hyperparams
.
Config
=
None
task_routines
:
Tuple
[
TaskRoutine
,
...]
=
()
@
dataclasses
.
dataclass
class
ProportionalSampleConfig
(
hyperparams
.
Config
):
alpha
:
float
=
1.0
@
dataclasses
.
dataclass
class
AnnealingSampleConfig
(
hyperparams
.
Config
):
steps_per_epoch
:
int
=
5
total_steps
:
int
=
20
@
dataclasses
.
dataclass
class
TaskSamplingConfig
(
hyperparams
.
OneOfConfig
):
type
:
str
=
""
uniform
:
hyperparams
.
Config
=
hyperparams
.
Config
()
proportional
:
ProportionalSampleConfig
=
ProportionalSampleConfig
()
annealing
:
AnnealingSampleConfig
=
AnnealingSampleConfig
()
@
dataclasses
.
dataclass
class
MultiTaskTrainerConfig
(
cfg
.
TrainerConfig
):
trainer_type
:
str
=
"interleaving"
task_sampler
:
TaskSamplingConfig
=
TaskSamplingConfig
(
type
=
"proportional"
)
@
dataclasses
.
dataclass
class
MultiTaskExperimentConfig
(
hyperparams
.
Config
):
"""An experiment config for multi-task training and multi-task evaluation."""
task
:
MultiTaskConfig
=
MultiTaskConfig
()
trainer
:
MultiTaskTrainerConfig
=
MultiTaskTrainerConfig
()
runtime
:
cfg
.
RuntimeConfig
=
cfg
.
RuntimeConfig
()
@
dataclasses
.
dataclass
class
MultiEvalExperimentConfig
(
cfg
.
ExperimentConfig
):
"""An experiment config for single-task training and multi-task evaluation.
Attributes:
eval_tasks: individual evaluation tasks.
"""
eval_tasks
:
Tuple
[
TaskRoutine
,
...]
=
()
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/evaluator.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask Evaluator implementation.
The evaluator implements the Orbit `AbstractEvaluator` interface.
"""
from
typing
import
Dict
,
List
,
Optional
,
Union
import
gin
import
orbit
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
train_utils
from
official.modeling.multitask
import
base_model
@
gin
.
configurable
class
MultiTaskEvaluator
(
orbit
.
AbstractEvaluator
):
"""Implements the common trainer shared for TensorFlow models."""
def
__init__
(
self
,
eval_tasks
:
List
[
base_task
.
Task
],
model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
global_step
:
Optional
[
tf
.
Variable
]
=
None
,
eval_steps
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
checkpoint_exporter
:
Optional
[
train_utils
.
BestCheckpointExporter
]
=
None
):
"""Initialize common trainer for TensorFlow models.
Args:
eval_tasks: A list of tasks to evaluate.
model: tf.keras.Model instance.
global_step: the global step variable.
eval_steps: a dictionary of steps to run eval keyed by task names.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_tasks
=
eval_tasks
self
.
_model
=
model
self
.
_global_step
=
global_step
or
orbit
.
utils
.
create_global_step
()
self
.
_checkpoint_exporter
=
checkpoint_exporter
if
hasattr
(
self
.
model
,
"checkpoint_items"
):
checkpoint_items
=
self
.
model
.
checkpoint_items
else
:
checkpoint_items
=
{}
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
self
.
model
,
global_step
=
self
.
global_step
,
**
checkpoint_items
)
self
.
_validation_losses
=
None
self
.
_validation_metrics
=
None
# Builds per-task datasets.
self
.
eval_datasets
=
{}
self
.
eval_steps
=
eval_steps
or
{}
for
task
in
self
.
tasks
:
self
.
eval_datasets
[
task
.
name
]
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
strategy
,
task
.
build_inputs
,
task
.
task_config
.
validation_data
)
# Builds per-task validation loops.
def
get_function
(
task_name
,
task
):
task_metrics
=
self
.
validation_metrics
[
task_name
]
task_loss
=
self
.
validation_losses
[
task_name
]
if
isinstance
(
self
.
model
,
base_model
.
MultiTaskBaseModel
):
model
=
self
.
model
.
sub_tasks
[
task_name
]
else
:
model
=
self
.
model
def
step_fn
(
inputs
):
logs
=
task
.
validation_step
(
inputs
,
model
=
model
,
metrics
=
task_metrics
)
task_loss
.
update_state
(
logs
[
task
.
loss
])
return
logs
@
tf
.
function
def
eval_step_fn
(
iterator
):
distributed_outputs
=
self
.
strategy
.
run
(
step_fn
,
args
=
(
next
(
iterator
),))
return
tf
.
nest
.
map_structure
(
self
.
strategy
.
experimental_local_results
,
distributed_outputs
)
return
orbit
.
utils
.
create_loop_fn
(
eval_step_fn
)
self
.
task_fns
=
{
task
.
name
:
get_function
(
task
.
name
,
task
)
for
task
in
self
.
tasks
}
@
property
def
strategy
(
self
):
return
self
.
_strategy
@
property
def
tasks
(
self
):
return
self
.
_tasks
@
property
def
model
(
self
):
return
self
.
_model
@
property
def
global_step
(
self
):
return
self
.
_global_step
@
property
def
validation_losses
(
self
):
"""Accesses the validation loss metric object."""
if
self
.
_validation_losses
is
None
:
# Builds the per-task metrics and losses.
self
.
_validation_losses
=
{}
for
task
in
self
.
tasks
:
self
.
_validation_losses
[
task
.
name
]
=
tf
.
keras
.
metrics
.
Mean
(
"validation_loss"
,
dtype
=
tf
.
float32
)
return
self
.
_validation_losses
@
property
def
validation_metrics
(
self
):
"""Accesses all validation metric metric objects."""
if
self
.
_validation_metrics
is
None
:
# Builds the per-task metrics and losses.
self
.
_validation_metrics
=
{}
for
task
in
self
.
tasks
:
self
.
_validation_metrics
[
task
.
name
]
=
task
.
build_metrics
(
training
=
False
)
return
self
.
_validation_metrics
@
property
def
checkpoint
(
self
):
"""Accesses the training checkpoint."""
return
self
.
_checkpoint
def
evaluate
(
self
,
num_steps
:
tf
.
Tensor
):
"""Performs evaluation for each `EvalTask`."""
for
metric
in
self
.
validation_losses
.
values
():
metric
.
reset_states
()
for
metrics
in
self
.
validation_metrics
.
values
():
for
metric
in
metrics
:
metric
.
reset_states
()
results
=
{}
eval_iters
=
tf
.
nest
.
map_structure
(
iter
,
self
.
eval_datasets
)
for
task
in
self
.
tasks
:
outputs
=
None
name
=
task
.
name
eval_iter
=
eval_iters
[
name
]
task_eval_steps
=
self
.
eval_steps
.
get
(
name
,
None
)
or
num_steps
outputs
=
self
.
task_fns
[
name
](
eval_iter
,
task_eval_steps
,
state
=
outputs
,
reduce_fn
=
task
.
aggregate_logs
)
task_metrics
=
self
.
validation_metrics
[
name
]
task_loss
=
self
.
validation_losses
[
name
]
logs
=
{}
for
metric
in
task_metrics
+
[
task_loss
]:
logs
[
metric
.
name
]
=
metric
.
result
()
if
outputs
:
metrics
=
task
.
reduce_aggregated_logs
(
outputs
,
global_step
=
self
.
global_step
)
logs
.
update
(
metrics
)
results
[
name
]
=
logs
if
self
.
_checkpoint_exporter
:
self
.
_checkpoint_exporter
.
maybe_export_checkpoint
(
self
.
checkpoint
,
results
,
self
.
global_step
.
numpy
())
return
results
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/evaluator_test.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.evaluator."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.modeling.multitask
import
evaluator
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
class
MockModel
(
tf
.
keras
.
Model
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
1
)
def
call
(
self
,
inputs
):
print
(
inputs
,
type
(
inputs
))
if
"y"
in
inputs
:
self
.
add_loss
(
tf
.
zeros
((
1
,),
dtype
=
tf
.
float32
))
else
:
self
.
add_loss
(
tf
.
ones
((
1
,),
dtype
=
tf
.
float32
))
return
self
.
dense
(
inputs
[
"x"
])
class
MockTask
(
base_task
.
Task
):
"""Mock task object for testing."""
def
build_metrics
(
self
,
training
:
bool
=
True
):
del
training
return
[
tf
.
keras
.
metrics
.
Accuracy
(
name
=
"acc"
)]
def
build_inputs
(
self
,
params
):
def
generate_data
(
_
):
x
=
tf
.
zeros
(
shape
=
(
2
,),
dtype
=
tf
.
float32
)
label
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int32
)
if
self
.
name
==
"bar"
:
return
dict
(
x
=
x
,
y
=
x
),
label
else
:
return
dict
(
x
=
x
),
label
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
generate_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
.
prefetch
(
buffer_size
=
1
).
batch
(
2
,
drop_remainder
=
True
)
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
logs
=
super
().
validation_step
(
inputs
,
model
,
metrics
)
logs
[
"counter"
]
=
tf
.
ones
((
1
,),
dtype
=
tf
.
float32
)
return
logs
def
aggregate_logs
(
self
,
state
,
step_outputs
):
if
state
is
None
:
state
=
{}
for
key
,
value
in
step_outputs
.
items
():
if
key
not
in
state
:
state
[
key
]
=
[]
state
[
key
].
append
(
np
.
concatenate
([
np
.
expand_dims
(
v
.
numpy
(),
axis
=
0
)
for
v
in
value
]))
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
):
for
k
,
v
in
aggregated_logs
.
items
():
aggregated_logs
[
k
]
=
np
.
sum
(
np
.
stack
(
v
,
axis
=
0
))
return
aggregated_logs
class
EvaluatorTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_evaluator
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
]
model
=
MockModel
()
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
eval_tasks
=
tasks
,
model
=
model
)
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"foo"
].
keys
())
self
.
assertEqual
(
results
[
"bar"
][
"validation_loss"
],
0.0
)
self
.
assertEqual
(
results
[
"foo"
][
"validation_loss"
],
1.0
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_evaluator_numpy_metrics
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
]
model
=
MockModel
()
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
eval_tasks
=
tasks
,
model
=
model
)
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
results
[
"bar"
][
"counter"
],
5.
*
distribution
.
num_replicas_in_sync
)
self
.
assertEqual
(
results
[
"foo"
][
"counter"
],
5.
*
distribution
.
num_replicas_in_sync
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/interleaving_trainer.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask trainer that interleaves each task's train step."""
from
typing
import
Union
import
gin
import
orbit
import
tensorflow
as
tf
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
base_trainer
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
task_sampler
as
sampler
@
gin
.
configurable
class
MultiTaskInterleavingTrainer
(
base_trainer
.
MultiTaskBaseTrainer
):
"""MultiTask trainer that interleaves task update."""
def
__init__
(
self
,
multi_task
:
multitask
.
MultiTask
,
multi_task_model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
optimizer
:
tf
.
optimizers
.
Optimizer
,
task_sampler
:
sampler
.
TaskSampler
,
trainer_options
=
None
):
super
().
__init__
(
multi_task
=
multi_task
,
multi_task_model
=
multi_task_model
,
optimizer
=
optimizer
,
trainer_options
=
trainer_options
)
self
.
_task_sampler
=
task_sampler
# Build per task train step.
def
_get_task_step
(
task_name
,
task
):
def
step_fn
(
inputs
):
if
isinstance
(
self
.
multi_task_model
,
base_model
.
MultiTaskBaseModel
):
task_model
=
self
.
multi_task_model
.
sub_tasks
[
task_name
]
else
:
task_model
=
self
.
multi_task_model
task_logs
=
task
.
train_step
(
inputs
,
model
=
task_model
,
optimizer
=
self
.
optimizer
,
metrics
=
self
.
training_metrics
[
task_name
])
self
.
training_losses
[
task_name
].
update_state
(
task_logs
[
task
.
loss
])
return
step_fn
self
.
_task_train_step_map
=
{
name
:
_get_task_step
(
name
,
task
)
for
name
,
task
in
self
.
multi_task
.
tasks
.
items
()
}
# TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging
# on TensorBoard.
self
.
_task_step_counters
=
{
name
:
orbit
.
utils
.
create_global_step
()
for
name
in
self
.
multi_task
.
tasks
}
def
task_step_counter
(
self
,
name
):
return
self
.
_task_step_counters
[
name
]
def
train_step
(
self
,
iterator_map
):
# Sample one task to train according to a multinomial distribution
rn
=
tf
.
random
.
stateless_uniform
(
shape
=
[],
seed
=
(
0
,
self
.
global_step
))
cumulative_sample_distribution
=
self
.
_task_sampler
.
task_cumulative_distribution
(
self
.
global_step
)
# Prepend a [0.0] for indexing convenience.
cumulative_sample_distribution
=
tf
.
concat
(
[
tf
.
constant
([
0.0
],
dtype
=
tf
.
float32
),
cumulative_sample_distribution
],
axis
=
0
)
for
idx
,
(
name
,
_
)
in
enumerate
(
self
.
multi_task
.
tasks
.
items
()):
begin
=
cumulative_sample_distribution
[
idx
]
end
=
cumulative_sample_distribution
[
idx
+
1
]
if
rn
>=
begin
and
rn
<
end
:
self
.
_strategy
.
run
(
self
.
_task_train_step_map
[
name
],
args
=
(
next
(
iterator_map
[
name
]),))
self
.
global_step
.
assign_add
(
1
)
self
.
task_step_counter
(
name
).
assign_add
(
1
)
def
train_loop_end
(
self
):
"""Record loss and metric values per task."""
result
=
super
().
train_loop_end
()
# Interleaving training does not have a good semantic for `total_loss`. In
# fact, it is always zero. To avoid confusion, we filter the `total_loss`
# from the result logs.
if
'total_loss'
in
result
:
result
.
pop
(
'total_loss'
)
return
result
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/interleaving_trainer_test.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.interleaving_trainer."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
interleaving_trainer
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
task_sampler
from
official.modeling.multitask
import
test_utils
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
class
InterleavingTrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_interleaving_trainer
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
test_utils
.
MockFooTask
(
params
=
test_utils
.
FooConfig
(),
name
=
"foo"
),
test_utils
.
MockBarTask
(
params
=
test_utils
.
BarConfig
(),
name
=
"bar"
)
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
sampler
=
task_sampler
.
UniformTaskSampler
(
task_weights
=
test_multitask
.
task_weights
)
test_trainer
=
interleaving_trainer
.
MultiTaskInterleavingTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
,
task_sampler
=
sampler
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
self
.
assertNotIn
(
"total_loss"
,
results
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_with_configs
(
self
,
distribution
):
config
=
configs
.
MultiTaskConfig
(
task_routines
=
(
configs
.
TaskRoutine
(
task_name
=
"foo"
,
task_config
=
test_utils
.
FooConfig
(),
task_weight
=
3.0
),
configs
.
TaskRoutine
(
task_name
=
"bar"
,
task_config
=
test_utils
.
BarConfig
(),
task_weight
=
1.0
)))
with
distribution
.
scope
():
test_multitask
=
multitask
.
MultiTask
.
from_config
(
config
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
num_step
=
1000
sampler
=
task_sampler
.
AnnealingTaskSampler
(
task_weights
=
test_multitask
.
task_weights
,
steps_per_epoch
=
num_step
/
5
,
total_steps
=
num_step
)
test_trainer
=
interleaving_trainer
.
MultiTaskInterleavingTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
,
task_sampler
=
sampler
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
num_step
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
self
.
assertEqual
(
test_trainer
.
global_step
.
numpy
(),
num_step
)
bar_sampled_step
=
test_trainer
.
task_step_counter
(
"bar"
).
numpy
()
foo_sampled_step
=
test_trainer
.
task_step_counter
(
"foo"
).
numpy
()
self
.
assertEqual
(
bar_sampled_step
+
foo_sampled_step
,
num_step
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/multitask.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental MultiTask base class for multi-task training/evaluation."""
import
abc
from
typing
import
Dict
,
List
,
Optional
,
Text
,
Union
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.core
import
task_factory
from
official.modeling
import
optimization
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
configs
OptimizationConfig
=
optimization
.
OptimizationConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
class
MultiTask
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""A multi-task class to manage multiple tasks."""
def
__init__
(
self
,
tasks
:
Union
[
Dict
[
Text
,
base_task
.
Task
],
List
[
base_task
.
Task
]],
task_weights
:
Optional
[
Dict
[
str
,
Union
[
float
,
int
]]]
=
None
,
task_eval_steps
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
name
:
Optional
[
str
]
=
None
):
"""MultiTask initialization.
Args:
tasks: a list or a flat dict of Task.
task_weights: a dict of (task, task weight), task weight can be applied
directly during loss summation in a joint backward step, or it can be
used to sample task among interleaved backward step.
task_eval_steps: a dict of (task, eval steps).
name: the instance name of a MultiTask object.
"""
super
().
__init__
(
name
=
name
)
if
isinstance
(
tasks
,
list
):
self
.
_tasks
=
{}
for
task
in
tasks
:
if
task
.
name
in
self
.
_tasks
:
raise
ValueError
(
"Duplicated tasks found, task.name is %s"
%
task
.
name
)
self
.
_tasks
[
task
.
name
]
=
task
elif
isinstance
(
tasks
,
dict
):
self
.
_tasks
=
tasks
else
:
raise
ValueError
(
"The tasks argument has an invalid type: %s"
%
type
(
tasks
))
self
.
task_eval_steps
=
task_eval_steps
or
{}
self
.
_task_weights
=
task_weights
or
{}
self
.
_task_weights
=
dict
([
(
name
,
self
.
_task_weights
.
get
(
name
,
1.0
))
for
name
in
self
.
tasks
])
@
classmethod
def
from_config
(
cls
,
config
:
configs
.
MultiTaskConfig
,
logging_dir
=
None
):
tasks
=
{}
task_eval_steps
=
{}
task_weights
=
{}
for
task_routine
in
config
.
task_routines
:
task_name
=
task_routine
.
task_name
or
task_routine
.
task_config
.
name
tasks
[
task_name
]
=
task_factory
.
get_task
(
task_routine
.
task_config
,
logging_dir
=
logging_dir
,
name
=
task_name
)
task_eval_steps
[
task_name
]
=
task_routine
.
eval_steps
task_weights
[
task_name
]
=
task_routine
.
task_weight
return
cls
(
tasks
,
task_eval_steps
=
task_eval_steps
,
task_weights
=
task_weights
)
@
property
def
tasks
(
self
):
return
self
.
_tasks
def
task_weight
(
self
,
task_name
):
return
self
.
_task_weights
[
task_name
]
@
property
def
task_weights
(
self
):
return
self
.
_task_weights
@
classmethod
def
create_optimizer
(
cls
,
optimizer_config
:
OptimizationConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
return
base_task
.
Task
.
create_optimizer
(
optimizer_config
=
optimizer_config
,
runtime_config
=
runtime_config
)
def
joint_train_step
(
self
,
task_inputs
,
multi_task_model
:
base_model
.
MultiTaskBaseModel
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
task_metrics
,
**
kwargs
):
"""The joint train step.
Args:
task_inputs: a dictionary of task names and per-task features.
multi_task_model: a MultiTaskBaseModel instance.
optimizer: a tf.optimizers.Optimizer.
task_metrics: a dictionary of task names and per-task metrics.
**kwargs: other arguments to pass through.
Returns:
A dictionary of losses, inculding per-task losses and their weighted sum.
"""
losses
=
{}
with
tf
.
GradientTape
()
as
tape
:
total_loss
=
0.0
for
name
,
model
in
multi_task_model
.
sub_tasks
.
items
():
inputs
=
task_inputs
[
name
]
if
isinstance
(
inputs
,
tuple
)
and
len
(
inputs
)
==
2
:
features
,
labels
=
inputs
elif
isinstance
(
inputs
,
dict
):
features
,
labels
=
inputs
,
inputs
else
:
raise
ValueError
(
"The iterator output is neither a tuple nor a "
"dictionary. It is not implemented to support "
"such outputs."
)
outputs
=
model
(
features
,
training
=
True
)
task_loss
=
self
.
tasks
[
name
].
build_losses
(
labels
,
outputs
)
task_weight
=
self
.
task_weight
(
name
)
total_loss
+=
task_weight
*
task_loss
losses
[
name
]
=
task_loss
self
.
tasks
[
name
].
process_metrics
(
task_metrics
[
name
],
labels
,
outputs
,
**
kwargs
)
# Scales loss as the default gradients allreduce performs sum inside
# the optimizer.
scaled_loss
=
total_loss
/
tf
.
distribute
.
get_strategy
(
).
num_replicas_in_sync
tvars
=
multi_task_model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
losses
[
"total_loss"
]
=
total_loss
return
losses
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/task_sampler.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils to sample tasks for interleaved optimization."""
import
abc
from
typing
import
Union
,
Dict
,
Text
import
tensorflow
as
tf
from
official.modeling.multitask
import
configs
class
TaskSampler
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""An abstract class defining task sampling API for interleaving trainer."""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]]):
self
.
_task_weights
=
task_weights
@
property
def
task_weights
(
self
):
return
self
.
_task_weights
@
abc
.
abstractmethod
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Compute cumulative distribution to sample tasks.
It calculates the cumulative distribution of the multinomial task
distribution with respect to which to be sampled against.
Args:
global_step: A tensor indicating current progess of training.
Returns:
A float tensor with shape (#(task), 1) that represents the cumulative
sampling distribution.
"""
pass
class
UniformTaskSampler
(
TaskSampler
):
"""Sample all tasks uniformly."""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]]):
super
(
UniformTaskSampler
,
self
).
__init__
(
task_weights
=
task_weights
)
self
.
_uniform_cumulative
=
tf
.
math
.
cumsum
(
tf
.
constant
(
[
1.0
/
len
(
self
.
_task_weights
)]
*
len
(
self
.
_task_weights
),
dtype
=
tf
.
float32
))
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
del
global_step
return
self
.
_uniform_cumulative
class
ProportionalTaskSampler
(
TaskSampler
):
"""Sample tasks proportional to task weights."""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]],
alpha
:
float
=
1.0
):
super
(
ProportionalTaskSampler
,
self
).
__init__
(
task_weights
=
task_weights
)
self
.
_alpha
=
tf
.
cast
(
alpha
,
dtype
=
tf
.
float32
)
task_weight_dict_ordered_list
=
tf
.
constant
(
[
weight
for
_
,
weight
in
self
.
_task_weights
.
items
()],
dtype
=
tf
.
float32
)
task_sizes
=
tf
.
math
.
pow
(
task_weight_dict_ordered_list
,
self
.
_alpha
)
task_distribution
=
task_sizes
/
tf
.
reduce_sum
(
task_sizes
)
self
.
_porportional_cumulative
=
tf
.
math
.
cumsum
(
task_distribution
)
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
del
global_step
return
self
.
_porportional_cumulative
class
AnnealingTaskSampler
(
TaskSampler
):
"""Sample tasks according to task weights as well as training progress.
See http://proceedings.mlr.press/v97/stickland19a/stickland19a.pdf
"""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]],
steps_per_epoch
:
int
,
total_steps
:
int
):
super
(
AnnealingTaskSampler
,
self
).
__init__
(
task_weights
=
task_weights
)
self
.
_steps_per_epoch
=
tf
.
cast
(
steps_per_epoch
,
dtype
=
tf
.
float32
)
self
.
_total_epochs
=
tf
.
cast
(
total_steps
/
self
.
_steps_per_epoch
,
dtype
=
tf
.
float32
)
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
cur_epoch
=
tf
.
math
.
floor
(
tf
.
cast
(
global_step
,
dtype
=
tf
.
float32
)
/
self
.
_steps_per_epoch
)
alpha
=
1.0
-
0.8
*
(
cur_epoch
-
1
)
/
(
self
.
_total_epochs
-
1
+
1e-10
)
task_weight_dict_ordered_list
=
[
weight
for
_
,
weight
in
self
.
_task_weights
.
items
()
]
task_sizes
=
tf
.
math
.
pow
(
tf
.
constant
(
task_weight_dict_ordered_list
,
dtype
=
tf
.
float32
),
tf
.
cast
(
alpha
,
dtype
=
tf
.
float32
))
dynamic_task_distribution
=
task_sizes
/
tf
.
reduce_sum
(
task_sizes
)
return
tf
.
math
.
cumsum
(
dynamic_task_distribution
)
def
get_task_sampler
(
config
:
configs
.
TaskSamplingConfig
,
task_weights
:
Dict
[
Text
,
float
])
->
TaskSampler
:
"""Utils to create task sampler with configuration and task weights."""
oneof_config
=
config
.
get
()
if
config
.
type
==
'uniform'
:
return
UniformTaskSampler
(
task_weights
=
task_weights
)
elif
config
.
type
==
'proportional'
:
return
ProportionalTaskSampler
(
task_weights
=
task_weights
,
alpha
=
oneof_config
.
alpha
)
elif
config
.
type
==
'annealing'
:
return
AnnealingTaskSampler
(
task_weights
=
task_weights
,
steps_per_epoch
=
oneof_config
.
steps_per_epoch
,
total_steps
=
oneof_config
.
total_steps
)
else
:
raise
RuntimeError
(
'Task sampler type not supported'
)
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/task_sampler_test.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.task_sampler."""
import
tensorflow
as
tf
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
task_sampler
as
sampler
class
TaskSamplerTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
TaskSamplerTest
,
self
).
setUp
()
self
.
_task_weights
=
{
'A'
:
1.0
,
'B'
:
2.0
,
'C'
:
3.0
}
def
test_uniform_sample_distribution
(
self
):
uniform_sampler
=
sampler
.
get_task_sampler
(
configs
.
TaskSamplingConfig
(
type
=
'uniform'
),
self
.
_task_weights
)
for
step
in
range
(
5
):
cumulative_distribution
=
uniform_sampler
.
task_cumulative_distribution
(
tf
.
constant
(
step
,
dtype
=
tf
.
int64
))
self
.
assertAllClose
([
0.333333
,
0.666666
,
1.0
],
cumulative_distribution
.
numpy
())
def
test_proportional_sample_distribution
(
self
):
prop_sampler
=
sampler
.
get_task_sampler
(
configs
.
TaskSamplingConfig
(
type
=
'proportional'
,
proportional
=
configs
.
ProportionalSampleConfig
(
alpha
=
2.0
)),
self
.
_task_weights
)
# CucmulativeOf(Normalize([1.0^2, 2.0^2, 3.0^2]))
for
step
in
range
(
5
):
cumulative_distribution
=
prop_sampler
.
task_cumulative_distribution
(
tf
.
constant
(
step
,
dtype
=
tf
.
int64
))
self
.
assertAllClose
([
0.07142857
,
0.35714286
,
1.0
],
cumulative_distribution
.
numpy
())
def
test_annealing_sample_distribution
(
self
):
num_epoch
=
3
step_per_epoch
=
6
annel_sampler
=
sampler
.
get_task_sampler
(
configs
.
TaskSamplingConfig
(
type
=
'annealing'
,
annealing
=
configs
.
AnnealingSampleConfig
(
steps_per_epoch
=
step_per_epoch
,
total_steps
=
step_per_epoch
*
num_epoch
)),
self
.
_task_weights
)
global_step
=
tf
.
Variable
(
0
,
dtype
=
tf
.
int64
,
name
=
'global_step'
,
trainable
=
False
)
expected_cumulative_epochs
=
[[
0.12056106
,
0.4387236
,
1.0
],
[
0.16666667
,
0.5
,
1.0
],
[
0.22477472
,
0.5654695
,
1.0
]]
for
epoch
in
range
(
num_epoch
):
for
_
in
range
(
step_per_epoch
):
cumulative_distribution
=
annel_sampler
.
task_cumulative_distribution
(
tf
.
constant
(
global_step
,
dtype
=
tf
.
int64
))
global_step
.
assign_add
(
1
)
self
.
assertAllClose
(
expected_cumulative_epochs
[
epoch
],
cumulative_distribution
.
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/test_utils.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing utils for mock models and tasks."""
from
typing
import
Dict
,
Text
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.core
import
task_factory
from
official.modeling.multitask
import
base_model
class
MockFooModel
(
tf
.
keras
.
Model
):
"""A mock model can consume 'foo' and 'bar' inputs."""
def
__init__
(
self
,
shared_layer
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_share_layer
=
shared_layer
self
.
_foo_specific_layer
=
tf
.
keras
.
layers
.
Dense
(
1
)
def
call
(
self
,
inputs
):
self
.
add_loss
(
tf
.
zeros
((
1
,),
dtype
=
tf
.
float32
))
if
"foo"
in
inputs
:
input_tensor
=
inputs
[
"foo"
]
else
:
input_tensor
=
inputs
[
"bar"
]
return
self
.
_foo_specific_layer
(
self
.
_share_layer
(
input_tensor
))
class
MockBarModel
(
tf
.
keras
.
Model
):
def
__init__
(
self
,
shared_layer
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_share_layer
=
shared_layer
self
.
_bar_specific_layer
=
tf
.
keras
.
layers
.
Dense
(
1
)
def
call
(
self
,
inputs
):
self
.
add_loss
(
tf
.
zeros
((
2
,),
dtype
=
tf
.
float32
))
return
self
.
_bar_specific_layer
(
self
.
_share_layer
(
inputs
[
"bar"
]))
class
MockMultiTaskModel
(
base_model
.
MultiTaskBaseModel
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
_shared_dense
=
tf
.
keras
.
layers
.
Dense
(
1
)
super
().
__init__
(
*
args
,
**
kwargs
)
def
_instantiate_sub_tasks
(
self
)
->
Dict
[
Text
,
tf
.
keras
.
Model
]:
return
{
"foo"
:
MockFooModel
(
self
.
_shared_dense
),
"bar"
:
MockBarModel
(
self
.
_shared_dense
)
}
def
mock_data
(
feature_name
):
"""Mock dataset function."""
def
_generate_data
(
_
):
x
=
tf
.
zeros
(
shape
=
(
2
,),
dtype
=
tf
.
float32
)
label
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int32
)
return
{
feature_name
:
x
},
label
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
_generate_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
.
prefetch
(
buffer_size
=
1
).
batch
(
2
,
drop_remainder
=
True
)
class
FooConfig
(
cfg
.
TaskConfig
):
pass
class
BarConfig
(
cfg
.
TaskConfig
):
pass
@
task_factory
.
register_task_cls
(
FooConfig
)
class
MockFooTask
(
base_task
.
Task
):
"""Mock foo task object for testing."""
def
build_metrics
(
self
,
training
:
bool
=
True
):
del
training
return
[
tf
.
keras
.
metrics
.
Accuracy
(
name
=
"foo_acc"
)]
def
build_inputs
(
self
,
params
):
return
mock_data
(
"foo"
)
def
build_model
(
self
)
->
tf
.
keras
.
Model
:
return
MockFooModel
(
shared_layer
=
tf
.
keras
.
layers
.
Dense
(
1
))
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
loss
=
tf
.
keras
.
losses
.
mean_squared_error
(
labels
,
model_outputs
)
if
aux_losses
:
loss
+=
tf
.
add_n
(
aux_losses
)
return
tf
.
reduce_mean
(
loss
)
@
task_factory
.
register_task_cls
(
BarConfig
)
class
MockBarTask
(
base_task
.
Task
):
"""Mock bar task object for testing."""
def
build_metrics
(
self
,
training
:
bool
=
True
):
del
training
return
[
tf
.
keras
.
metrics
.
Accuracy
(
name
=
"bar_acc"
)]
def
build_inputs
(
self
,
params
):
return
mock_data
(
"bar"
)
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
loss
=
tf
.
keras
.
losses
.
mean_squared_error
(
labels
,
model_outputs
)
if
aux_losses
:
loss
+=
tf
.
add_n
(
aux_losses
)
return
tf
.
reduce_mean
(
loss
)
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/train_lib.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask training driver library."""
# pytype: disable=attribute-error
import
os
from
typing
import
Any
,
List
,
Optional
,
Tuple
from
absl
import
logging
import
orbit
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
base_trainer
as
core_lib
from
official.core
import
train_utils
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
base_trainer
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
evaluator
as
evaluator_lib
from
official.modeling.multitask
import
interleaving_trainer
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
task_sampler
TRAINERS
=
{
'interleaving'
:
interleaving_trainer
.
MultiTaskInterleavingTrainer
,
'joint'
:
base_trainer
.
MultiTaskBaseTrainer
}
def
run_experiment
(
*
,
distribution_strategy
:
tf
.
distribute
.
Strategy
,
task
:
multitask
.
MultiTask
,
model
:
base_model
.
MultiTaskBaseModel
,
mode
:
str
,
params
:
configs
.
MultiTaskExperimentConfig
,
model_dir
:
str
,
trainer
:
base_trainer
.
MultiTaskBaseTrainer
=
None
)
->
base_model
.
MultiTaskBaseModel
:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A MultiTaskTask instance.
model: A MultiTaskBaseModel instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
trainer: (optional) A multi-task trainer to use. If none is provided, a
default one will be created based on `params`.
Returns:
model: `base_model.MultiTaskBaseModel` instance.
"""
is_training
=
'train'
in
mode
is_eval
=
'eval'
in
mode
with
distribution_strategy
.
scope
():
optimizer
=
task
.
create_optimizer
(
params
.
trainer
.
optimizer_config
,
params
.
runtime
)
kwargs
=
dict
(
multi_task
=
task
,
multi_task_model
=
model
,
optimizer
=
optimizer
)
if
params
.
trainer
.
trainer_type
==
'interleaving'
:
sampler
=
task_sampler
.
get_task_sampler
(
params
.
trainer
.
task_sampler
,
task
.
task_weights
)
kwargs
.
update
(
dict
(
task_sampler
=
sampler
))
if
trainer
is
None
:
trainer
=
TRAINERS
[
params
.
trainer
.
trainer_type
](
**
kwargs
)
if
is_training
else
None
if
is_eval
:
eval_steps
=
task
.
task_eval_steps
evaluator
=
evaluator_lib
.
MultiTaskEvaluator
(
eval_tasks
=
task
.
tasks
.
values
(),
model
=
model
,
eval_steps
=
eval_steps
,
global_step
=
trainer
.
global_step
if
is_training
else
None
,
checkpoint_exporter
=
train_utils
.
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
else
:
evaluator
=
None
if
trainer
:
checkpoint
=
trainer
.
checkpoint
global_step
=
trainer
.
global_step
else
:
checkpoint
=
evaluator
.
checkpoint
global_step
=
evaluator
.
global_step
# TODO(hongkuny,haozhangthu): Revisit initialization method.
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
model_dir
,
max_to_keep
=
params
.
trainer
.
max_to_keep
,
step_counter
=
global_step
,
checkpoint_interval
=
params
.
trainer
.
checkpoint_interval
,
init_fn
=
model
.
initialize
)
controller
=
orbit
.
Controller
(
strategy
=
distribution_strategy
,
trainer
=
trainer
,
evaluator
=
evaluator
,
global_step
=
global_step
,
steps_per_loop
=
params
.
trainer
.
steps_per_loop
,
checkpoint_manager
=
checkpoint_manager
,
summary_dir
=
os
.
path
.
join
(
model_dir
,
'train'
),
eval_summary_dir
=
os
.
path
.
join
(
model_dir
,
'validation'
),
summary_interval
=
params
.
trainer
.
summary_interval
)
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
with
distribution_strategy
.
scope
():
if
mode
==
'train'
:
controller
.
train
(
steps
=
params
.
trainer
.
train_steps
)
elif
mode
==
'train_and_eval'
:
controller
.
train_and_evaluate
(
train_steps
=
params
.
trainer
.
train_steps
,
eval_steps
=
params
.
trainer
.
validation_steps
,
eval_interval
=
params
.
trainer
.
validation_interval
)
elif
mode
==
'eval'
:
controller
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
elif
mode
==
'continuous_eval'
:
def
timeout_fn
():
if
evaluator
.
global_step
.
numpy
()
>=
params
.
trainer
.
train_steps
:
return
True
return
False
controller
.
evaluate_continuously
(
steps
=
params
.
trainer
.
validation_steps
,
timeout
=
params
.
trainer
.
continuous_eval_timeout
,
timeout_fn
=
timeout_fn
)
else
:
raise
NotImplementedError
(
'The mode is not implemented: %s'
%
mode
)
return
model
def
run_experiment_with_multitask_eval
(
*
,
distribution_strategy
:
tf
.
distribute
.
Strategy
,
train_task
:
base_task
.
Task
,
eval_tasks
:
List
[
base_task
.
Task
],
mode
:
str
,
params
:
configs
.
MultiEvalExperimentConfig
,
model_dir
:
str
,
run_post_eval
:
bool
=
False
,
save_summary
:
bool
=
True
,
trainer
:
Optional
[
core_lib
.
Trainer
]
=
None
)
->
Tuple
[
Any
,
Any
]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
train_task: A base_task.Task instance.
eval_tasks: A list of evaluation tasks.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: MultiEvalExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
trainer: the core_lib.Trainer instance. It should be created within the
strategy.scope(). If not provided, an instance will be created by default
if `mode` contains 'train'.
Returns:
model: `tf.keras.Model` instance.
"""
is_training
=
'train'
in
mode
is_eval
=
'eval'
in
mode
with
distribution_strategy
.
scope
():
if
is_training
:
trainer
=
trainer
or
core_lib
.
Trainer
(
config
=
params
,
task
=
train_task
,
model
=
train_task
.
build_model
(),
optimizer
=
train_task
.
create_optimizer
(
params
.
trainer
.
optimizer_config
,
params
.
runtime
),
train
=
True
,
evaluate
=
False
)
else
:
trainer
=
None
model
=
trainer
.
model
if
trainer
else
train_task
.
build_model
()
if
is_eval
:
eval_steps
=
dict
([(
task_routine
.
task_config
.
name
,
task_routine
.
eval_steps
)
for
task_routine
in
params
.
eval_tasks
])
evaluator
=
evaluator_lib
.
MultiTaskEvaluator
(
eval_tasks
=
eval_tasks
,
model
=
model
,
global_step
=
trainer
.
global_step
if
is_training
else
None
,
eval_steps
=
eval_steps
,
checkpoint_exporter
=
train_utils
.
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
else
:
evaluator
=
None
if
trainer
:
checkpoint
=
trainer
.
checkpoint
global_step
=
trainer
.
global_step
else
:
checkpoint
=
evaluator
.
checkpoint
global_step
=
evaluator
.
global_step
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
model_dir
,
max_to_keep
=
params
.
trainer
.
max_to_keep
,
step_counter
=
global_step
,
checkpoint_interval
=
params
.
trainer
.
checkpoint_interval
,
init_fn
=
trainer
.
initialize
if
trainer
else
None
)
controller
=
orbit
.
Controller
(
strategy
=
distribution_strategy
,
trainer
=
trainer
,
evaluator
=
evaluator
,
global_step
=
global_step
,
steps_per_loop
=
params
.
trainer
.
steps_per_loop
,
checkpoint_manager
=
checkpoint_manager
,
summary_dir
=
os
.
path
.
join
(
model_dir
,
'train'
)
if
save_summary
else
None
,
eval_summary_dir
=
os
.
path
.
join
(
model_dir
,
'validation'
)
if
(
save_summary
)
else
None
,
summary_interval
=
params
.
trainer
.
summary_interval
if
(
save_summary
)
else
None
)
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
with
distribution_strategy
.
scope
():
if
mode
==
'train'
:
controller
.
train
(
steps
=
params
.
trainer
.
train_steps
)
elif
mode
==
'train_and_eval'
:
controller
.
train_and_evaluate
(
train_steps
=
params
.
trainer
.
train_steps
,
eval_steps
=
params
.
trainer
.
validation_steps
,
eval_interval
=
params
.
trainer
.
validation_interval
)
elif
mode
==
'eval'
:
controller
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
elif
mode
==
'continuous_eval'
:
def
timeout_fn
():
if
evaluator
.
global_step
.
numpy
()
>=
params
.
trainer
.
train_steps
:
return
True
return
False
controller
.
evaluate_continuously
(
steps
=
params
.
trainer
.
validation_steps
,
timeout
=
params
.
trainer
.
continuous_eval_timeout
,
timeout_fn
=
timeout_fn
)
else
:
raise
NotImplementedError
(
'The mode is not implemented: %s'
%
mode
)
if
run_post_eval
:
return
model
,
evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
params
.
trainer
.
validation_steps
))
# pytype: disable=bad-return-type # typed-keras
else
:
return
model
,
{}
# pytype: disable=bad-return-type # typed-keras
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/multitask/train_lib_test.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.train_lib."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
task_factory
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
test_utils
from
official.modeling.multitask
import
train_lib
class
TrainLibTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
_test_config
=
{
'trainer'
:
{
'checkpoint_interval'
:
10
,
'steps_per_loop'
:
10
,
'summary_interval'
:
10
,
'train_steps'
:
10
,
'validation_steps'
:
5
,
'validation_interval'
:
10
,
'continuous_eval_timeout'
:
1
,
'optimizer_config'
:
{
'optimizer'
:
{
'type'
:
'sgd'
,
},
'learning_rate'
:
{
'type'
:
'constant'
}
}
},
}
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
'eager'
,
flag_mode
=
[
'train'
,
'eval'
,
'train_and_eval'
]))
def
test_end_to_end
(
self
,
distribution_strategy
,
flag_mode
):
model_dir
=
self
.
get_temp_dir
()
experiment_config
=
configs
.
MultiTaskExperimentConfig
(
task
=
configs
.
MultiTaskConfig
(
task_routines
=
(
configs
.
TaskRoutine
(
task_name
=
'foo'
,
task_config
=
test_utils
.
FooConfig
()),
configs
.
TaskRoutine
(
task_name
=
'bar'
,
task_config
=
test_utils
.
BarConfig
()))))
experiment_config
=
params_dict
.
override_params_dict
(
experiment_config
,
self
.
_test_config
,
is_strict
=
False
)
with
distribution_strategy
.
scope
():
test_multitask
=
multitask
.
MultiTask
.
from_config
(
experiment_config
.
task
)
model
=
test_utils
.
MockMultiTaskModel
()
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
test_multitask
,
model
=
model
,
mode
=
flag_mode
,
params
=
experiment_config
,
model_dir
=
model_dir
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
'eager'
,
flag_mode
=
[
'train'
,
'eval'
,
'train_and_eval'
]))
def
test_end_to_end_multi_eval
(
self
,
distribution_strategy
,
flag_mode
):
model_dir
=
self
.
get_temp_dir
()
experiment_config
=
configs
.
MultiEvalExperimentConfig
(
task
=
test_utils
.
FooConfig
(),
eval_tasks
=
(
configs
.
TaskRoutine
(
task_name
=
'foo'
,
task_config
=
test_utils
.
FooConfig
(),
eval_steps
=
2
),
configs
.
TaskRoutine
(
task_name
=
'bar'
,
task_config
=
test_utils
.
BarConfig
(),
eval_steps
=
3
)))
experiment_config
=
params_dict
.
override_params_dict
(
experiment_config
,
self
.
_test_config
,
is_strict
=
False
)
with
distribution_strategy
.
scope
():
train_task
=
task_factory
.
get_task
(
experiment_config
.
task
)
eval_tasks
=
[
task_factory
.
get_task
(
config
.
task_config
,
name
=
config
.
task_name
)
for
config
in
experiment_config
.
eval_tasks
]
train_lib
.
run_experiment_with_multitask_eval
(
distribution_strategy
=
distribution_strategy
,
train_task
=
train_task
,
eval_tasks
=
eval_tasks
,
mode
=
flag_mode
,
params
=
experiment_config
,
model_dir
=
model_dir
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/optimization/__init__.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimization package definition."""
# pylint: disable=wildcard-import
from
official.modeling.optimization.configs.learning_rate_config
import
*
from
official.modeling.optimization.configs.optimization_config
import
*
from
official.modeling.optimization.configs.optimizer_config
import
*
from
official.modeling.optimization.ema_optimizer
import
ExponentialMovingAverage
from
official.modeling.optimization.lr_schedule
import
*
from
official.modeling.optimization.optimizer_factory
import
OptimizerFactory
from
official.modeling.optimization.optimizer_factory
import
register_optimizer_cls
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/optimization/adafactor_optimizer.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adafactor optimizer.
A new optimizer that will be open sourced soon.
"""
# pylint: disable=invalid-name, represents an unimplemented class definition.
Adafactor
=
"Unimplemented"
Prev
1
…
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment