Commit a32ffa95 authored by qianyj's avatar qianyj
Browse files

update TensorFlow2x test method

parent e286da17
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import tensorflow as tf
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import oneof
@dataclasses.dataclass
class ResNet(base_config.Config):
model_depth: int = 50
@dataclasses.dataclass
class Backbone(oneof.OneOfConfig):
type: str = 'resnet'
resnet: ResNet = ResNet()
not_resnet: int = 2
@dataclasses.dataclass
class OutputLayer(oneof.OneOfConfig):
type: str = 'single'
single: int = 1
multi_head: int = 2
@dataclasses.dataclass
class Network(base_config.Config):
backbone: Backbone = Backbone()
output_layer: OutputLayer = OutputLayer()
class OneOfTest(tf.test.TestCase):
def test_to_dict(self):
network_params = {
'backbone': {
'type': 'resnet',
'resnet': {
'model_depth': 50
}
},
'output_layer': {
'type': 'single',
'single': 1000
}
}
network_config = Network(network_params)
self.assertEqual(network_config.as_dict(), network_params)
def test_get_oneof(self):
backbone = Backbone()
self.assertIsInstance(backbone.get(), ResNet)
self.assertEqual(backbone.get().as_dict(), {'model_depth': 50})
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A parameter dictionary class which supports the nest structure."""
import collections
import copy
import re
import six
import tensorflow as tf
import yaml
# regex pattern that matches on key-value pairs in a comma-separated
# key-value pair string. It splits each k-v pair on the = sign, and
# matches on values that are within single quotes, double quotes, single
# values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE = re.compile(
r"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
\s*=\s*
((?P<val>\'(.*?)\' # single quote
|
\"(.*?)\" # double quote
|
[^,\[]* # single value
|
\[[^\]]*\])) # list of values
($|,\s*)""", re.VERBOSE)
_CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)')
# Yaml loader with an implicit resolver to parse float decimal and exponential
# format. The regular experission parse the following cases:
# 1- Decimal number with an optional exponential term.
# 2- Integer number with an exponential term.
# 3- Decimal number with an optional exponential term.
# 4- Decimal number.
LOADER = yaml.SafeLoader
LOADER.add_implicit_resolver(
'tag:yaml.org,2002:float',
re.compile(r'''
^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*)$''', re.X),
list('-+0123456789.'))
class ParamsDict(object):
"""A hyperparameter container class."""
RESERVED_ATTR = ['_locked', '_restrictions']
def __init__(self, default_params=None, restrictions=None):
"""Instantiate a ParamsDict.
Instantiate a ParamsDict given a set of default parameters and a list of
restrictions. Upon initialization, it validates itself by checking all the
defined restrictions, and raise error if it finds inconsistency.
Args:
default_params: a Python dict or another ParamsDict object including the
default parameters to initialize.
restrictions: a list of strings, which define a list of restrictions to
ensure the consistency of different parameters internally. Each
restriction string is defined as a binary relation with a set of
operators, including {'==', '!=', '<', '<=', '>', '>='}.
"""
self._locked = False
self._restrictions = []
if restrictions:
self._restrictions = restrictions
if default_params is None:
default_params = {}
self.override(default_params, is_strict=False)
def _set(self, k, v):
if isinstance(v, dict):
self.__dict__[k] = ParamsDict(v)
else:
self.__dict__[k] = copy.deepcopy(v)
def __setattr__(self, k, v):
"""Sets the value of the existing key.
Note that this does not allow directly defining a new key. Use the
`override` method with `is_strict=False` instead.
Args:
k: the key string.
v: the value to be used to set the key `k`.
Raises:
KeyError: if k is not defined in the ParamsDict.
"""
if k not in ParamsDict.RESERVED_ATTR:
if k not in self.__dict__.keys():
raise KeyError('The key `%{}` does not exist. '
'To extend the existing keys, use '
'`override` with `is_strict` = True.'.format(k))
if self._locked:
raise ValueError('The ParamsDict has been locked. '
'No change is allowed.')
self._set(k, v)
def __getattr__(self, k):
"""Gets the value of the existing key.
Args:
k: the key string.
Returns:
the value of the key.
Raises:
AttributeError: if k is not defined in the ParamsDict.
"""
if k not in self.__dict__.keys():
raise AttributeError('The key `{}` does not exist. '.format(k))
return self.__dict__[k]
def __contains__(self, key):
"""Implements the membership test operator."""
return key in self.__dict__
def get(self, key, value=None):
"""Accesses through built-in dictionary get method."""
return self.__dict__.get(key, value)
def __delattr__(self, k):
"""Deletes the key and removes its values.
Args:
k: the key string.
Raises:
AttributeError: if k is reserverd or not defined in the ParamsDict.
ValueError: if the ParamsDict instance has been locked.
"""
if k in ParamsDict.RESERVED_ATTR:
raise AttributeError(
'The key `{}` is reserved. No change is allowes. '.format(k))
if k not in self.__dict__.keys():
raise AttributeError('The key `{}` does not exist. '.format(k))
if self._locked:
raise ValueError('The ParamsDict has been locked. No change is allowed.')
del self.__dict__[k]
def override(self, override_params, is_strict=True):
"""Override the ParamsDict with a set of given params.
Args:
override_params: a dict or a ParamsDict specifying the parameters to be
overridden.
is_strict: a boolean specifying whether override is strict or not. If
True, keys in `override_params` must be present in the ParamsDict. If
False, keys in `override_params` can be different from what is currently
defined in the ParamsDict. In this case, the ParamsDict will be extended
to include the new keys.
"""
if self._locked:
raise ValueError('The ParamsDict has been locked. No change is allowed.')
if isinstance(override_params, ParamsDict):
override_params = override_params.as_dict()
self._override(override_params, is_strict) # pylint: disable=protected-access
def _override(self, override_dict, is_strict=True):
"""The implementation of `override`."""
for k, v in six.iteritems(override_dict):
if k in ParamsDict.RESERVED_ATTR:
raise KeyError('The key `%{}` is internally reserved. '
'Can not be overridden.')
if k not in self.__dict__.keys():
if is_strict:
raise KeyError('The key `{}` does not exist. '
'To extend the existing keys, use '
'`override` with `is_strict` = False.'.format(k))
else:
self._set(k, v)
else:
if isinstance(v, dict):
self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
elif isinstance(v, ParamsDict):
self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
else:
self.__dict__[k] = copy.deepcopy(v)
def lock(self):
"""Makes the ParamsDict immutable."""
self._locked = True
def as_dict(self):
"""Returns a dict representation of ParamsDict.
For the nested ParamsDict, a nested dict will be returned.
"""
params_dict = {}
for k, v in six.iteritems(self.__dict__):
if k not in ParamsDict.RESERVED_ATTR:
if isinstance(v, ParamsDict):
params_dict[k] = v.as_dict()
else:
params_dict[k] = copy.deepcopy(v)
return params_dict
def validate(self):
"""Validate the parameters consistency based on the restrictions.
This method validates the internal consistency using the pre-defined list of
restrictions. A restriction is defined as a string which specfiies a binary
operation. The supported binary operations are {'==', '!=', '<', '<=', '>',
'>='}. Note that the meaning of these operators are consistent with the
underlying Python immplementation. Users should make sure the define
restrictions on their type make sense.
For example, for a ParamsDict like the following
```
a:
a1: 1
a2: 2
b:
bb:
bb1: 10
bb2: 20
ccc:
a1: 1
a3: 3
```
one can define two restrictions like this
['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']
What it enforces are:
- a.a1 = 1 == b.ccc.a1 = 1
- a.a2 = 2 <= b.bb.bb2 = 20
Raises:
KeyError: if any of the following happens
(1) any of parameters in any of restrictions is not defined in
ParamsDict,
(2) any inconsistency violating the restriction is found.
ValueError: if the restriction defined in the string is not supported.
"""
def _get_kv(dotted_string, params_dict):
"""Get keys and values indicated by dotted_string."""
if _CONST_VALUE_RE.match(dotted_string) is not None:
const_str = dotted_string
if const_str == 'None':
constant = None
else:
constant = float(const_str)
return None, constant
else:
tokenized_params = dotted_string.split('.')
v = params_dict
for t in tokenized_params:
v = v[t]
return tokenized_params[-1], v
def _get_kvs(tokens, params_dict):
if len(tokens) != 2:
raise ValueError('Only support binary relation in restriction.')
stripped_tokens = [t.strip() for t in tokens]
left_k, left_v = _get_kv(stripped_tokens[0], params_dict)
right_k, right_v = _get_kv(stripped_tokens[1], params_dict)
return left_k, left_v, right_k, right_v
params_dict = self.as_dict()
for restriction in self._restrictions:
if '==' in restriction:
tokens = restriction.split('==')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v != right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '!=' in restriction:
tokens = restriction.split('!=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v == right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<' in restriction:
tokens = restriction.split('<')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v >= right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<=' in restriction:
tokens = restriction.split('<=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v > right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>' in restriction:
tokens = restriction.split('>')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v <= right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>=' in restriction:
tokens = restriction.split('>=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v < right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
else:
raise ValueError('Unsupported relation in restriction.')
def read_yaml_to_params_dict(file_path: str):
"""Reads a YAML file to a ParamsDict."""
with tf.io.gfile.GFile(file_path, 'r') as f:
params_dict = yaml.load(f, Loader=LOADER)
return ParamsDict(params_dict)
def save_params_dict_to_yaml(params, file_path):
"""Saves the input ParamsDict to a YAML file."""
with tf.io.gfile.GFile(file_path, 'w') as f:
def _my_list_rep(dumper, data):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return dumper.represent_sequence(
u'tag:yaml.org,2002:seq', data, flow_style=True)
yaml.add_representer(list, _my_list_rep)
yaml.dump(params.as_dict(), f, default_flow_style=False)
def nested_csv_str_to_json_str(csv_str):
"""Converts a nested (using '.') comma-separated k=v string to a JSON string.
Converts a comma-separated string of key/value pairs that supports
nesting of keys to a JSON string. Nesting is implemented using
'.' between levels for a given key.
Spacing between commas and = is supported (e.g. there is no difference between
"a=1,b=2", "a = 1, b = 2", or "a=1, b=2") but there should be no spaces before
keys or after values (e.g. " a=1,b=2" and "a=1,b=2 " are not supported).
Note that this will only support values supported by CSV, meaning
values such as nested lists (e.g. "a=[[1,2,3],[4,5,6]]") are not
supported. Strings are supported as well, e.g. "a='hello'".
An example conversion would be:
"a=1, b=2, c.a=2, c.b=3, d.a.a=5"
to
"{ a: 1, b : 2, c: {a : 2, b : 3}, d: {a: {a : 5}}}"
Args:
csv_str: the comma separated string.
Returns:
the converted JSON string.
Raises:
ValueError: If csv_str is not in a comma separated string or
if the string is formatted incorrectly.
"""
if not csv_str:
return ''
formatted_entries = []
nested_map = collections.defaultdict(list)
pos = 0
while pos < len(csv_str):
m = _PARAM_RE.match(csv_str, pos)
if not m:
raise ValueError('Malformed hyperparameter value while parsing '
'CSV string: %s' % csv_str[pos:])
pos = m.end()
# Parse the values.
m_dict = m.groupdict()
name = m_dict['name']
v = m_dict['val']
# If a GCS path (e.g. gs://...) is provided, wrap this in quotes
# as yaml.load would otherwise throw an exception
if re.match(r'(?=[^\"\'])(?=[gs://])', v):
v = '\'{}\''.format(v)
name_nested = name.split('.')
if len(name_nested) > 1:
grouping = name_nested[0]
value = '.'.join(name_nested[1:]) + '=' + v
nested_map[grouping].append(value)
else:
formatted_entries.append('%s : %s' % (name, v))
for grouping, value in nested_map.items():
value = ','.join(value)
value = nested_csv_str_to_json_str(value)
formatted_entries.append('%s : %s' % (grouping, value))
return '{' + ', '.join(formatted_entries) + '}'
def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
"""Override a given ParamsDict using a dict, JSON/YAML/CSV string or YAML file.
The logic of the function is outlined below:
1. Test that the input is a dict. If not, proceed to 2.
2. Tests that the input is a string. If not, raise unknown ValueError
2.1. Test if the string is in a CSV format. If so, parse.
If not, proceed to 2.2.
2.2. Try loading the string as a YAML/JSON. If successful, parse to
dict and use it to override. If not, proceed to 2.3.
2.3. Try using the string as a file path and load the YAML file.
Args:
params: a ParamsDict object to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or path to
a YAML file specifying the parameters to be overridden.
is_strict: a boolean specifying whether override is strict or not.
Returns:
params: the overridden ParamsDict object.
Raises:
ValueError: if failed to override the parameters.
"""
if not dict_or_string_or_yaml_file:
return params
if isinstance(dict_or_string_or_yaml_file, dict):
params.override(dict_or_string_or_yaml_file, is_strict)
elif isinstance(dict_or_string_or_yaml_file, six.string_types):
try:
dict_or_string_or_yaml_file = (
nested_csv_str_to_json_str(dict_or_string_or_yaml_file))
except ValueError:
pass
params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=LOADER)
if isinstance(params_dict, dict):
params.override(params_dict, is_strict)
else:
with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f:
params.override(yaml.load(f, Loader=yaml.FullLoader), is_strict)
else:
raise ValueError('Unknown input type to parse.')
return params
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for params_dict.py."""
import os
import tensorflow as tf
import yaml
from official.modeling.hyperparams import params_dict
class ParamsDictTest(tf.test.TestCase):
def test_init_from_an_empty_dict(self):
params = params_dict.ParamsDict()
with self.assertRaises(AttributeError):
_ = params.a
with self.assertRaises(KeyError):
params.a = 'aa'
def test_init_from_a_dict(self):
params = params_dict.ParamsDict({'a': 'aa', 'b': 2})
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
def test_init_from_a_param_dict(self):
params_init = params_dict.ParamsDict({'a': 'aa', 'b': 2})
params = params_dict.ParamsDict(params_init)
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
def test_lock(self):
params = params_dict.ParamsDict({'a': 1, 'b': 2, 'c': 3})
params.lock()
with self.assertRaises(ValueError):
params.a = 10
with self.assertRaises(ValueError):
params.override({'b': 20})
with self.assertRaises(ValueError):
del params.c
def test_setattr(self):
params = params_dict.ParamsDict()
params.override({'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
params.c = 'ccc'
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
self.assertEqual(params.c, 'ccc')
def test_getattr(self):
params = params_dict.ParamsDict()
params.override({'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
self.assertEqual(params.c, None)
def test_delattr(self):
params = params_dict.ParamsDict()
params.override({
'a': 'aa',
'b': 2,
'c': None,
'd': {
'd1': 1,
'd2': 10
}
},
is_strict=False)
del params.c
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
with self.assertRaises(AttributeError):
_ = params.c
del params.d
with self.assertRaises(AttributeError):
_ = params.d.d1
def test_contains(self):
params = params_dict.ParamsDict()
params.override({'a': 'aa'}, is_strict=False)
self.assertIn('a', params)
self.assertNotIn('b', params)
def test_get(self):
params = params_dict.ParamsDict()
params.override({'a': 'aa'}, is_strict=False)
self.assertEqual(params.get('a'), 'aa')
self.assertEqual(params.get('b', 2), 2)
self.assertEqual(params.get('b'), None)
def test_override_is_strict_true(self):
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 'cc',
'c2': 20
}
})
params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
self.assertEqual(params.a, 2)
self.assertEqual(params.c.c1, 'ccc')
with self.assertRaises(KeyError):
params.override({'d': 'ddd'}, is_strict=True)
with self.assertRaises(KeyError):
params.override({'c': {'c3': 30}}, is_strict=True)
def test_override_is_strict_false(self):
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
params.override({'a': 2, 'c': {'c3': 3000}}, is_strict=False)
self.assertEqual(params.a, 2)
self.assertEqual(params.c.c3, 3000)
params.override({'d': 'ddd'}, is_strict=False)
self.assertEqual(params.d, 'ddd')
params.override({'c': {'c4': 4444}}, is_strict=False)
self.assertEqual(params.c.c4, 4444)
def test_as_dict(self):
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
params_d = params.as_dict()
self.assertEqual(params_d['a'], 'aa')
self.assertEqual(params_d['b'], 2)
self.assertEqual(params_d['c']['c1'], 10)
self.assertEqual(params_d['c']['c2'], 20)
def test_validate(self):
# Raise error due to the unknown parameter.
with self.assertRaises(KeyError):
params = params_dict.ParamsDict({'a': 1, 'b': {'a': 11}}, ['a == c'])
params.validate()
# OK to check equality of two nested dicts.
params = params_dict.ParamsDict({
'a': 1,
'b': {
'a': 10
},
'c': {
'a': 10
}
}, ['b == c'])
# Raise error due to inconsistency
with self.assertRaises(KeyError):
params = params_dict.ParamsDict({'a': 1, 'c': {'a': 10}}, ['a == c.a'])
params.validate()
# Valid rule.
params = params_dict.ParamsDict({'a': 1, 'c': {'a': 1}}, ['a == c.a'])
# Overridding violates the existing rule, raise error upon validate.
params.override({'a': 11})
with self.assertRaises(KeyError):
params.validate()
# Valid restrictions with constant.
params = params_dict.ParamsDict({
'a': None,
'c': {
'a': 1
}
}, ['a == None', 'c.a == 1'])
params.validate()
with self.assertRaises(KeyError):
params = params_dict.ParamsDict({
'a': 4,
'c': {
'a': 1
}
}, ['a == None', 'c.a == 1'])
params.validate()
class ParamsDictIOTest(tf.test.TestCase):
def write_temp_file(self, filename, text):
temp_file = os.path.join(self.get_temp_dir(), filename)
with tf.io.gfile.GFile(temp_file, 'w') as writer:
writer.write(text)
return temp_file
def test_save_params_dict_to_yaml(self):
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
output_yaml_file = os.path.join(self.get_temp_dir(), 'params.yaml')
params_dict.save_params_dict_to_yaml(params, output_yaml_file)
with tf.io.gfile.GFile(output_yaml_file, 'r') as f:
params_d = yaml.load(f)
self.assertEqual(params.a, params_d['a'])
self.assertEqual(params.b, params_d['b'])
self.assertEqual(params.c.c1, params_d['c']['c1'])
self.assertEqual(params.c.c2, params_d['c']['c2'])
def test_read_yaml_to_params_dict(self):
input_yaml_file = self.write_temp_file(
'params.yaml', r"""
a: 'aa'
b: 2
c:
c1: 10
c2: 20
""")
params = params_dict.read_yaml_to_params_dict(input_yaml_file)
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
self.assertEqual(params.c.c1, 10)
self.assertEqual(params.c.c2, 20)
def test_override_params_dict_using_dict(self):
params = params_dict.ParamsDict({
'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_dict = {'b': 5.2, 'c': [30, 40]}
params = params_dict.override_params_dict(
params, override_dict, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(5.2, params.b)
self.assertEqual([30, 40], params.c)
self.assertEqual('hello', params.d)
self.assertEqual(False, params.e)
def test_override_params_dict_using_yaml_string(self):
params = params_dict.ParamsDict({
'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_yaml_string = "'b': 5.2\n'c': [30, 40]"
params = params_dict.override_params_dict(
params, override_yaml_string, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(5.2, params.b)
self.assertEqual([30, 40], params.c)
self.assertEqual('hello', params.d)
self.assertEqual(False, params.e)
def test_override_params_dict_using_json_string(self):
params = params_dict.ParamsDict({
'a': 1,
'b': {
'b1': 2,
'b2': [2, 3],
},
'd': {
'd1': {
'd2': 'hello'
}
},
'e': False
})
override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
params = params_dict.override_params_dict(
params, override_json_string, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(2, params.b.b1)
self.assertEqual([3, 4], params.b.b2)
self.assertEqual('hi', params.d.d1.d2)
self.assertEqual(False, params.e)
def test_override_params_dict_using_csv_string(self):
params = params_dict.ParamsDict({
'a': 1,
'b': {
'b1': 2,
'b2': [2, 3],
},
'd': {
'd1': {
'd2': 'hello'
}
},
'e': False
})
override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
params = params_dict.override_params_dict(
params, override_csv_string, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(2, params.b.b1)
self.assertEqual([3, 4], params.b.b2)
self.assertEqual('hi, world', params.d.d1.d2)
self.assertEqual('gs://test', params.e)
# Test different float formats
override_csv_string = 'b.b2=-1.e-3, d.d1.d2=+0.001, e=1e+3, a=-1.5E-3'
params = params_dict.override_params_dict(
params, override_csv_string, is_strict=True)
self.assertEqual(-1e-3, params.b.b2)
self.assertEqual(0.001, params.d.d1.d2)
self.assertEqual(1e3, params.e)
self.assertEqual(-1.5e-3, params.a)
def test_override_params_dict_using_yaml_file(self):
params = params_dict.ParamsDict({
'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_yaml_file = self.write_temp_file(
'params.yaml', r"""
b: 5.2
c: [30, 40]
""")
params = params_dict.override_params_dict(
params, override_yaml_file, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(5.2, params.b)
self.assertEqual([30, 40], params.c)
self.assertEqual('hello', params.d)
self.assertEqual(False, params.e)
class IOTest(tf.test.TestCase):
def test_basic_csv_str_to_json_str(self):
csv_str = 'a=1,b=2,c=3'
json_str = '{a : 1, b : 2, c : 3}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, json_str)
def test_basic_csv_str_load(self):
csv_str = 'a=1,b=2,c=3'
expected_output = {'a': 1, 'b': 2, 'c': 3}
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertDictEqual(converted_dict, expected_output)
def test_basic_nested_csv_str_to_json_str(self):
csv_str = 'a=1,b.b1=2'
json_str = '{a : 1, b : {b1 : 2}}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, json_str)
def test_basic_nested_csv_str_load(self):
csv_str = 'a=1,b.b1=2,c.c1=3'
expected_output = {'a': 1, 'b': {'b1': 2}, 'c': {'c1': 3}}
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertDictEqual(converted_dict, expected_output)
def test_complex_nested_csv_str_to_json_str(self):
csv_str = 'a.aa.aaa.aaaaa.a=1'
json_str = '{a : {aa : {aaa : {aaaaa : {a : 1}}}}}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, json_str)
def test_complex_nested_csv_str_load(self):
csv_str = 'a.aa.aaa.aaaaa.a=1,a.a=2'
expected_output = {'a': {'aa': {'aaa': {'aaaaa': {'a': 1}}}, 'a': 2}}
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertDictEqual(converted_dict, expected_output)
def test_csv_str_load_supported_datatypes(self):
csv_str = 'a=1,b=2.,c=[1,2,3],d=\'hello, there\',e=\"Hi.\"'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertEqual(converted_dict['a'], 1)
self.assertEqual(converted_dict['b'], 2.)
self.assertEqual(converted_dict['c'], [1, 2, 3])
self.assertEqual(converted_dict['d'], 'hello, there')
self.assertEqual(converted_dict['e'], 'Hi.')
def test_csv_str_load_unsupported_datatypes(self):
csv_str = 'a=[[1,2,3],[4,5,6]]'
self.assertRaises(ValueError, params_dict.nested_csv_str_to_json_str,
csv_str)
def test_csv_str_to_json_str_spacing(self):
csv_str1 = 'a=1,b=2,c=3'
csv_str2 = 'a = 1, b = 2, c = 3'
json_str = '{a : 1, b : 2, c : 3}'
converted_csv_str1 = params_dict.nested_csv_str_to_json_str(csv_str1)
converted_csv_str2 = params_dict.nested_csv_str_to_json_str(csv_str2)
self.assertEqual(converted_csv_str1, converted_csv_str2)
self.assertEqual(converted_csv_str1, json_str)
self.assertEqual(converted_csv_str2, json_str)
def test_gcs_added_quotes(self):
csv_str = 'a=gs://abc, b=gs://def'
expected_output = '{a : \'gs://abc\', b : \'gs://def\'}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, expected_output)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstraction of multi-task model."""
from typing import Text, Dict
import tensorflow as tf
class MultiTaskBaseModel(tf.Module):
"""Base class that holds multi-task model computation."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._sub_tasks = self._instantiate_sub_tasks()
def _instantiate_sub_tasks(self) -> Dict[Text, tf.keras.Model]:
"""Abstract function that sets up the computation for each sub-task.
Returns:
A map from task name (as string) to a tf.keras.Model object that
represents the sub-task in the multi-task pool.
"""
raise NotImplementedError(
"_instantiate_sub_task_models() is not implemented.")
@property
def sub_tasks(self):
"""Fetch a map of task name (string) to task model (tf.keras.Model)."""
return self._sub_tasks
def initialize(self):
"""Optional function that loads a pre-train checkpoint."""
return
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask base trainer implementation.
The trainer derives from the Orbit `StandardTrainer` class.
"""
from typing import Union
import gin
import orbit
import tensorflow as tf
from official.modeling import optimization
from official.modeling.multitask import base_model
from official.modeling.multitask import multitask
@gin.configurable
class MultiTaskBaseTrainer(orbit.StandardTrainer):
"""Multitask base trainer."""
def __init__(self,
multi_task: multitask.MultiTask,
multi_task_model: Union[tf.keras.Model,
base_model.MultiTaskBaseModel],
optimizer: tf.optimizers.Optimizer,
trainer_options=None,
train_datasets=None):
self._strategy = tf.distribute.get_strategy()
self._multi_task = multi_task
self._multi_task_model = multi_task_model
self._optimizer = optimizer
self._training_losses = None
self._training_metrics = None
self._global_step = orbit.utils.create_global_step()
# Creates a shadow copy of the weights to store weights moving average.
if isinstance(self._optimizer, optimization.ExponentialMovingAverage
) and not self._optimizer.has_shadow_copy:
self._optimizer.shadow_copy(multi_task_model)
if hasattr(self.multi_task_model, "checkpoint_items"):
checkpoint_items = self.multi_task_model.checkpoint_items
else:
checkpoint_items = {}
self._checkpoint = tf.train.Checkpoint(
model=self.multi_task_model,
optimizer=self.optimizer,
global_step=self.global_step,
**checkpoint_items)
if train_datasets is None:
train_datasets = {}
for name, task in self.multi_task.tasks.items():
train_datasets[name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.train_data)
super().__init__(
train_dataset=train_datasets,
options=trainer_options or orbit.StandardTrainerOptions())
def train_loop_begin(self):
"""Clean up states that hold losses and metrics."""
for _, train_loss_metric in self.training_losses.items():
train_loss_metric.reset_states()
for _, metrics in self.training_metrics.items():
for metric in metrics:
metric.reset_states()
def train_loop_end(self):
"""Record loss and metric values per task."""
result = {}
for task_name, loss in self.training_losses.items():
result[task_name] = {loss.name: loss.result()}
for task_name, task_metrics in self.training_metrics.items():
result[task_name].update(
{metric.name: metric.result() for metric in task_metrics})
# Note that, the learning rate schedule is managed by the keras optimizer
# internally, which respects the number of backward pass as `iterations`.
# The learning rate schedule does not follow the trainer logical global
# step of multiple tasks.
if callable(self.optimizer.learning_rate):
result["learning_rate"] = self.optimizer.learning_rate(
self.optimizer.iterations)
else:
result["learning_rate"] = self.optimizer.learning_rate
return result
@property
def checkpoint(self):
"""Accesses the training checkpoint."""
return self._checkpoint
@property
def training_losses(self):
"""Access training loss metric objects for all tasks."""
if self._training_losses is None:
# Builds the per-task metrics and losses.
# This the total summed training loss of tasks in the joint training.
self._training_losses = dict(
total_loss=tf.keras.metrics.Mean("training_loss", dtype=tf.float32))
for name in self.multi_task.tasks:
self._training_losses[name] = tf.keras.metrics.Mean(
"training_loss", dtype=tf.float32)
return self._training_losses
@property
def training_metrics(self):
"""Access training metric metric objects for all tasks."""
if self._training_metrics is None:
# Builds the per-task metrics and losses.
self._training_metrics = {}
for name, task in self.multi_task.tasks.items():
self._training_metrics[name] = task.build_metrics(training=True)
return self._training_metrics
@property
def strategy(self):
return self._strategy
@property
def multi_task(self):
return self._multi_task
@property
def multi_task_model(self):
return self._multi_task_model
@property
def optimizer(self):
return self._optimizer
@property
def global_step(self):
return self._global_step
def train_step(self, iterator_map):
"""The default train step calling the multi-task train step.
Args:
iterator_map: a dictionary of task names and per-task dataset iterators.
"""
def step_fn(inputs):
losses = self.multi_task.joint_train_step(
inputs,
multi_task_model=self.multi_task_model,
optimizer=self.optimizer,
task_metrics=self.training_metrics)
for key, loss in losses.items():
self.training_losses[key].update_state(loss)
self.strategy.run(
step_fn, args=(tf.nest.map_structure(next, iterator_map),))
self.global_step.assign_add(1)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.base_trainer."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.modeling.multitask import base_trainer
from official.modeling.multitask import configs
from official.modeling.multitask import multitask
from official.modeling.multitask import test_utils
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
class BaseTrainerTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_multitask_joint_trainer(self, distribution):
with distribution.scope():
tasks = [
test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"),
test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar")
]
task_weights = {"foo": 1.0, "bar": 1.0}
test_multitask = multitask.MultiTask(
tasks=tasks, task_weights=task_weights)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
test_trainer = base_trainer.MultiTaskBaseTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer)
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
def test_trainer_with_configs(self):
config = configs.MultiTaskConfig(
task_routines=(configs.TaskRoutine(
task_name="foo",
task_config=test_utils.FooConfig(),
task_weight=0.5),
configs.TaskRoutine(
task_name="bar",
task_config=test_utils.BarConfig(),
task_weight=0.5)))
test_multitask = multitask.MultiTask.from_config(config)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
test_trainer = base_trainer.MultiTaskBaseTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer)
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
self.assertEqual(test_multitask.task_weight("foo"), 0.5)
self.assertEqual(test_trainer.global_step.numpy(), 5)
self.assertIn("learning_rate", results)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration definitions for multi-task training."""
from typing import Optional, Tuple
import dataclasses
from official.core import config_definitions as cfg
from official.modeling import hyperparams
@dataclasses.dataclass
class TaskRoutine(hyperparams.Config):
# TODO(hongkuny): deprecate the task_name once we migrated client code.
task_name: str = ""
task_config: cfg.TaskConfig = None
eval_steps: Optional[int] = None
task_weight: Optional[float] = 1.0
@dataclasses.dataclass
class MultiTaskConfig(hyperparams.Config):
init_checkpoint: str = ""
model: hyperparams.Config = None
task_routines: Tuple[TaskRoutine, ...] = ()
@dataclasses.dataclass
class ProportionalSampleConfig(hyperparams.Config):
alpha: float = 1.0
@dataclasses.dataclass
class AnnealingSampleConfig(hyperparams.Config):
steps_per_epoch: int = 5
total_steps: int = 20
@dataclasses.dataclass
class TaskSamplingConfig(hyperparams.OneOfConfig):
type: str = ""
uniform: hyperparams.Config = hyperparams.Config()
proportional: ProportionalSampleConfig = ProportionalSampleConfig()
annealing: AnnealingSampleConfig = AnnealingSampleConfig()
@dataclasses.dataclass
class MultiTaskTrainerConfig(cfg.TrainerConfig):
trainer_type: str = "interleaving"
task_sampler: TaskSamplingConfig = TaskSamplingConfig(type="proportional")
@dataclasses.dataclass
class MultiTaskExperimentConfig(hyperparams.Config):
"""An experiment config for multi-task training and multi-task evaluation."""
task: MultiTaskConfig = MultiTaskConfig()
trainer: MultiTaskTrainerConfig = MultiTaskTrainerConfig()
runtime: cfg.RuntimeConfig = cfg.RuntimeConfig()
@dataclasses.dataclass
class MultiEvalExperimentConfig(cfg.ExperimentConfig):
"""An experiment config for single-task training and multi-task evaluation.
Attributes:
eval_tasks: individual evaluation tasks.
"""
eval_tasks: Tuple[TaskRoutine, ...] = ()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask Evaluator implementation.
The evaluator implements the Orbit `AbstractEvaluator` interface.
"""
from typing import Dict, List, Optional, Union
import gin
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import train_utils
from official.modeling.multitask import base_model
@gin.configurable
class MultiTaskEvaluator(orbit.AbstractEvaluator):
"""Implements the common trainer shared for TensorFlow models."""
def __init__(
self,
eval_tasks: List[base_task.Task],
model: Union[tf.keras.Model, base_model.MultiTaskBaseModel],
global_step: Optional[tf.Variable] = None,
eval_steps: Optional[Dict[str, int]] = None,
checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None):
"""Initialize common trainer for TensorFlow models.
Args:
eval_tasks: A list of tasks to evaluate.
model: tf.keras.Model instance.
global_step: the global step variable.
eval_steps: a dictionary of steps to run eval keyed by task names.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._tasks = eval_tasks
self._model = model
self._global_step = global_step or orbit.utils.create_global_step()
self._checkpoint_exporter = checkpoint_exporter
if hasattr(self.model, "checkpoint_items"):
checkpoint_items = self.model.checkpoint_items
else:
checkpoint_items = {}
self._checkpoint = tf.train.Checkpoint(
model=self.model,
global_step=self.global_step,
**checkpoint_items)
self._validation_losses = None
self._validation_metrics = None
# Builds per-task datasets.
self.eval_datasets = {}
self.eval_steps = eval_steps or {}
for task in self.tasks:
self.eval_datasets[task.name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.validation_data)
# Builds per-task validation loops.
def get_function(task_name, task):
task_metrics = self.validation_metrics[task_name]
task_loss = self.validation_losses[task_name]
if isinstance(self.model, base_model.MultiTaskBaseModel):
model = self.model.sub_tasks[task_name]
else:
model = self.model
def step_fn(inputs):
logs = task.validation_step(inputs, model=model, metrics=task_metrics)
task_loss.update_state(logs[task.loss])
return logs
@tf.function
def eval_step_fn(iterator):
distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),))
return tf.nest.map_structure(self.strategy.experimental_local_results,
distributed_outputs)
return orbit.utils.create_loop_fn(eval_step_fn)
self.task_fns = {
task.name: get_function(task.name, task) for task in self.tasks
}
@property
def strategy(self):
return self._strategy
@property
def tasks(self):
return self._tasks
@property
def model(self):
return self._model
@property
def global_step(self):
return self._global_step
@property
def validation_losses(self):
"""Accesses the validation loss metric object."""
if self._validation_losses is None:
# Builds the per-task metrics and losses.
self._validation_losses = {}
for task in self.tasks:
self._validation_losses[task.name] = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32)
return self._validation_losses
@property
def validation_metrics(self):
"""Accesses all validation metric metric objects."""
if self._validation_metrics is None:
# Builds the per-task metrics and losses.
self._validation_metrics = {}
for task in self.tasks:
self._validation_metrics[task.name] = task.build_metrics(training=False)
return self._validation_metrics
@property
def checkpoint(self):
"""Accesses the training checkpoint."""
return self._checkpoint
def evaluate(self, num_steps: tf.Tensor):
"""Performs evaluation for each `EvalTask`."""
for metric in self.validation_losses.values():
metric.reset_states()
for metrics in self.validation_metrics.values():
for metric in metrics:
metric.reset_states()
results = {}
eval_iters = tf.nest.map_structure(iter, self.eval_datasets)
for task in self.tasks:
outputs = None
name = task.name
eval_iter = eval_iters[name]
task_eval_steps = self.eval_steps.get(name, None) or num_steps
outputs = self.task_fns[name](
eval_iter,
task_eval_steps,
state=outputs,
reduce_fn=task.aggregate_logs)
task_metrics = self.validation_metrics[name]
task_loss = self.validation_losses[name]
logs = {}
for metric in task_metrics + [task_loss]:
logs[metric.name] = metric.result()
if outputs:
metrics = task.reduce_aggregated_logs(
outputs, global_step=self.global_step)
logs.update(metrics)
results[name] = logs
if self._checkpoint_exporter:
self._checkpoint_exporter.maybe_export_checkpoint(
self.checkpoint, results, self.global_step.numpy())
return results
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.evaluator."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import base_task
from official.core import config_definitions as cfg
from official.modeling.multitask import evaluator
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
class MockModel(tf.keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dense = tf.keras.layers.Dense(1)
def call(self, inputs):
print(inputs, type(inputs))
if "y" in inputs:
self.add_loss(tf.zeros((1,), dtype=tf.float32))
else:
self.add_loss(tf.ones((1,), dtype=tf.float32))
return self.dense(inputs["x"])
class MockTask(base_task.Task):
"""Mock task object for testing."""
def build_metrics(self, training: bool = True):
del training
return [tf.keras.metrics.Accuracy(name="acc")]
def build_inputs(self, params):
def generate_data(_):
x = tf.zeros(shape=(2,), dtype=tf.float32)
label = tf.zeros([1], dtype=tf.int32)
if self.name == "bar":
return dict(x=x, y=x), label
else:
return dict(x=x), label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset.prefetch(buffer_size=1).batch(2, drop_remainder=True)
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
logs = super().validation_step(inputs, model, metrics)
logs["counter"] = tf.ones((1,), dtype=tf.float32)
return logs
def aggregate_logs(self, state, step_outputs):
if state is None:
state = {}
for key, value in step_outputs.items():
if key not in state:
state[key] = []
state[key].append(
np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value]))
return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
for k, v in aggregated_logs.items():
aggregated_logs[k] = np.sum(np.stack(v, axis=0))
return aggregated_logs
class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_multitask_evaluator(self, distribution):
with distribution.scope():
tasks = [
MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo")
]
model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator(
eval_tasks=tasks, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertContainsSubset(["validation_loss", "acc"], results["bar"].keys())
self.assertContainsSubset(["validation_loss", "acc"], results["foo"].keys())
self.assertEqual(results["bar"]["validation_loss"], 0.0)
self.assertEqual(results["foo"]["validation_loss"], 1.0)
@combinations.generate(all_strategy_combinations())
def test_multitask_evaluator_numpy_metrics(self, distribution):
with distribution.scope():
tasks = [
MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo")
]
model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator(
eval_tasks=tasks, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(results["bar"]["counter"],
5. * distribution.num_replicas_in_sync)
self.assertEqual(results["foo"]["counter"],
5. * distribution.num_replicas_in_sync)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask trainer that interleaves each task's train step."""
from typing import Union
import gin
import orbit
import tensorflow as tf
from official.modeling.multitask import base_model
from official.modeling.multitask import base_trainer
from official.modeling.multitask import multitask
from official.modeling.multitask import task_sampler as sampler
@gin.configurable
class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
"""MultiTask trainer that interleaves task update."""
def __init__(self,
multi_task: multitask.MultiTask,
multi_task_model: Union[tf.keras.Model,
base_model.MultiTaskBaseModel],
optimizer: tf.optimizers.Optimizer,
task_sampler: sampler.TaskSampler,
trainer_options=None):
super().__init__(
multi_task=multi_task,
multi_task_model=multi_task_model,
optimizer=optimizer,
trainer_options=trainer_options)
self._task_sampler = task_sampler
# Build per task train step.
def _get_task_step(task_name, task):
def step_fn(inputs):
if isinstance(self.multi_task_model, base_model.MultiTaskBaseModel):
task_model = self.multi_task_model.sub_tasks[task_name]
else:
task_model = self.multi_task_model
task_logs = task.train_step(
inputs,
model=task_model,
optimizer=self.optimizer,
metrics=self.training_metrics[task_name])
self.training_losses[task_name].update_state(task_logs[task.loss])
return step_fn
self._task_train_step_map = {
name: _get_task_step(name, task)
for name, task in self.multi_task.tasks.items()
}
# TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging
# on TensorBoard.
self._task_step_counters = {
name: orbit.utils.create_global_step() for name in self.multi_task.tasks
}
def task_step_counter(self, name):
return self._task_step_counters[name]
def train_step(self, iterator_map):
# Sample one task to train according to a multinomial distribution
rn = tf.random.stateless_uniform(shape=[], seed=(0, self.global_step))
cumulative_sample_distribution = self._task_sampler.task_cumulative_distribution(
self.global_step)
# Prepend a [0.0] for indexing convenience.
cumulative_sample_distribution = tf.concat(
[tf.constant([0.0], dtype=tf.float32), cumulative_sample_distribution],
axis=0)
for idx, (name, _) in enumerate(self.multi_task.tasks.items()):
begin = cumulative_sample_distribution[idx]
end = cumulative_sample_distribution[idx + 1]
if rn >= begin and rn < end:
self._strategy.run(
self._task_train_step_map[name], args=(next(iterator_map[name]),))
self.global_step.assign_add(1)
self.task_step_counter(name).assign_add(1)
def train_loop_end(self):
"""Record loss and metric values per task."""
result = super().train_loop_end()
# Interleaving training does not have a good semantic for `total_loss`. In
# fact, it is always zero. To avoid confusion, we filter the `total_loss`
# from the result logs.
if 'total_loss' in result:
result.pop('total_loss')
return result
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.interleaving_trainer."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.modeling.multitask import configs
from official.modeling.multitask import interleaving_trainer
from official.modeling.multitask import multitask
from official.modeling.multitask import task_sampler
from official.modeling.multitask import test_utils
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_multitask_interleaving_trainer(self, distribution):
with distribution.scope():
tasks = [
test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"),
test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar")
]
test_multitask = multitask.MultiTask(tasks=tasks)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
sampler = task_sampler.UniformTaskSampler(
task_weights=test_multitask.task_weights)
test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer,
task_sampler=sampler)
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
self.assertNotIn("total_loss", results)
@combinations.generate(all_strategy_combinations())
def test_trainer_with_configs(self, distribution):
config = configs.MultiTaskConfig(
task_routines=(configs.TaskRoutine(
task_name="foo",
task_config=test_utils.FooConfig(),
task_weight=3.0),
configs.TaskRoutine(
task_name="bar",
task_config=test_utils.BarConfig(),
task_weight=1.0)))
with distribution.scope():
test_multitask = multitask.MultiTask.from_config(config)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
num_step = 1000
sampler = task_sampler.AnnealingTaskSampler(
task_weights=test_multitask.task_weights,
steps_per_epoch=num_step/5,
total_steps=num_step)
test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer,
task_sampler=sampler)
results = test_trainer.train(tf.convert_to_tensor(num_step, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
self.assertEqual(test_trainer.global_step.numpy(), num_step)
bar_sampled_step = test_trainer.task_step_counter("bar").numpy()
foo_sampled_step = test_trainer.task_step_counter("foo").numpy()
self.assertEqual(bar_sampled_step + foo_sampled_step, num_step)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental MultiTask base class for multi-task training/evaluation."""
import abc
from typing import Dict, List, Optional, Text, Union
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions
from official.core import task_factory
from official.modeling import optimization
from official.modeling.multitask import base_model
from official.modeling.multitask import configs
OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig
class MultiTask(tf.Module, metaclass=abc.ABCMeta):
"""A multi-task class to manage multiple tasks."""
def __init__(self,
tasks: Union[Dict[Text, base_task.Task], List[base_task.Task]],
task_weights: Optional[Dict[str, Union[float, int]]] = None,
task_eval_steps: Optional[Dict[str, int]] = None,
name: Optional[str] = None):
"""MultiTask initialization.
Args:
tasks: a list or a flat dict of Task.
task_weights: a dict of (task, task weight), task weight can be applied
directly during loss summation in a joint backward step, or it can be
used to sample task among interleaved backward step.
task_eval_steps: a dict of (task, eval steps).
name: the instance name of a MultiTask object.
"""
super().__init__(name=name)
if isinstance(tasks, list):
self._tasks = {}
for task in tasks:
if task.name in self._tasks:
raise ValueError("Duplicated tasks found, task.name is %s" %
task.name)
self._tasks[task.name] = task
elif isinstance(tasks, dict):
self._tasks = tasks
else:
raise ValueError("The tasks argument has an invalid type: %s" %
type(tasks))
self.task_eval_steps = task_eval_steps or {}
self._task_weights = task_weights or {}
self._task_weights = dict([
(name, self._task_weights.get(name, 1.0)) for name in self.tasks
])
@classmethod
def from_config(cls, config: configs.MultiTaskConfig, logging_dir=None):
tasks = {}
task_eval_steps = {}
task_weights = {}
for task_routine in config.task_routines:
task_name = task_routine.task_name or task_routine.task_config.name
tasks[task_name] = task_factory.get_task(
task_routine.task_config, logging_dir=logging_dir, name=task_name)
task_eval_steps[task_name] = task_routine.eval_steps
task_weights[task_name] = task_routine.task_weight
return cls(
tasks, task_eval_steps=task_eval_steps, task_weights=task_weights)
@property
def tasks(self):
return self._tasks
def task_weight(self, task_name):
return self._task_weights[task_name]
@property
def task_weights(self):
return self._task_weights
@classmethod
def create_optimizer(cls,
optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None):
return base_task.Task.create_optimizer(
optimizer_config=optimizer_config, runtime_config=runtime_config)
def joint_train_step(self, task_inputs,
multi_task_model: base_model.MultiTaskBaseModel,
optimizer: tf.keras.optimizers.Optimizer, task_metrics,
**kwargs):
"""The joint train step.
Args:
task_inputs: a dictionary of task names and per-task features.
multi_task_model: a MultiTaskBaseModel instance.
optimizer: a tf.optimizers.Optimizer.
task_metrics: a dictionary of task names and per-task metrics.
**kwargs: other arguments to pass through.
Returns:
A dictionary of losses, inculding per-task losses and their weighted sum.
"""
losses = {}
with tf.GradientTape() as tape:
total_loss = 0.0
for name, model in multi_task_model.sub_tasks.items():
inputs = task_inputs[name]
if isinstance(inputs, tuple) and len(inputs) == 2:
features, labels = inputs
elif isinstance(inputs, dict):
features, labels = inputs, inputs
else:
raise ValueError("The iterator output is neither a tuple nor a "
"dictionary. It is not implemented to support "
"such outputs.")
outputs = model(features, training=True)
task_loss = self.tasks[name].build_losses(labels, outputs)
task_weight = self.task_weight(name)
total_loss += task_weight * task_loss
losses[name] = task_loss
self.tasks[name].process_metrics(task_metrics[name], labels, outputs,
**kwargs)
# Scales loss as the default gradients allreduce performs sum inside
# the optimizer.
scaled_loss = total_loss / tf.distribute.get_strategy(
).num_replicas_in_sync
tvars = multi_task_model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
losses["total_loss"] = total_loss
return losses
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils to sample tasks for interleaved optimization."""
import abc
from typing import Union, Dict, Text
import tensorflow as tf
from official.modeling.multitask import configs
class TaskSampler(tf.Module, metaclass=abc.ABCMeta):
"""An abstract class defining task sampling API for interleaving trainer."""
def __init__(self, task_weights: Dict[Text, Union[float, int]]):
self._task_weights = task_weights
@property
def task_weights(self):
return self._task_weights
@abc.abstractmethod
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
"""Compute cumulative distribution to sample tasks.
It calculates the cumulative distribution of the multinomial task
distribution with respect to which to be sampled against.
Args:
global_step: A tensor indicating current progess of training.
Returns:
A float tensor with shape (#(task), 1) that represents the cumulative
sampling distribution.
"""
pass
class UniformTaskSampler(TaskSampler):
"""Sample all tasks uniformly."""
def __init__(self, task_weights: Dict[Text, Union[float, int]]):
super(UniformTaskSampler, self).__init__(task_weights=task_weights)
self._uniform_cumulative = tf.math.cumsum(
tf.constant(
[1.0 / len(self._task_weights)] * len(self._task_weights),
dtype=tf.float32))
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
del global_step
return self._uniform_cumulative
class ProportionalTaskSampler(TaskSampler):
"""Sample tasks proportional to task weights."""
def __init__(self,
task_weights: Dict[Text, Union[float, int]],
alpha: float = 1.0):
super(ProportionalTaskSampler, self).__init__(task_weights=task_weights)
self._alpha = tf.cast(alpha, dtype=tf.float32)
task_weight_dict_ordered_list = tf.constant(
[weight for _, weight in self._task_weights.items()], dtype=tf.float32)
task_sizes = tf.math.pow(task_weight_dict_ordered_list, self._alpha)
task_distribution = task_sizes / tf.reduce_sum(task_sizes)
self._porportional_cumulative = tf.math.cumsum(task_distribution)
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
del global_step
return self._porportional_cumulative
class AnnealingTaskSampler(TaskSampler):
"""Sample tasks according to task weights as well as training progress.
See http://proceedings.mlr.press/v97/stickland19a/stickland19a.pdf
"""
def __init__(self,
task_weights: Dict[Text, Union[float, int]],
steps_per_epoch: int,
total_steps: int):
super(AnnealingTaskSampler, self).__init__(task_weights=task_weights)
self._steps_per_epoch = tf.cast(steps_per_epoch, dtype=tf.float32)
self._total_epochs = tf.cast(
total_steps / self._steps_per_epoch, dtype=tf.float32)
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
cur_epoch = tf.math.floor(
tf.cast(global_step, dtype=tf.float32) / self._steps_per_epoch)
alpha = 1.0 - 0.8 * (cur_epoch - 1) / (self._total_epochs - 1 + 1e-10)
task_weight_dict_ordered_list = [
weight for _, weight in self._task_weights.items()
]
task_sizes = tf.math.pow(
tf.constant(task_weight_dict_ordered_list, dtype=tf.float32),
tf.cast(alpha, dtype=tf.float32))
dynamic_task_distribution = task_sizes / tf.reduce_sum(task_sizes)
return tf.math.cumsum(dynamic_task_distribution)
def get_task_sampler(config: configs.TaskSamplingConfig,
task_weights: Dict[Text, float]) -> TaskSampler:
"""Utils to create task sampler with configuration and task weights."""
oneof_config = config.get()
if config.type == 'uniform':
return UniformTaskSampler(task_weights=task_weights)
elif config.type == 'proportional':
return ProportionalTaskSampler(
task_weights=task_weights, alpha=oneof_config.alpha)
elif config.type == 'annealing':
return AnnealingTaskSampler(
task_weights=task_weights,
steps_per_epoch=oneof_config.steps_per_epoch,
total_steps=oneof_config.total_steps)
else:
raise RuntimeError('Task sampler type not supported')
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.task_sampler."""
import tensorflow as tf
from official.modeling.multitask import configs
from official.modeling.multitask import task_sampler as sampler
class TaskSamplerTest(tf.test.TestCase):
def setUp(self):
super(TaskSamplerTest, self).setUp()
self._task_weights = {'A': 1.0, 'B': 2.0, 'C': 3.0}
def test_uniform_sample_distribution(self):
uniform_sampler = sampler.get_task_sampler(
configs.TaskSamplingConfig(type='uniform'), self._task_weights)
for step in range(5):
cumulative_distribution = uniform_sampler.task_cumulative_distribution(
tf.constant(step, dtype=tf.int64))
self.assertAllClose([0.333333, 0.666666, 1.0],
cumulative_distribution.numpy())
def test_proportional_sample_distribution(self):
prop_sampler = sampler.get_task_sampler(
configs.TaskSamplingConfig(
type='proportional',
proportional=configs.ProportionalSampleConfig(alpha=2.0)),
self._task_weights)
# CucmulativeOf(Normalize([1.0^2, 2.0^2, 3.0^2]))
for step in range(5):
cumulative_distribution = prop_sampler.task_cumulative_distribution(
tf.constant(step, dtype=tf.int64))
self.assertAllClose([0.07142857, 0.35714286, 1.0],
cumulative_distribution.numpy())
def test_annealing_sample_distribution(self):
num_epoch = 3
step_per_epoch = 6
annel_sampler = sampler.get_task_sampler(
configs.TaskSamplingConfig(
type='annealing',
annealing=configs.AnnealingSampleConfig(
steps_per_epoch=step_per_epoch,
total_steps=step_per_epoch * num_epoch)), self._task_weights)
global_step = tf.Variable(
0, dtype=tf.int64, name='global_step', trainable=False)
expected_cumulative_epochs = [[0.12056106, 0.4387236, 1.0],
[0.16666667, 0.5, 1.0],
[0.22477472, 0.5654695, 1.0]]
for epoch in range(num_epoch):
for _ in range(step_per_epoch):
cumulative_distribution = annel_sampler.task_cumulative_distribution(
tf.constant(global_step, dtype=tf.int64))
global_step.assign_add(1)
self.assertAllClose(expected_cumulative_epochs[epoch],
cumulative_distribution.numpy())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing utils for mock models and tasks."""
from typing import Dict, Text
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling.multitask import base_model
class MockFooModel(tf.keras.Model):
"""A mock model can consume 'foo' and 'bar' inputs."""
def __init__(self, shared_layer, *args, **kwargs):
super().__init__(*args, **kwargs)
self._share_layer = shared_layer
self._foo_specific_layer = tf.keras.layers.Dense(1)
def call(self, inputs):
self.add_loss(tf.zeros((1,), dtype=tf.float32))
if "foo" in inputs:
input_tensor = inputs["foo"]
else:
input_tensor = inputs["bar"]
return self._foo_specific_layer(self._share_layer(input_tensor))
class MockBarModel(tf.keras.Model):
def __init__(self, shared_layer, *args, **kwargs):
super().__init__(*args, **kwargs)
self._share_layer = shared_layer
self._bar_specific_layer = tf.keras.layers.Dense(1)
def call(self, inputs):
self.add_loss(tf.zeros((2,), dtype=tf.float32))
return self._bar_specific_layer(self._share_layer(inputs["bar"]))
class MockMultiTaskModel(base_model.MultiTaskBaseModel):
def __init__(self, *args, **kwargs):
self._shared_dense = tf.keras.layers.Dense(1)
super().__init__(*args, **kwargs)
def _instantiate_sub_tasks(self) -> Dict[Text, tf.keras.Model]:
return {
"foo": MockFooModel(self._shared_dense),
"bar": MockBarModel(self._shared_dense)
}
def mock_data(feature_name):
"""Mock dataset function."""
def _generate_data(_):
x = tf.zeros(shape=(2,), dtype=tf.float32)
label = tf.zeros([1], dtype=tf.int32)
return {feature_name: x}, label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
_generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset.prefetch(buffer_size=1).batch(2, drop_remainder=True)
class FooConfig(cfg.TaskConfig):
pass
class BarConfig(cfg.TaskConfig):
pass
@task_factory.register_task_cls(FooConfig)
class MockFooTask(base_task.Task):
"""Mock foo task object for testing."""
def build_metrics(self, training: bool = True):
del training
return [tf.keras.metrics.Accuracy(name="foo_acc")]
def build_inputs(self, params):
return mock_data("foo")
def build_model(self) -> tf.keras.Model:
return MockFooModel(shared_layer=tf.keras.layers.Dense(1))
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = tf.keras.losses.mean_squared_error(labels, model_outputs)
if aux_losses:
loss += tf.add_n(aux_losses)
return tf.reduce_mean(loss)
@task_factory.register_task_cls(BarConfig)
class MockBarTask(base_task.Task):
"""Mock bar task object for testing."""
def build_metrics(self, training: bool = True):
del training
return [tf.keras.metrics.Accuracy(name="bar_acc")]
def build_inputs(self, params):
return mock_data("bar")
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
loss = tf.keras.losses.mean_squared_error(labels, model_outputs)
if aux_losses:
loss += tf.add_n(aux_losses)
return tf.reduce_mean(loss)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask training driver library."""
# pytype: disable=attribute-error
import os
from typing import Any, List, Optional, Tuple
from absl import logging
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import base_trainer as core_lib
from official.core import train_utils
from official.modeling.multitask import base_model
from official.modeling.multitask import base_trainer
from official.modeling.multitask import configs
from official.modeling.multitask import evaluator as evaluator_lib
from official.modeling.multitask import interleaving_trainer
from official.modeling.multitask import multitask
from official.modeling.multitask import task_sampler
TRAINERS = {
'interleaving': interleaving_trainer.MultiTaskInterleavingTrainer,
'joint': base_trainer.MultiTaskBaseTrainer
}
def run_experiment(
*,
distribution_strategy: tf.distribute.Strategy,
task: multitask.MultiTask,
model: base_model.MultiTaskBaseModel,
mode: str,
params: configs.MultiTaskExperimentConfig,
model_dir: str,
trainer: base_trainer.MultiTaskBaseTrainer = None
) -> base_model.MultiTaskBaseModel:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A MultiTaskTask instance.
model: A MultiTaskBaseModel instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
trainer: (optional) A multi-task trainer to use. If none is provided, a
default one will be created based on `params`.
Returns:
model: `base_model.MultiTaskBaseModel` instance.
"""
is_training = 'train' in mode
is_eval = 'eval' in mode
with distribution_strategy.scope():
optimizer = task.create_optimizer(params.trainer.optimizer_config,
params.runtime)
kwargs = dict(multi_task=task, multi_task_model=model, optimizer=optimizer)
if params.trainer.trainer_type == 'interleaving':
sampler = task_sampler.get_task_sampler(params.trainer.task_sampler,
task.task_weights)
kwargs.update(dict(task_sampler=sampler))
if trainer is None:
trainer = TRAINERS[params.trainer.trainer_type](
**kwargs) if is_training else None
if is_eval:
eval_steps = task.task_eval_steps
evaluator = evaluator_lib.MultiTaskEvaluator(
eval_tasks=task.tasks.values(),
model=model,
eval_steps=eval_steps,
global_step=trainer.global_step if is_training else None,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir))
else:
evaluator = None
if trainer:
checkpoint = trainer.checkpoint
global_step = trainer.global_step
else:
checkpoint = evaluator.checkpoint
global_step = evaluator.global_step
# TODO(hongkuny,haozhangthu): Revisit initialization method.
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=model.initialize)
controller = orbit.Controller(
strategy=distribution_strategy,
trainer=trainer,
evaluator=evaluator,
global_step=global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train'),
eval_summary_dir=os.path.join(model_dir, 'validation'),
summary_interval=params.trainer.summary_interval)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if evaluator.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
return model
def run_experiment_with_multitask_eval(
*,
distribution_strategy: tf.distribute.Strategy,
train_task: base_task.Task,
eval_tasks: List[base_task.Task],
mode: str,
params: configs.MultiEvalExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
trainer: Optional[core_lib.Trainer] = None) -> Tuple[Any, Any]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
train_task: A base_task.Task instance.
eval_tasks: A list of evaluation tasks.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: MultiEvalExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
trainer: the core_lib.Trainer instance. It should be created within the
strategy.scope(). If not provided, an instance will be created by default
if `mode` contains 'train'.
Returns:
model: `tf.keras.Model` instance.
"""
is_training = 'train' in mode
is_eval = 'eval' in mode
with distribution_strategy.scope():
if is_training:
trainer = trainer or core_lib.Trainer(
config=params,
task=train_task,
model=train_task.build_model(),
optimizer=train_task.create_optimizer(params.trainer.optimizer_config,
params.runtime),
train=True,
evaluate=False)
else:
trainer = None
model = trainer.model if trainer else train_task.build_model()
if is_eval:
eval_steps = dict([(task_routine.task_config.name,
task_routine.eval_steps)
for task_routine in params.eval_tasks])
evaluator = evaluator_lib.MultiTaskEvaluator(
eval_tasks=eval_tasks,
model=model,
global_step=trainer.global_step if is_training else None,
eval_steps=eval_steps,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir))
else:
evaluator = None
if trainer:
checkpoint = trainer.checkpoint
global_step = trainer.global_step
else:
checkpoint = evaluator.checkpoint
global_step = evaluator.global_step
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize if trainer else None)
controller = orbit.Controller(
strategy=distribution_strategy,
trainer=trainer,
evaluator=evaluator,
global_step=global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train') if save_summary else None,
eval_summary_dir=os.path.join(model_dir, 'validation') if
(save_summary) else None,
summary_interval=params.trainer.summary_interval if
(save_summary) else None)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if evaluator.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
if run_post_eval:
return model, evaluator.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps)) # pytype: disable=bad-return-type # typed-keras
else:
return model, {} # pytype: disable=bad-return-type # typed-keras
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.train_lib."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import task_factory
from official.modeling.hyperparams import params_dict
from official.modeling.multitask import configs
from official.modeling.multitask import multitask
from official.modeling.multitask import test_utils
from official.modeling.multitask import train_lib
class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self._test_config = {
'trainer': {
'checkpoint_interval': 10,
'steps_per_loop': 10,
'summary_interval': 10,
'train_steps': 10,
'validation_steps': 5,
'validation_interval': 10,
'continuous_eval_timeout': 1,
'optimizer_config': {
'optimizer': {
'type': 'sgd',
},
'learning_rate': {
'type': 'constant'
}
}
},
}
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
flag_mode=['train', 'eval', 'train_and_eval']))
def test_end_to_end(self, distribution_strategy, flag_mode):
model_dir = self.get_temp_dir()
experiment_config = configs.MultiTaskExperimentConfig(
task=configs.MultiTaskConfig(
task_routines=(
configs.TaskRoutine(
task_name='foo', task_config=test_utils.FooConfig()),
configs.TaskRoutine(
task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope():
test_multitask = multitask.MultiTask.from_config(experiment_config.task)
model = test_utils.MockMultiTaskModel()
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=test_multitask,
model=model,
mode=flag_mode,
params=experiment_config,
model_dir=model_dir)
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
flag_mode=['train', 'eval', 'train_and_eval']))
def test_end_to_end_multi_eval(self, distribution_strategy, flag_mode):
model_dir = self.get_temp_dir()
experiment_config = configs.MultiEvalExperimentConfig(
task=test_utils.FooConfig(),
eval_tasks=(configs.TaskRoutine(
task_name='foo', task_config=test_utils.FooConfig(), eval_steps=2),
configs.TaskRoutine(
task_name='bar',
task_config=test_utils.BarConfig(),
eval_steps=3)))
experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope():
train_task = task_factory.get_task(experiment_config.task)
eval_tasks = [
task_factory.get_task(config.task_config, name=config.task_name)
for config in experiment_config.eval_tasks
]
train_lib.run_experiment_with_multitask_eval(
distribution_strategy=distribution_strategy,
train_task=train_task,
eval_tasks=eval_tasks,
mode=flag_mode,
params=experiment_config,
model_dir=model_dir)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimization package definition."""
# pylint: disable=wildcard-import
from official.modeling.optimization.configs.learning_rate_config import *
from official.modeling.optimization.configs.optimization_config import *
from official.modeling.optimization.configs.optimizer_config import *
from official.modeling.optimization.ema_optimizer import ExponentialMovingAverage
from official.modeling.optimization.lr_schedule import *
from official.modeling.optimization.optimizer_factory import OptimizerFactory
from official.modeling.optimization.optimizer_factory import register_optimizer_cls
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adafactor optimizer.
A new optimizer that will be open sourced soon.
"""
# pylint: disable=invalid-name, represents an unimplemented class definition.
Adafactor = "Unimplemented"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment