Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -226,9 +226,13 @@ class TrainerWithMaskedLMTaskTest(tf.test.TestCase, parameterized.TestCase):
task = TestPolicy(None, config.task)
trainer = trainer_lib.ProgressiveTrainer(config, task, self.get_temp_dir())
if mixed_precision_dtype != 'float16':
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
self.assertIsInstance(
trainer.optimizer,
(tf.keras.optimizers.SGD, tf.keras.optimizers.legacy.SGD))
elif mixed_precision_dtype == 'float16' and loss_scale is None:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
self.assertIsInstance(
trainer.optimizer,
(tf.keras.optimizers.SGD, tf.keras.optimizers.legacy.SGD))
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,10 +18,10 @@ from absl import logging
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.training.tracking import tracking
from tensorflow.python.trackable import autotrackable
class VolatileTrackable(tracking.AutoTrackable):
class VolatileTrackable(autotrackable.AutoTrackable):
"""A util class to keep Trackables that might change instances."""
def __init__(self, **kwargs):
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -41,15 +41,15 @@ _PARAM_RE = re.compile(
_CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)')
# Yaml loader with an implicit resolver to parse float decimal and exponential
# Yaml LOADER with an implicit resolver to parse float decimal and exponential
# format. The regular experission parse the following cases:
# 1- Decimal number with an optional exponential term.
# 2- Integer number with an exponential term.
# 3- Decimal number with an optional exponential term.
# 4- Decimal number.
LOADER = yaml.SafeLoader
LOADER.add_implicit_resolver(
_LOADER = yaml.SafeLoader
_LOADER.add_implicit_resolver(
'tag:yaml.org,2002:float',
re.compile(r'''
^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
......@@ -288,42 +288,42 @@ class ParamsDict(object):
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v != right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
'Found inconsistency between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '!=' in restriction:
tokens = restriction.split('!=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v == right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
'Found inconsistency between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<' in restriction:
tokens = restriction.split('<')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v >= right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
'Found inconsistency between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<=' in restriction:
tokens = restriction.split('<=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v > right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
'Found inconsistency between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>' in restriction:
tokens = restriction.split('>')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v <= right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
'Found inconsistency between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>=' in restriction:
tokens = restriction.split('>=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v < right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
'Found inconsistency between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
else:
raise ValueError('Unsupported relation in restriction.')
......@@ -332,7 +332,7 @@ class ParamsDict(object):
def read_yaml_to_params_dict(file_path: str):
"""Reads a YAML file to a ParamsDict."""
with tf.io.gfile.GFile(file_path, 'r') as f:
params_dict = yaml.load(f, Loader=LOADER)
params_dict = yaml.load(f, Loader=_LOADER)
return ParamsDict(params_dict)
......@@ -453,7 +453,7 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
nested_csv_str_to_json_str(dict_or_string_or_yaml_file))
except ValueError:
pass
params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=LOADER)
params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=_LOADER)
if isinstance(params_dict, dict):
params.override(params_dict, is_strict)
else:
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -43,3 +43,12 @@ class MultiTaskBaseModel(tf.Module):
def initialize(self):
"""Optional function that loads a pre-train checkpoint."""
return
def build(self):
"""Builds the networks for tasks to make sure variables are created."""
# Try to build all sub tasks.
for task_model in self._sub_tasks.values():
# Assumes all the tf.Module models are built because we don't have any
# way to check them.
if isinstance(task_model, tf.keras.Model) and not task_model.built:
_ = task_model(task_model.inputs)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment