Commit 472e2f80 authored by zhanggzh's avatar zhanggzh
Browse files

Merge remote-tracking branch 'tf_model/main'

parents d91296eb f3a14f85
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom checkpoint manager that also exports saved models."""
import os
import re
import time
from typing import Callable, List, Mapping, Optional, Union
from absl import logging
import tensorflow as tf
SAVED_MODULES_PATH_SUFFIX = 'saved_modules'
def make_saved_modules_directory_name(checkpoint_name: str) -> str:
return f'{checkpoint_name}_{SAVED_MODULES_PATH_SUFFIX}'
class SavedModelCheckpointManager(tf.train.CheckpointManager):
"""A CheckpointManager that also exports `SavedModel`s."""
def __init__(self,
checkpoint: tf.train.Checkpoint,
directory: str,
max_to_keep: int,
modules_to_export: Optional[Mapping[str, tf.Module]] = None,
keep_checkpoint_every_n_hours: Optional[int] = None,
checkpoint_name: str = 'ckpt',
step_counter: Optional[tf.Variable] = None,
checkpoint_interval: Optional[int] = None,
init_fn: Optional[Callable[[], None]] = None):
"""See base class."""
super().__init__(
checkpoint=checkpoint,
directory=directory,
max_to_keep=max_to_keep,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
checkpoint_name=checkpoint_name,
step_counter=step_counter,
checkpoint_interval=checkpoint_interval,
init_fn=init_fn)
self._modules_to_export = modules_to_export
self._savedmodels = self.get_existing_savedmodels()
def save(self,
checkpoint_number: Optional[int] = None,
check_interval: bool = True,
options: Optional[tf.train.CheckpointOptions] = None):
"""See base class."""
checkpoint_path = super().save(
checkpoint_number=checkpoint_number,
check_interval=check_interval,
options=options)
if not checkpoint_path: # Nothing got written.
return
if not self._modules_to_export: # No modules to export.
logging.info('Skip saving SavedModel due to empty modules_to_export.')
return checkpoint_path
# Save the models for the checkpoint that just got written.
saved_modules_directory = make_saved_modules_directory_name(checkpoint_path)
for model_name, model in self._modules_to_export.items():
signatures = getattr(model, 'saved_model_signatures', None)
if signatures is not None:
tf.saved_model.save(
obj=model,
export_dir=os.path.join(saved_modules_directory, model_name),
signatures=signatures)
saved_modules_directories_to_keep = [
make_saved_modules_directory_name(ckpt) for ckpt in self.checkpoints
]
existing_saved_modules_dirs = self.get_existing_savedmodels()
self._savedmodels = []
# Keep savedmodels in the same order as checkpoints (from oldest to newest).
for saved_modules_dir_to_keep in saved_modules_directories_to_keep:
if saved_modules_dir_to_keep in existing_saved_modules_dirs:
self._savedmodels.append(saved_modules_dir_to_keep)
for existing_saved_modules_dir in existing_saved_modules_dirs:
if existing_saved_modules_dir not in self._savedmodels:
tf.io.gfile.rmtree(existing_saved_modules_dir)
return checkpoint_path
def get_existing_savedmodels(self) -> List[str]:
"""Gets a list of all existing SavedModel paths in `directory`.
Returns:
A list of all existing SavedModel paths.
"""
saved_modules_glob = make_saved_modules_directory_name(
self._checkpoint_prefix + '-*')
return tf.io.gfile.glob(saved_modules_glob)
@property
def latest_savedmodel(self) -> Union[str, None]:
"""The path of the most recent SavedModel in `directory`.
Returns:
The latest SavedModel path. If there are no SavedModels, returns `None`.
"""
if self._savedmodels:
return self._savedmodels[-1]
return None
@property
def savedmodels(self) -> List[str]:
"""A list of managed SavedModels.
Returns:
A list of SavedModel paths, sorted from oldest to newest.
"""
return self._savedmodels
@property
def modules_to_export(self) -> Union[Mapping[str, tf.Module], None]:
return self._modules_to_export
def get_savedmodel_number_from_path(self,
savedmodel_path: str) -> Union[int, None]:
"""Gets the savedmodel_number/checkpoint_number from savedmodel filepath.
The savedmodel_number is global step when using with orbit controller.
Args:
savedmodel_path: savedmodel directory path.
Returns:
Savedmodel number or None if no matched pattern found in savedmodel path.
"""
pattern = rf'\d+_{SAVED_MODULES_PATH_SUFFIX}$'
savedmodel_number = re.search(pattern, savedmodel_path)
if savedmodel_number:
savedmodel_number = savedmodel_number.group()
return int(savedmodel_number[:-len(SAVED_MODULES_PATH_SUFFIX) - 1])
return None
def savedmodels_iterator(self,
min_interval_secs: float = 0,
timeout: Optional[float] = None,
timeout_fn: Optional[Callable[[], bool]] = None):
"""Continuously yield new SavedModel files as they appear.
The iterator only checks for new savedmodels when control flow has been
reverted to it. The logic is same to the `train.checkpoints_iterator`.
Args:
min_interval_secs: The minimum number of seconds between yielding
savedmodels.
timeout: The maximum number of seconds to wait between savedmodels. If
left as `None`, then the process will wait indefinitely.
timeout_fn: Optional function to call after a timeout. If the function
returns True, then it means that no new savedmodels will be generated
and the iterator will exit. The function is called with no arguments.
Yields:
String paths to latest SavedModel files as they arrive.
"""
savedmodel_path = None
while True:
new_savedmodel_path = self.wait_for_new_savedmodel(
savedmodel_path, timeout=timeout)
if new_savedmodel_path is None:
if not timeout_fn:
# timed out
logging.info('Timed-out waiting for a savedmodel.')
return
if timeout_fn():
# The timeout_fn indicated that we are truly done.
return
else:
# The timeout_fn indicated that more savedmodels may come.
continue
start = time.time()
savedmodel_path = new_savedmodel_path
yield savedmodel_path
time_to_next_eval = start + min_interval_secs - time.time()
if time_to_next_eval > 0:
time.sleep(time_to_next_eval)
def wait_for_new_savedmodel(
self,
last_savedmodel: Optional[str] = None,
seconds_to_sleep: float = 1.0,
timeout: Optional[float] = None) -> Union[str, None]:
"""Waits until a new savedmodel file is found.
Args:
last_savedmodel: The last savedmodel path used or `None` if we're
expecting a savedmodel for the first time.
seconds_to_sleep: The number of seconds to sleep for before looking for a
new savedmodel.
timeout: The maximum number of seconds to wait. If left as `None`, then
the process will wait indefinitely.
Returns:
A new savedmodel path, or None if the timeout was reached.
"""
logging.info('Waiting for new savedmodel at %s', self._directory)
stop_time = time.time() + timeout if timeout is not None else None
last_savedmodel_number = 0
if last_savedmodel:
last_savedmodel_number = self.get_savedmodel_number_from_path(
last_savedmodel)
while True:
if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
return None
existing_savedmodels = {}
for savedmodel_path in self.get_existing_savedmodels():
savedmodel_number = self.get_savedmodel_number_from_path(
savedmodel_path)
if savedmodel_number is not None:
existing_savedmodels[savedmodel_number] = savedmodel_path
# Find the first savedmodel with larger step number as next savedmodel.
savedmodel_path = None
existing_savedmodels = dict(sorted(existing_savedmodels.items()))
for savedmodel_number in existing_savedmodels:
if savedmodel_number > last_savedmodel_number:
savedmodel_path = existing_savedmodels[savedmodel_number]
break
if savedmodel_path:
logging.info('Found new savedmodel at %s', savedmodel_path)
return savedmodel_path
else:
time.sleep(seconds_to_sleep)
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from typing import Iterable
import tensorflow as tf
from official.core import savedmodel_checkpoint_manager
def _models_exist(checkpoint_path: str, models: Iterable[str]) -> bool:
for model_name in models:
if not tf.io.gfile.isdir(
os.path.join(
savedmodel_checkpoint_manager.make_saved_modules_directory_name(
checkpoint_path), model_name)):
return False
return True
class _ModelForTest(tf.keras.Model):
def __init__(self, hidden_size: int = 8):
super().__init__()
self.dense = tf.keras.layers.Dense(hidden_size)
@tf.function(input_signature=[tf.TensorSpec([None, 16])])
def call(self, inputs):
return self.dense(inputs)
@property
def saved_model_signatures(self):
# Build SavedModel signatures.
return dict(serving_default=self.call)
class CheckpointManagerTest(tf.test.TestCase):
def _create_manager(self, max_to_keep: int = 1) -> tf.train.CheckpointManager:
"""Sets up SavedModelCheckpointManager object.
Args:
max_to_keep: max number of savedmodels to keep.
Returns:
created savedmodel manager.
"""
models = {
'model_1': _ModelForTest(12),
'model_2': _ModelForTest(14),
}
checkpoint = tf.train.Checkpoint()
manager = savedmodel_checkpoint_manager.SavedModelCheckpointManager(
checkpoint=checkpoint,
directory=self.get_temp_dir(),
max_to_keep=max_to_keep,
modules_to_export=models)
return manager
def test_max_to_keep(self):
manager = self._create_manager()
models = manager.modules_to_export
first_path = manager.save()
second_path = manager.save()
savedmodel = savedmodel_checkpoint_manager.make_saved_modules_directory_name(
manager.latest_checkpoint)
self.assertEqual(savedmodel, manager.latest_savedmodel)
self.assertTrue(_models_exist(second_path, models.keys()))
self.assertFalse(_models_exist(first_path, models.keys()))
def test_returns_none_after_timeout(self):
manager = self._create_manager()
start = time.time()
ret = manager.wait_for_new_savedmodel(
None, timeout=1.0, seconds_to_sleep=0.5)
end = time.time()
self.assertIsNone(ret)
# We've waited 0.5 second.
self.assertGreater(end, start + 0.5)
# The timeout kicked in.
self.assertLess(end, start + 0.6)
def test_saved_model_iterator(self):
manager = self._create_manager(max_to_keep=2)
self.assertIsNotNone(manager.save(checkpoint_number=1))
self.assertIsNotNone(manager.save(checkpoint_number=2))
self.assertIsNotNone(manager.save(checkpoint_number=3))
# Savedmodels are in time order.
expected_savedmodels = manager.savedmodels
# Order not guaranteed.
existing_savedmodels = manager.get_existing_savedmodels()
savedmodels = list(manager.savedmodels_iterator(timeout=3.0))
self.assertEqual(savedmodels, expected_savedmodels)
self.assertEqual(set(savedmodels), set(existing_savedmodels))
def test_saved_model_iterator_timeout_fn(self):
manager = self._create_manager()
timeout_fn_calls = [0]
def timeout_fn():
timeout_fn_calls[0] += 1
return timeout_fn_calls[0] > 3
results = list(
manager.savedmodels_iterator(timeout=0.1, timeout_fn=timeout_fn))
self.assertEqual([], results)
self.assertEqual(4, timeout_fn_calls[0])
if __name__ == '__main__':
tf.test.main()
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A global factory to register and access all registered tasks."""
from official.core import registry
_REGISTERED_TASK_CLS = {}
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def register_task_cls(task_config_cls):
"""Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
This decorator supports registration of tasks as follows:
```
@dataclasses.dataclass
class MyTaskConfig(TaskConfig):
# Add fields here.
pass
@register_task_cls(MyTaskConfig)
class MyTask(Task):
# Inherits def __init__(self, task_config).
pass
my_task_config = MyTaskConfig()
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
```
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
def get_task(task_config, **kwargs):
"""Creates a Task (of suitable subclass type) from task_config."""
# TODO(hongkuny): deprecate the task factory to use config.BUILDER.
if task_config.BUILDER is not None:
return task_config.BUILDER(task_config, **kwargs)
return get_task_cls(task_config.__class__)(task_config, **kwargs)
# The user-visible get_task() is defined after classes have been registered.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def get_task_cls(task_config_cls):
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
return task_cls
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for testing."""
import tensorflow as tf
class FakeKerasModel(tf.keras.Model):
"""Fake keras model for testing."""
def __init__(self):
super().__init__()
self.dense = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
return self.dense2(self.dense(inputs))
class _Dense(tf.Module):
"""A dense layer."""
def __init__(self, input_dim, output_size, name=None):
super().__init__(name=name)
with self.name_scope:
self.w = tf.Variable(
tf.random.normal([input_dim, output_size]), name='w')
self.b = tf.Variable(tf.zeros([output_size]), name='b')
@tf.Module.with_name_scope
def __call__(self, x):
y = tf.matmul(x, self.w) + self.b
return tf.nn.relu(y)
class FakeModule(tf.Module):
"""Fake model using tf.Module for testing."""
def __init__(self, input_size, name=None):
super().__init__(name=name)
with self.name_scope:
self.dense = _Dense(input_size, 4, name='dense')
self.dense2 = _Dense(4, 4, name='dense_1')
@tf.Module.with_name_scope
def __call__(self, x):
return self.dense2(self.dense(x))
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Builder class for preparing tf.train.Example."""
# https://www.python.org/dev/peps/pep-0563/#enabling-the-future-behavior-in-python-3-7
from __future__ import annotations
from typing import Mapping, Sequence, Union
import numpy as np
import tensorflow as tf
BytesValueType = Union[bytes, Sequence[bytes], str, Sequence[str]]
_to_array = lambda v: [v] if not isinstance(v, (list, np.ndarray)) else v
_to_bytes = lambda v: v.encode() if isinstance(v, str) else v
_to_bytes_array = lambda v: list(map(_to_bytes, _to_array(v)))
class TfExampleBuilder(object):
"""Builder class for preparing tf.train.Example.
Read API doc at https://www.tensorflow.org/api_docs/python/tf/train/Example.
Example usage:
>>> example_builder = TfExampleBuilder()
>>> example = (
example_builder.add_bytes_feature('feature_a', 'foobarbaz')
.add_ints_feature('feature_b', [1, 2, 3])
.example)
"""
def __init__(self) -> None:
self._example = tf.train.Example()
@property
def example(self) -> tf.train.Example:
"""Returns a copy of the generated tf.train.Example proto."""
return self._example
@property
def serialized_example(self) -> str:
"""Returns a serialized string of the generated tf.train.Example proto."""
return self._example.SerializeToString()
def set(self, example: tf.train.Example) -> TfExampleBuilder:
"""Sets the example."""
self._example = example
return self
def reset(self) -> TfExampleBuilder:
"""Resets the example to an empty proto."""
self._example = tf.train.Example()
return self
###### Basic APIs for primitive data types ######
def add_feature_dict(
self, feature_dict: Mapping[str, tf.train.Feature]) -> TfExampleBuilder:
"""Adds the predefined `feature_dict` to the example.
Note: Please prefer to using feature-type-specific methods.
Args:
feature_dict: A dictionary from tf.Example feature key to
tf.train.Feature.
Returns:
The builder object for subsequent method calls.
"""
for k, v in feature_dict.items():
self._example.features.feature[k].CopyFrom(v)
return self
def add_feature(self, key: str,
feature: tf.train.Feature) -> TfExampleBuilder:
"""Adds predefined `feature` with `key` to the example.
Args:
key: String key of the feature.
feature: The feature to be added to the example.
Returns:
The builder object for subsequent method calls.
"""
self._example.features.feature[key].CopyFrom(feature)
return self
def add_bytes_feature(self, key: str,
value: BytesValueType) -> TfExampleBuilder:
"""Adds byte(s) or string(s) with `key` to the example.
Args:
key: String key of the feature.
value: The byte(s) or string(s) to be added to the example.
Returns:
The builder object for subsequent method calls.
"""
return self.add_feature(
key,
tf.train.Feature(
bytes_list=tf.train.BytesList(value=_to_bytes_array(value))))
def add_ints_feature(self, key: str,
value: Union[int, Sequence[int]]) -> TfExampleBuilder:
"""Adds integer(s) with `key` to the example.
Args:
key: String key of the feature.
value: The integer(s) to be added to the example.
Returns:
The builder object for subsequent method calls.
"""
return self.add_feature(
key,
tf.train.Feature(int64_list=tf.train.Int64List(value=_to_array(value))))
def add_floats_feature(
self, key: str, value: Union[float, Sequence[float]]) -> TfExampleBuilder:
"""Adds float(s) with `key` to the example.
Args:
key: String key of the feature.
value: The float(s) to be added to the example.
Returns:
The builder object for subsequent method calls.
"""
return self.add_feature(
key,
tf.train.Feature(float_list=tf.train.FloatList(value=_to_array(value))))
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tf_example_builder.
See `test_add_image_matrix_feature_with_fake_image` for the typical structure of
a unit test.
"""
from absl.testing import parameterized
import tensorflow as tf
from official.core import tf_example_builder
class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
def test_init_an_empty_example(self):
example_builder = tf_example_builder.TfExampleBuilder()
example = example_builder.example
self.assertProtoEquals('', example)
def test_init_an_empty_serialized_example(self):
example_builder = tf_example_builder.TfExampleBuilder()
example = example_builder.serialized_example
self.assertProtoEquals('', example)
def test_add_feature(self):
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_feature(
'foo',
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[b'Hello World!'])))
example = example_builder.example
# Use proto text to show how the entire proto would look like.
self.assertProtoEquals(
"""
features: {
feature: {
key: "foo"
value: {
bytes_list: {
value: "Hello World!"
}
}
}
}""", example)
def test_add_feature_dict(self):
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_feature_dict({
'foo':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[b'Hello World!'])),
'bar':
tf.train.Feature(
int64_list=tf.train.Int64List(value=[299, 792, 458]))
})
example = example_builder.example
# Use proto text to show how the entire proto would look like.
self.assertProtoEquals(
"""
features: {
feature: {
key: "foo"
value: {
bytes_list: {
value: "Hello World!"
}
}
}
feature: {
key: "bar"
value: {
int64_list: {
value: 299
value: 792
value: 458
}
}
}
}""", example)
@parameterized.named_parameters(
('single_bytes', b'Hello World!', b'Hello World!'),
('single_string', 'Hello World!', b'Hello World!'))
def test_add_single_byte_feature(self, value, expected_value):
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_bytes_feature('foo', value)
example = example_builder.example
# Use constructor to easily work with test parameters.
self.assertProtoEquals(
tf.train.Example(
features=tf.train.Features(
feature={
'foo':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[expected_value]))
})), example)
@parameterized.named_parameters(
('multiple_bytes', [b'Hello World!', b'Good Morning!'
], [b'Hello World!', b'Good Morning!']),
('multiple_sring', ['Hello World!', 'Good Morning!'
], [b'Hello World!', b'Good Morning!']))
def test_add_multiple_bytes_feature(self, values, expected_values):
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_bytes_feature('foo', values)
example = example_builder.example
self.assertProtoEquals(
tf.train.Example(
features=tf.train.Features(
feature={
'foo':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=expected_values))
})), example)
@parameterized.named_parameters(
('single_integer', 123, [123]),
('multiple_integers', [123, 456, 789], [123, 456, 789]))
def test_add_ints_feature(self, value, expected_value):
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_ints_feature('bar', value)
example = example_builder.example
self.assertProtoEquals(
tf.train.Example(
features=tf.train.Features(
feature={
'bar':
tf.train.Feature(
int64_list=tf.train.Int64List(value=expected_value))
})), example)
@parameterized.named_parameters(
('single_float', 3.14, [3.14]),
('multiple_floats', [3.14, 1.57, 6.28], [3.14, 1.57, 6.28]))
def test_add_floats_feature(self, value, expected_value):
example_builder = tf_example_builder.TfExampleBuilder()
example_builder.add_floats_feature('baz', value)
example = example_builder.example
self.assertProtoEquals(
tf.train.Example(
features=tf.train.Features(
feature={
'baz':
tf.train.Feature(
float_list=tf.train.FloatList(value=expected_value))
})), example)
if __name__ == '__main__':
tf.test.main()
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data classes for tf.Example proto feature keys.
Feature keys are grouped by feature types. Key names follow conventions in
go/tf-example.
"""
import dataclasses
import functools
from typing import Optional
# Disable init function to use the one defined in base class.
dataclass = functools.partial(dataclasses.dataclass(init=False))
@dataclass
class TfExampleFeatureKeyBase:
"""Base dataclass for defining tf.Example proto feature keys.
This class defines the logic of adding prefix to feature keys. Subclasses
will define feature keys for a specific feature type in data fields.
NOTE: Please follow subclass examples in this module to define feature keys
for a new feature type.
"""
def __init__(self, prefix: Optional[str] = None):
"""Instantiates the feature key class.
Adds a string prefix to all fields of a feature key instance if `prefix` is
not None nor empty.
Example usage:
>>> test_key = EncodedImageFeatureKey()
>>> test_key.encoded
image/encoded
>>> test_key = EncodedImageFeatureKey('prefix')
>>> test_key.encoded
prefix/image/encoded
Args:
prefix: A prefix string that will be added before the feature key string
with a trailing slash '/'.
"""
if prefix:
for field in dataclasses.fields(self): # pytype: disable=wrong-arg-types # re-none
key_name = field.name
key_value = getattr(self, key_name)
setattr(self, key_name, f'{prefix}/{key_value}')
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tf_example_feature_key."""
import dataclasses
import inspect
from absl.testing import absltest
from absl.testing import parameterized
from official.core import tf_example_feature_key
@tf_example_feature_key.dataclass
class TestFeatureKey(tf_example_feature_key.TfExampleFeatureKeyBase):
test: str = 'foo/bar'
class TfExampleFeatureKeyTest(parameterized.TestCase):
def test_add_prefix_success(self):
test_key = TestFeatureKey('prefix')
self.assertEqual(test_key.test, 'prefix/foo/bar')
@parameterized.parameters(None, '')
def test_add_prefix_skip_success(self, prefix):
test_key = TestFeatureKey(prefix)
self.assertEqual(test_key.test, 'foo/bar')
def test_all_feature_key_classes_are_valid(self):
for _, obj in inspect.getmembers(tf_example_feature_key):
if inspect.isclass(obj):
self.assertTrue(dataclasses.is_dataclass(obj))
self.assertTrue(
issubclass(obj, tf_example_feature_key.TfExampleFeatureKeyBase))
if __name__ == '__main__':
absltest.main()
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM common training driver library."""
# pytype: disable=attribute-error
import os
import tempfile
from typing import Any, List, Mapping, Optional, Tuple
# Import libraries
from absl import logging
import orbit
import tensorflow as tf
from official.core import actions
from official.core import base_task
from official.core import base_trainer
from official.core import config_definitions
from official.core import train_utils
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
class OrbitExperimentRunner:
"""Runs experiment with Orbit training loop.
The default experiment runner for model garden experiments. User can
customize the experiment pipeline by subclassing this class and replacing
components or functions.
For example, an experiment runner with customized checkpoint manager:
```python
class MyExpRunnerWithExporter(OrbitExperimentRunner):
def _maybe_build_checkpoint_manager(sefl):
# Replaces the default CheckpointManger with a customized one.
return MyCheckpointManager(*args)
# In user code, instead of the orginal
# `OrbitExperimentRunner(..).run(mode)`, now user can do:
MyExpRunnerWithExporter(**needed_kwargs).run(mode)
```
Similar override can be done to other components.
"""
def __init__(
self,
distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
):
"""Constructor.
Args:
distribution_strategy: A distribution strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval',
'train_and_eval' or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
train_actions: Optional list of Orbit train actions.
eval_actions: Optional list of Orbit eval actions.
trainer: the base_trainer.Trainer instance. It should be created within
the strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
summary_manager: Instance of the summary manager to override default
summary manager.
eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager.
"""
self.strategy = distribution_strategy or tf.distribute.get_strategy()
self._params = params
self._model_dir = model_dir
self._mode = mode
self._run_post_eval = run_post_eval
self._trainer = trainer or self._build_trainer(
task,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval)
assert self.trainer is not None
self._checkpoint_manager = self._maybe_build_checkpoint_manager()
self._summary_manager = summary_manager
self._eval_summary_manager = eval_summary_manager
self._controller = self._build_controller(
trainer=self.trainer if 'train' in mode else None,
evaluator=self.trainer,
save_summary=save_summary,
train_actions=train_actions,
eval_actions=eval_actions,
controller_cls=controller_cls)
@property
def params(self) -> config_definitions.ExperimentConfig:
"""The whole experiment parameters object."""
return self._params
@property
def model_dir(self) -> str:
"""Path to the model folder, which stores checkpoints, params, log, etc."""
return self._model_dir
@property
def trainer(self) -> base_trainer.Trainer:
"""The underlying Orbit Trainer object."""
return self._trainer
@property
def checkpoint_manager(self) -> tf.train.CheckpointManager:
"""The CheckpointManager that stores the checkpoints in a train job."""
return self._checkpoint_manager
@property
def controller(self) -> orbit.Controller:
"""The Orbit controller object."""
return self._controller
def _build_trainer(self, task: base_task.Task, train: bool,
evaluate: bool) -> base_trainer.Trainer:
"""Create trainer."""
with self.strategy.scope():
trainer = train_utils.create_trainer(
self.params,
task,
train=train,
evaluate=evaluate,
checkpoint_exporter=self._build_best_checkpoint_exporter())
return trainer
def _build_best_checkpoint_exporter(self):
return maybe_create_best_ckpt_exporter(self.params, self.model_dir)
def _maybe_build_checkpoint_manager(
self) -> Optional[tf.train.CheckpointManager]:
"""Maybe create a CheckpointManager."""
assert self.trainer is not None
if self.trainer.checkpoint:
if self.model_dir is None:
raise ValueError('model_dir must be specified, but got None')
if (not self.strategy) or self.strategy.extended.should_checkpoint:
ckpt_path = self.model_dir
max_to_keep = self.params.trainer.max_to_keep
else:
# In multi worker training we need every worker to save checkpoint,
# because variables can trigger synchronization on read and
# synchronization needs all workers to participate. To avoid workers
# overriding each other we save to a temporary directory on non-chief
# workers.
ckpt_path = tempfile.mkdtemp()
max_to_keep = 1
checkpoint_manager = tf.train.CheckpointManager(
self.trainer.checkpoint,
directory=ckpt_path,
max_to_keep=max_to_keep,
step_counter=self.trainer.global_step,
checkpoint_interval=self.params.trainer.checkpoint_interval,
init_fn=self.trainer.initialize)
else:
checkpoint_manager = None
return checkpoint_manager
def _build_controller(self,
trainer,
evaluator,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
controller_cls=orbit.Controller) -> orbit.Controller:
"""Builds a Orbit controler."""
train_actions = [] if not train_actions else train_actions
if trainer:
train_actions += actions.get_train_actions(
self.params,
trainer,
self.model_dir,
checkpoint_manager=self.checkpoint_manager)
eval_actions = [] if not eval_actions else eval_actions
if evaluator:
eval_actions += actions.get_eval_actions(self.params, evaluator,
self.model_dir)
if save_summary:
eval_summary_dir = os.path.join(
self.model_dir, self.params.trainer.validation_summary_subdir
)
else:
eval_summary_dir = None
controller = controller_cls(
strategy=self.strategy,
trainer=trainer,
evaluator=evaluator,
global_step=self.trainer.global_step,
steps_per_loop=self.params.trainer.steps_per_loop,
checkpoint_manager=self.checkpoint_manager,
summary_dir=os.path.join(self.model_dir, 'train')
if (save_summary)
else None,
eval_summary_dir=eval_summary_dir,
summary_interval=self.params.trainer.summary_interval
if (save_summary)
else None,
train_actions=train_actions,
eval_actions=eval_actions,
summary_manager=self._summary_manager
if hasattr(self, '_summary_manager')
else None,
eval_summary_manager=self._eval_summary_manager
if hasattr(self, '_eval_summary_manager')
else None,
)
return controller
def run(self) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Run experiments by mode.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
mode = self._mode
params = self.params
logging.info('Starts to execute mode: %s', mode)
with self.strategy.scope():
if mode == 'train' or mode == 'train_and_post_eval':
self.controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
self.controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
self.controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if self.trainer.global_step.numpy() >= params.trainer.train_steps:
return True
return False
self.controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
num_params = train_utils.try_count_params(self.trainer.model)
if num_params is not None:
logging.info('Number of trainable params in model: %f Millions.',
num_params / 10.**6)
flops = train_utils.try_count_flops(self.trainer.model)
if flops is not None:
logging.info('FLOPs (multi-adds) in model: %f Billions.',
flops / 10.**9 / 2)
if self._run_post_eval or mode == 'train_and_post_eval':
with self.strategy.scope():
return self.trainer.model, self.controller.evaluate(
steps=params.trainer.validation_steps)
else:
return self.trainer.model, {}
def run_experiment(
distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
train_actions: Optional list of Orbit train actions.
eval_actions: Optional list of Orbit eval actions.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
summary_manager: Instance of the summary manager to override default summary
manager.
eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
runner = OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode=mode,
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval,
save_summary=save_summary,
train_actions=train_actions,
eval_actions=eval_actions,
trainer=trainer,
controller_cls=controller_cls,
summary_manager=summary_manager,
eval_summary_manager=eval_summary_manager,
)
return runner.run()
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for train_ctl_lib."""
import json
import os
from absl import flags
from absl.testing import flagsaver
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.common import flags as tfm_flags
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.utils.testing import mock_task
FLAGS = flags.FLAGS
tfm_flags.define_flags()
class TrainTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainTest, self).setUp()
self._test_config = {
'trainer': {
'checkpoint_interval': 10,
'steps_per_loop': 10,
'summary_interval': 10,
'train_steps': 10,
'validation_steps': 5,
'validation_interval': 10,
'continuous_eval_timeout': 1,
'validation_summary_subdir': 'validation',
'optimizer_config': {
'optimizer': {
'type': 'sgd',
},
'learning_rate': {
'type': 'constant'
}
}
},
}
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train', 'eval', 'train_and_eval'],
run_post_eval=[True, False]))
def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode=flag_mode,
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS)
train_utils.serialize_config(params, model_dir)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval)
if 'eval' in flag_mode:
self.assertTrue(
tf.io.gfile.exists(
os.path.join(model_dir,
params.trainer.validation_summary_subdir)))
if run_post_eval:
self.assertNotEmpty(logs)
else:
self.assertEmpty(logs)
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
if flag_mode == 'eval':
return
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
# Tests continuous evaluation.
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='continuous_eval',
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval)
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train', 'eval', 'train_and_eval'],
run_post_eval=[True, False]))
def test_end_to_end_class(self, distribution_strategy, flag_mode,
run_post_eval):
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode=flag_mode,
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS)
train_utils.serialize_config(params, model_dir)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
_, logs = train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval).run()
if 'eval' in flag_mode:
self.assertTrue(
tf.io.gfile.exists(
os.path.join(model_dir,
params.trainer.validation_summary_subdir)))
if run_post_eval:
self.assertNotEmpty(logs)
else:
self.assertEmpty(logs)
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
if flag_mode == 'eval':
return
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
# Tests continuous evaluation.
_, logs = train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode='continuous_eval',
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval).run()
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train', 'train_and_eval'],
))
def test_recovery_nan_error(self, distribution_strategy, flag_mode):
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode=flag_mode,
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS)
train_utils.serialize_config(params, model_dir)
with distribution_strategy.scope():
# task = task_factory.get_task(params.task, logging_dir=model_dir)
task = mock_task.MockTask(params.task, logging_dir=model_dir)
# Set the loss to NaN to trigger RunTimeError.
def build_losses(labels, model_outputs, aux_losses=None):
del labels, model_outputs
return tf.constant([np.nan], tf.float32) + aux_losses
task.build_losses = build_losses
with self.assertRaises(RuntimeError):
train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir).run()
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train'],
))
def test_recovery(self, distribution_strategy, flag_mode):
loss_threshold = 1.0
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode=flag_mode,
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS)
params.trainer.loss_upper_bound = loss_threshold
params.trainer.recovery_max_trials = 1
train_utils.serialize_config(params, model_dir)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
# Saves a checkpoint for reference.
model = task.build_model()
checkpoint = tf.train.Checkpoint(model=model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, self.get_temp_dir(), max_to_keep=2)
checkpoint_manager.save()
before_weights = model.get_weights()
def build_losses(labels, model_outputs, aux_losses=None):
del labels, model_outputs
return tf.constant([loss_threshold], tf.float32) + aux_losses
task.build_losses = build_losses
model, _ = train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir).run()
after_weights = model.get_weights()
for left, right in zip(before_weights, after_weights):
self.assertAllEqual(left, right)
def test_parse_configuration(self):
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode='train',
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS, lock_return=True)
with self.assertRaises(ValueError):
params.override({'task': {'init_checkpoint': 'Foo'}})
params = train_utils.parse_configuration(flags.FLAGS, lock_return=False)
params.override({'task': {'init_checkpoint': 'Bar'}})
self.assertEqual(params.task.init_checkpoint, 'Bar')
if __name__ == '__main__':
tf.test.main()
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training utils."""
import dataclasses
import inspect
import json
import os
import pprint
from typing import Any, Callable, Dict, List, Optional, Union
from absl import logging
import gin
import orbit
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
# pylint: enable=g-direct-tensorflow-import
from official.core import base_task
from official.core import base_trainer
from official.core import config_definitions
from official.core import exp_factory
from official.modeling import hyperparams
BEST_CHECKPOINT_NAME = 'best_ckpt'
def get_leaf_nested_dict(d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
"""Get leaf from a dictionary with arbitrary depth with a list of keys.
Args:
d: The dictionary to extract value from.
keys: The list of keys to extract values recursively.
Returns:
The value of the leaf.
Raises:
KeyError: If the value of keys extracted is a dictionary.
"""
leaf = d
for k in keys:
if not isinstance(leaf, dict) or k not in leaf:
raise KeyError(
'Path not exist while traversing the dictionary: d with keys'
': %s.' % keys)
leaf = leaf[k]
if isinstance(leaf, dict):
raise KeyError('The value extracted with keys: %s is not a leaf of the '
'dictionary: %s.' % (keys, d))
return leaf
def cast_leaf_nested_dict(d: Dict[str, Any],
cast_fn: Callable[[Any], Any]) -> Dict[str, Any]:
"""Cast the leaves of a dictionary with arbitrary depth in place.
Args:
d: The dictionary to extract value from.
cast_fn: The casting function.
Returns:
A dictionray with the same structure as d.
"""
for key, value in d.items():
if isinstance(value, dict):
d[key] = cast_leaf_nested_dict(value, cast_fn)
else:
d[key] = cast_fn(value)
return d
def _filter_leaf_nested_dict(
d: Dict[str, Any], predicate: Callable[[Any], bool]
) -> Dict[str, Any]:
"""Filters the leaves of a dictionary with arbitrary depth in place.
Args:
d: The dictionary to extract value from.
predicate: A function that will be called on every leave item. When the
function returns True the leave will be kept. Otherwise the leave will be
dropped.
Returns:
A new dictionray with filtered result.
"""
result = {}
for key, value in d.items():
if isinstance(value, dict):
result[key] = _filter_leaf_nested_dict(value, predicate)
elif predicate(value):
result[key] = value
return result
def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
data_dir: str) -> Any:
"""Maybe create a BestCheckpointExporter object, according to the config."""
export_subdir = params.trainer.best_checkpoint_export_subdir
metric_name = params.trainer.best_checkpoint_eval_metric
metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter(best_ckpt_dir, metric_name,
metric_comp)
logging.info(
'Created the best checkpoint exporter. '
'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
export_subdir, metric_name)
else:
best_ckpt_exporter = None
return best_ckpt_exporter
class BestCheckpointExporter:
"""Keeps track of the best result, and saves its checkpoint.
Orbit will support an API for checkpoint exporter. This class will be used
together with orbit once this functionality is ready.
"""
def __init__(self, export_dir: str, metric_name: str, metric_comp: str):
"""Initialization.
Args:
export_dir: The directory that will contain exported checkpoints.
metric_name: Indicates which metric to look at, when determining which
result is better. If eval_logs being passed to maybe_export_checkpoint
is a nested dictionary, use `|` as a seperator for different layers.
metric_comp: Indicates how to compare results. Either `lower` or `higher`.
"""
self._export_dir = export_dir
self._metric_name = metric_name.split('|')
self._metric_comp = metric_comp
if self._metric_comp not in ('lower', 'higher'):
raise ValueError('best checkpoint metric comp must be one of '
'higher, lower. Got: {}'.format(self._metric_comp))
tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
self._best_ckpt_logs = self._maybe_load_best_eval_metric()
self._checkpoint_manager = None
def _get_checkpoint_manager(self, checkpoint):
"""Gets an existing checkpoint manager or creates a new one."""
if self._checkpoint_manager is None or (self._checkpoint_manager.checkpoint
!= checkpoint):
logging.info('Creates a new checkpoint manager.')
self._checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=self._export_dir,
max_to_keep=1,
checkpoint_name=BEST_CHECKPOINT_NAME)
return self._checkpoint_manager
def maybe_export_checkpoint(
self, checkpoint, eval_logs, global_step, write_logs=True) -> bool:
"""Compare eval_logs with past eval_logs and export checkpoint if better."""
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
eval_logs, global_step)
if self._best_ckpt_logs is None or self._new_metric_is_better(
self._best_ckpt_logs, eval_logs):
self._best_ckpt_logs = eval_logs
if write_logs:
self.export_best_eval_metric(self._best_ckpt_logs, global_step)
self._get_checkpoint_manager(checkpoint).save()
return True
return False
def _maybe_load_best_eval_metric(self):
if not tf.io.gfile.exists(self.best_ckpt_logs_path):
return None
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'r') as reader:
return json.loads(reader.read())
def _new_metric_is_better(self, old_logs, new_logs):
"""Check if the metric in new_logs is better than the metric in old_logs."""
old_value = float(
orbit.utils.get_value(
get_leaf_nested_dict(old_logs, self._metric_name)))
new_value = float(
orbit.utils.get_value(
get_leaf_nested_dict(new_logs, self._metric_name)))
logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
old_value, new_value)
if self._metric_comp == 'higher':
if new_value > old_value:
logging.info('[BestCheckpointExporter] '
'the new number is better since it is higher.')
return True
else: # self._metric_comp == 'lower':
if new_value < old_value:
logging.info('[BestCheckpointExporter] '
'the new number is better since it is lower.')
return True
return False
def export_best_eval_metric(self, eval_logs, global_step):
"""Export evaluation results of the best checkpoint into a json file."""
# eval_log_ext may contains non-scalar tensors, such as image data when
# `allow_image_summary` is True. Here we only keep scalar tensors.
eval_logs_ext = _filter_leaf_nested_dict(
eval_logs, lambda x: tf.rank(x) <= 1
)
eval_logs_ext['best_ckpt_global_step'] = global_step
eval_logs_ext = cast_leaf_nested_dict(
eval_logs_ext, lambda x: float(orbit.utils.get_value(x)))
# Saving json file is very fast.
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
@property
def best_ckpt_logs(self):
return self._best_ckpt_logs
@property
def best_ckpt_logs_path(self):
return os.path.join(self._export_dir, 'info.json')
@property
def best_ckpt_path(self):
"""Returns the best ckpt path or None if there is no ckpt yet."""
return tf.train.latest_checkpoint(self._export_dir)
def create_optimizer(task: base_task.Task,
params: config_definitions.ExperimentConfig
) -> tf.keras.optimizers.Optimizer:
"""A create optimizer util to be backward compatability with new args."""
if 'dp_config' in inspect.signature(task.create_optimizer).parameters:
dp_config = None
if hasattr(params.task, 'differential_privacy_config'):
dp_config = params.task.differential_privacy_config
optimizer = task.create_optimizer(
params.trainer.optimizer_config, params.runtime,
dp_config=dp_config)
else:
if hasattr(params.task, 'differential_privacy_config'
) and params.task.differential_privacy_config is not None:
raise ValueError('Differential privacy config is specified but '
'task.create_optimizer api does not accept it.')
optimizer = task.create_optimizer(
params.trainer.optimizer_config,
params.runtime)
return optimizer
@gin.configurable
def create_trainer(params: config_definitions.ExperimentConfig,
task: base_task.Task,
train: bool,
evaluate: bool,
checkpoint_exporter: Optional[BestCheckpointExporter] = None,
trainer_cls=base_trainer.Trainer) -> base_trainer.Trainer:
"""Create trainer."""
logging.info('Running default trainer.')
model = task.build_model()
optimizer = create_optimizer(task, params)
return trainer_cls(
params,
task,
model=model,
optimizer=optimizer,
train=train,
evaluate=evaluate,
checkpoint_exporter=checkpoint_exporter)
@dataclasses.dataclass
class ParseConfigOptions:
"""Use this dataclass instead of FLAGS to customize parse_configuration()."""
experiment: str
config_file: List[str]
tpu: str = ''
tf_data_service: str = ''
params_override: str = ''
def __contains__(self, name):
return name in dataclasses.asdict(self)
class ExperimentParser:
"""Constructs the Experiment config from Flags or equivalent object.
Most of the cases, users only need to call the `parse()` function:
```
builder = ExperimentParser(FLAGS)
params = builder.parse()
```
The advanced users can modify the flow by calling the parse_*() functions
separately.
"""
def __init__(self, flags_obj):
self._flags_obj = flags_obj
def parse(self):
"""Overrall process of constructing Experiment config."""
params = self.base_experiment()
params = self.parse_config_file(params)
params = self.parse_runtime(params)
params = self.parse_data_service(params)
params = self.parse_params_override(params)
return params
def base_experiment(self):
"""Get the base experiment config from --experiment field."""
if self._flags_obj.experiment is None:
raise ValueError('The flag --experiment must be specified.')
return exp_factory.get_exp_config(self._flags_obj.experiment)
def parse_config_file(self, params):
"""Override the configs of params from the config_file."""
for config_file in self._flags_obj.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
return params
def parse_runtime(self, params):
"""Override the runtime configs of params from flags."""
# Override the TPU address and tf.data service address.
params.override({
'runtime': {
'tpu': self._flags_obj.tpu,
},
})
return params
def parse_data_service(self, params):
"""Override the data service configs of params from flags."""
if ('tf_data_service' in self._flags_obj and
self._flags_obj.tf_data_service and
isinstance(params.task, config_definitions.TaskConfig)):
params.override({
'task': {
'train_data': {
'tf_data_service_address': self._flags_obj.tf_data_service,
},
'validation_data': {
'tf_data_service_address': self._flags_obj.tf_data_service,
}
}
})
return params
def parse_params_override(self, params):
# Get the second level of override from `--params_override`.
# `--params_override` is typically used as a further override over the
# template. For example, one may define a particular template for training
# ResNet50 on ImageNet in a config file and pass it via `--config_file`,
# then define different learning rates and pass it via `--params_override`.
if self._flags_obj.params_override:
params = hyperparams.override_params_dict(
params, self._flags_obj.params_override, is_strict=True)
return params
def parse_configuration(flags_obj, lock_return=True, print_return=True):
"""Parses ExperimentConfig from flags."""
params = ExperimentParser(flags_obj).parse()
params.validate()
if lock_return:
params.lock()
if print_return:
pp = pprint.PrettyPrinter()
logging.info('Final experiment parameters:\n%s',
pp.pformat(params.as_dict()))
return params
def serialize_config(params: config_definitions.ExperimentConfig,
model_dir: str):
"""Serializes and saves the experiment config."""
if model_dir is None:
raise ValueError('model_dir must be specified, but got None')
params_save_path = os.path.join(model_dir, 'params.yaml')
logging.info('Saving experiment configuration to %s', params_save_path)
tf.io.gfile.makedirs(model_dir)
hyperparams.save_params_dict_to_yaml(params, params_save_path)
def save_gin_config(filename_suffix: str, model_dir: str):
"""Serializes and saves the experiment config."""
gin_save_path = os.path.join(
model_dir, 'operative_config.{}.gin'.format(filename_suffix))
logging.info('Saving gin configurations to %s', gin_save_path)
tf.io.gfile.makedirs(model_dir)
with tf.io.gfile.GFile(gin_save_path, 'w') as f:
f.write(gin.operative_config_str())
def read_global_step_from_checkpoint(ckpt_file_path):
"""Read global step from checkpoint, or get global step from its filename."""
global_step = tf.Variable(-1, dtype=tf.int64)
ckpt = tf.train.Checkpoint(global_step=global_step)
try:
ckpt.restore(ckpt_file_path).expect_partial()
global_step_maybe_restored = global_step.numpy()
except tf.errors.InvalidArgumentError:
global_step_maybe_restored = -1
if global_step_maybe_restored == -1:
raise ValueError('global_step not found in checkpoint {}. '
'If you want to run finetune eval jobs, you need to '
'make sure that your pretrain model writes '
'global_step in its checkpoints.'.format(ckpt_file_path))
global_step_restored = global_step.numpy()
logging.info('get global_step %d from checkpoint %s', global_step_restored,
ckpt_file_path)
return global_step_restored
def write_json_summary(log_dir, global_step, eval_metrics):
"""Dump evaluation metrics to json file."""
serializable_dict = {}
for name, value in eval_metrics.items():
if hasattr(value, 'numpy'):
serializable_dict[name] = str(value.numpy())
else:
serializable_dict[name] = str(value)
output_json = os.path.join(log_dir, 'metrics-{}.json'.format(global_step))
logging.info('Evaluation results at pretrain step %d: %s', global_step,
serializable_dict)
with tf.io.gfile.GFile(output_json, 'w') as writer:
writer.write(json.dumps(serializable_dict, indent=4) + '\n')
def write_summary(summary_writer, global_step, eval_metrics):
"""Write evaluation metrics to TF summary."""
numeric_dict = {}
for name, value in eval_metrics.items():
numeric_dict[name] = float(orbit.utils.get_value(value))
with summary_writer.as_default():
for name, value in numeric_dict.items():
tf.summary.scalar(name, value, step=global_step)
summary_writer.flush()
def remove_ckpts(model_dir):
"""Remove model checkpoints, so we can restart."""
ckpts = os.path.join(model_dir, 'ckpt-*')
logging.info('removing checkpoint files %s', ckpts)
for file_to_remove in tf.io.gfile.glob(ckpts):
tf.io.gfile.rmtree(file_to_remove)
file_to_remove = os.path.join(model_dir, 'checkpoint')
if tf.io.gfile.exists(file_to_remove):
tf.io.gfile.remove(file_to_remove)
def write_model_params(model: Union[tf.Module, tf.keras.Model],
output_path: str) -> None:
"""Writes the model parameters and shapes to a file.
Args:
model: A model instance.
output_path: Output file path.
"""
with tf.io.gfile.GFile(output_path, 'w') as f:
total_params = 0
for var in model.variables:
shape = tf.shape(var)
total_params += tf.math.reduce_prod(shape).numpy()
f.write(f'{var.name} {shape.numpy().tolist()}\n')
f.write(f'\nTotal params: {total_params}\n')
def try_count_params(
model: Union[tf.Module, tf.keras.Model],
trainable_only: bool = False):
"""Count the number of parameters if model is possible.
Args:
model: Try to count the number of params in this model.
trainable_only: Whether to calculate trainable params only. This flag is
not used when the model has `count_params` attribute.
Returns:
The number of parameters or None.
"""
if hasattr(model, 'count_params'):
try:
return model.count_params()
except ValueError:
logging.info('Number of trainable params unknown, because the build() '
'methods in keras layers were not called. This is probably '
'because the model was not feed any input, e.g., the max '
'train step already reached before this run.')
return None
else:
total_params = 0
variables = model.trainable_variables if trainable_only else model.variables
for var in variables:
shape = tf.shape(var)
total_params += tf.math.reduce_prod(shape).numpy()
return total_params
def try_count_flops(model: Union[tf.Module, tf.keras.Model],
inputs_kwargs: Optional[Dict[str, Any]] = None,
output_path: Optional[str] = None):
"""Counts and returns model FLOPs.
Args:
model: A model instance.
inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
shape specifications to getting corresponding concrete function.
output_path: A file path to write the profiling results to.
Returns:
The model's FLOPs.
"""
if hasattr(model, 'inputs'):
try:
# Get input shape and set batch size to 1.
if model.inputs:
inputs = [
tf.TensorSpec([1] + input.shape[1:], input.dtype)
for input in model.inputs
]
concrete_func = tf.function(model).get_concrete_function(inputs)
# If model.inputs is invalid, try to use the input to get concrete
# function for model.call (subclass model).
else:
concrete_func = tf.function(model.call).get_concrete_function(
**inputs_kwargs)
frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func)
# Calculate FLOPs.
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
if output_path is not None:
opts['output'] = f'file:outfile={output_path}'
else:
opts['output'] = 'none'
flops = tf.compat.v1.profiler.profile(
graph=frozen_func.graph, run_meta=run_meta, options=opts)
return flops.total_float_ops
except Exception as e: # pylint: disable=broad-except
logging.info(
'Failed to count model FLOPs with error %s, because the build() '
'methods in keras layers were not called. This is probably because '
'the model was not feed any input, e.g., the max train step already '
'reached before this run.', e)
return None
return None
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.train_utils."""
import json
import os
import pprint
import numpy as np
import tensorflow as tf
from official.core import exp_factory
from official.core import test_utils
from official.core import train_utils
from official.modeling import hyperparams
@exp_factory.register_config_factory('foo')
def foo():
"""Multitask experiment for test."""
experiment_config = hyperparams.Config(
default_params={
'runtime': {
'tpu': 'fake',
},
'task': {
'model': {
'model_id': 'bar',
},
},
'trainer': {
'train_steps': -1,
'validation_steps': -1,
},
})
return experiment_config
class TrainUtilsTest(tf.test.TestCase):
def test_get_leaf_nested_dict(self):
d = {'a': {'i': {'x': 5}}}
self.assertEqual(train_utils.get_leaf_nested_dict(d, ['a', 'i', 'x']), 5)
def test_get_leaf_nested_dict_not_leaf(self):
with self.assertRaisesRegex(KeyError, 'The value extracted with keys.*'):
d = {'a': {'i': {'x': 5}}}
train_utils.get_leaf_nested_dict(d, ['a', 'i'])
def test_get_leaf_nested_dict_path_not_exist_missing_key(self):
with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
d = {'a': {'i': {'x': 5}}}
train_utils.get_leaf_nested_dict(d, ['a', 'i', 'y'])
def test_get_leaf_nested_dict_path_not_exist_out_of_range(self):
with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
d = {'a': {'i': {'x': 5}}}
train_utils.get_leaf_nested_dict(d, ['a', 'i', 'z'])
def test_get_leaf_nested_dict_path_not_exist_meets_leaf(self):
with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
d = {'a': {'i': 5}}
train_utils.get_leaf_nested_dict(d, ['a', 'i', 'z'])
def test_cast_leaf_nested_dict(self):
d = {'a': {'i': {'x': '123'}}, 'b': 456.5}
d = train_utils.cast_leaf_nested_dict(d, int)
self.assertEqual(d['a']['i']['x'], 123)
self.assertEqual(d['b'], 456)
def test_write_model_params_keras_model(self):
inputs = np.zeros([2, 3])
model = test_utils.FakeKerasModel()
model(inputs) # Must do forward pass to build the model.
filepath = os.path.join(self.create_tempdir(), 'model_params.txt')
train_utils.write_model_params(model, filepath)
actual = tf.io.gfile.GFile(filepath, 'r').read().splitlines()
expected = [
'fake_keras_model/dense/kernel:0 [3, 4]',
'fake_keras_model/dense/bias:0 [4]',
'fake_keras_model/dense_1/kernel:0 [4, 4]',
'fake_keras_model/dense_1/bias:0 [4]',
'',
'Total params: 36',
]
self.assertEqual(actual, expected)
def test_write_model_params_module(self):
inputs = np.zeros([2, 3], dtype=np.float32)
model = test_utils.FakeModule(3, name='fake_module')
model(inputs) # Must do forward pass to build the model.
filepath = os.path.join(self.create_tempdir(), 'model_params.txt')
train_utils.write_model_params(model, filepath)
actual = tf.io.gfile.GFile(filepath, 'r').read().splitlines()
expected = [
'fake_module/dense/b:0 [4]',
'fake_module/dense/w:0 [3, 4]',
'fake_module/dense_1/b:0 [4]',
'fake_module/dense_1/w:0 [4, 4]',
'',
'Total params: 36',
]
self.assertEqual(actual, expected)
def test_construct_experiment_from_flags(self):
options = train_utils.ParseConfigOptions(
experiment='foo',
config_file=[],
tpu='bar',
tf_data_service='',
params_override='task.model.model_id=new,'
'trainer.train_steps=10,'
'trainer.validation_steps=11')
builder = train_utils.ExperimentParser(options)
params_from_obj = builder.parse()
params_from_func = train_utils.parse_configuration(options)
pp = pprint.PrettyPrinter()
self.assertEqual(
pp.pformat(params_from_obj.as_dict()),
pp.pformat(params_from_func.as_dict()))
self.assertEqual(params_from_obj.runtime.tpu, 'bar')
self.assertEqual(params_from_obj.task.model.model_id, 'new')
self.assertEqual(params_from_obj.trainer.train_steps, 10)
self.assertEqual(params_from_obj.trainer.validation_steps, 11)
class BestCheckpointExporterTest(tf.test.TestCase):
def test_maybe_export(self):
model_dir = self.create_tempdir().full_path
best_ckpt_path = os.path.join(model_dir, 'best_ckpt-1')
metric_name = 'test_metric|metric_1'
exporter = train_utils.BestCheckpointExporter(
model_dir, metric_name, 'higher')
v = tf.Variable(1.0)
checkpoint = tf.train.Checkpoint(v=v)
ret = exporter.maybe_export_checkpoint(
checkpoint, {'test_metric': {'metric_1': 5.0}}, 100)
with self.subTest(name='Successful first save.'):
self.assertEqual(ret, True)
v_2 = tf.Variable(2.0)
checkpoint_2 = tf.train.Checkpoint(v=v_2)
checkpoint_2.restore(best_ckpt_path)
self.assertEqual(v_2.numpy(), 1.0)
v = tf.Variable(3.0)
checkpoint = tf.train.Checkpoint(v=v)
ret = exporter.maybe_export_checkpoint(
checkpoint, {'test_metric': {'metric_1': 6.0}}, 200)
with self.subTest(name='Successful better metic save.'):
self.assertEqual(ret, True)
v_2 = tf.Variable(2.0)
checkpoint_2 = tf.train.Checkpoint(v=v_2)
checkpoint_2.restore(best_ckpt_path)
self.assertEqual(v_2.numpy(), 3.0)
v = tf.Variable(5.0)
checkpoint = tf.train.Checkpoint(v=v)
ret = exporter.maybe_export_checkpoint(
checkpoint, {'test_metric': {'metric_1': 1.0}}, 300)
with self.subTest(name='Worse metic no save.'):
self.assertEqual(ret, False)
v_2 = tf.Variable(2.0)
checkpoint_2 = tf.train.Checkpoint(v=v_2)
checkpoint_2.restore(best_ckpt_path)
self.assertEqual(v_2.numpy(), 3.0)
def test_export_best_eval_metric(self):
model_dir = self.create_tempdir().full_path
metric_name = 'test_metric|metric_1'
exporter = train_utils.BestCheckpointExporter(model_dir, metric_name,
'higher')
exporter.export_best_eval_metric({'test_metric': {'metric_1': 5.0}}, 100)
with tf.io.gfile.GFile(os.path.join(model_dir, 'info.json'),
'rb') as reader:
metric = json.loads(reader.read())
self.assertAllEqual(
metric,
{'test_metric': {'metric_1': 5.0}, 'best_ckpt_global_step': 100.0})
def test_export_best_eval_metric_skips_non_scalar_values(self):
model_dir = self.create_tempdir().full_path
metric_name = 'test_metric|metric_1'
exporter = train_utils.BestCheckpointExporter(model_dir, metric_name,
'higher')
image = tf.zeros(shape=[16, 8, 1])
eval_logs = {'test_metric': {'metric_1': 5.0, 'image': image}}
exporter.export_best_eval_metric(eval_logs, 100)
with tf.io.gfile.GFile(os.path.join(model_dir, 'info.json'),
'rb') as reader:
metric = json.loads(reader.read())
self.assertAllEqual(
metric,
{'test_metric': {'metric_1': 5.0}, 'best_ckpt_global_step': 100.0})
if __name__ == '__main__':
tf.test.main()
Models in this `legacy` directory are mainly are used for benchmarking the
models.
Please note that the models in this `legacy` directory are not supported like
the models in official/nlp and official/vision.
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ALBERT (ALBERT: A Lite BERT for Self-supervised Learning of Language Representations)
**WARNING**: This directory is deprecated.
See `nlp/docs/MODEL_GARDEN.md` for the new ALBERT implementation.
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The ALBERT configurations."""
import six
from official.legacy.bert import configs
class AlbertConfig(configs.BertConfig):
"""Configuration for `ALBERT`."""
def __init__(self, num_hidden_groups=1, inner_group_num=1, **kwargs):
"""Constructs AlbertConfig.
Args:
num_hidden_groups: Number of group for the hidden layers, parameters in
the same group are shared. Note that this value and also the following
'inner_group_num' has to be 1 for now, because all released ALBERT
models set them to 1. We may support arbitary valid values in future.
inner_group_num: Number of inner repetition of attention and ffn.
**kwargs: The remaining arguments are the same as above 'BertConfig'.
"""
super(AlbertConfig, self).__init__(**kwargs)
# TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1
# in the released ALBERT. Support other values in AlbertEncoder if needed.
if inner_group_num != 1 or num_hidden_groups != 1:
raise ValueError("We only support 'inner_group_num' and "
"'num_hidden_groups' as 1.")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `AlbertConfig` from a Python dictionary of parameters."""
config = AlbertConfig(vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
# BERT (Bidirectional Encoder Representations from Transformers)
**WARNING**: We are on the way to deprecating most of the code in this directory.
Please see
[this link](../g3doc/tutorials/bert_new.md)
for the new tutorial and use the new code in `nlp/modeling`. This README is
still correct for this legacy implementation.
The academic paper which describes BERT in detail and provides full results on a
number of tasks can be found here: https://arxiv.org/abs/1810.04805.
This repository contains TensorFlow 2.x implementation for BERT.
## Contents
* [Contents](#contents)
* [Pre-trained Models](#pre-trained-models)
* [Restoring from Checkpoints](#restoring-from-checkpoints)
* [Set Up](#set-up)
* [Process Datasets](#process-datasets)
* [Fine-tuning with BERT](#fine-tuning-with-bert)
* [Cloud GPUs and TPUs](#cloud-gpus-and-tpus)
* [Sentence and Sentence-pair Classification Tasks](#sentence-and-sentence-pair-classification-tasks)
* [SQuAD 1.1](#squad-1.1)
## Pre-trained Models
We released both checkpoints and tf.hub modules as the pretrained models for
fine-tuning. They are TF 2.x compatible and are converted from the checkpoints
released in TF 1.x official BERT repository
[google-research/bert](https://github.com/google-research/bert)
in order to keep consistent with BERT paper.
### Access to Pretrained Checkpoints
Pretrained checkpoints can be found in the following links:
**Note: We have switched BERT implementation
to use Keras functional-style networks in [nlp/modeling](../modeling).
The new checkpoints are:**
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_uncased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Large, Cased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_cased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12.tar.gz)**:
12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Large, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-12_H-768_A-12.tar.gz)**:
12-layer, 768-hidden, 12-heads , 110M parameters
* **[`BERT-Large, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-24_H-1024_A-16.tar.gz)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Multilingual Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/multi_cased_L-12_H-768_A-12.tar.gz)**:
104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
We recommend to host checkpoints on Google Cloud Storage buckets when you use
Cloud GPU/TPU.
### Restoring from Checkpoints
`tf.train.Checkpoint` is used to manage model checkpoints in TF 2. To restore
weights from provided pre-trained checkpoints, you can use the following code:
```python
init_checkpoint='the pretrained model checkpoint path.'
model=tf.keras.Model() # Bert pre-trained model as feature extractor.
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(init_checkpoint)
```
Checkpoints featuring native serialized Keras models
(i.e. model.load()/load_weights()) will be available soon.
### Access to Pretrained hub modules.
Pretrained tf.hub modules in TF 2.x SavedModel format can be found in the
following links:
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Large, Cased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/)**:
12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Large, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/)**:
12-layer, 768-hidden, 12-heads , 110M parameters
* **[`BERT-Large, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-24_H-1024_A-16/)**:
24-layer, 1024-hidden, 16-heads, 340M parameters
* **[`BERT-Base, Multilingual Cased`](https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/)**:
104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Base, Chinese`](https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/)**:
Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads,
110M parameters
## Set Up
```shell
export PYTHONPATH="$PYTHONPATH:/path/to/models"
```
Install `tf-nightly` to get latest updates:
```shell
pip install tf-nightly-gpu
```
With TPU, GPU support is not necessary. First, you need to create a `tf-nightly`
TPU with [ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):
```shell
ctpu up -name <instance name> --tf-version=”nightly”
```
Second, you need to install TF 2 `tf-nightly` on your VM:
```shell
pip install tf-nightly
```
## Process Datasets
### Pre-training
There is no change to generate pre-training data. Please use the script
[`../data/create_pretraining_data.py`](../data/create_pretraining_data.py)
which is essentially branched from the [BERT research repo](https://github.com/google-research/bert)
to get processed pre-training data and it adapts to TF2 symbols and python3
compatibility.
Running the pre-training script requires an input and output directory, as well as a vocab file. Note that max_seq_length will need to match the sequence length parameter you specify when you run pre-training.
Example shell script to call create_pretraining_data.py
```
export WORKING_DIR='local disk or cloud location'
export BERT_DIR='local disk or cloud location'
python models/official/nlp/data/create_pretraining_data.py \
--input_file=$WORKING_DIR/input/input.txt \
--output_file=$WORKING_DIR/output/tf_examples.tfrecord \
--vocab_file=$BERT_DIR/wwm_uncased_L-24_H-1024_A-16/vocab.txt \
--do_lower_case=True \
--max_seq_length=512 \
--max_predictions_per_seq=76 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5
```
### Fine-tuning
To prepare the fine-tuning data for final model training, use the
[`../data/create_finetuning_data.py`](../data/create_finetuning_data.py) script.
Resulting datasets in `tf_record` format and training meta data should be later
passed to training or evaluation scripts. The task-specific arguments are
described in the following sections:
* GLUE
Users can download the
[GLUE data](https://gluebenchmark.com/tasks) by running
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
and unpack it to some directory `$GLUE_DIR`.
Also, users can download [Pretrained Checkpoint](#access-to-pretrained-checkpoints) and locate it on some directory `$BERT_DIR` instead of using checkpoints on Google Cloud Storage.
```shell
export GLUE_DIR=~/glue
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export TASK_NAME=MNLI
export OUTPUT_DIR=gs://some_bucket/datasets
python ../data/create_finetuning_data.py \
--input_data_dir=${GLUE_DIR}/${TASK_NAME}/ \
--vocab_file=${BERT_DIR}/vocab.txt \
--train_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_train.tf_record \
--eval_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_eval.tf_record \
--meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
--fine_tuning_task_type=classification --max_seq_length=128 \
--classification_task_name=${TASK_NAME}
```
* SQUAD
The [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/) contains
detailed information about the SQuAD datasets and evaluation.
The necessary files can be found here:
* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
* [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
* [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
* [train-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json)
* [dev-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json)
* [evaluate-v2.0.py](https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/)
```shell
export SQUAD_DIR=~/squad
export SQUAD_VERSION=v1.1
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export OUTPUT_DIR=gs://some_bucket/datasets
python ../data/create_finetuning_data.py \
--squad_data_file=${SQUAD_DIR}/train-${SQUAD_VERSION}.json \
--vocab_file=${BERT_DIR}/vocab.txt \
--train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
--fine_tuning_task_type=squad --max_seq_length=384
```
Note: To create fine-tuning data with SQUAD 2.0, you need to add flag `--version_2_with_negative=True`.
## Fine-tuning with BERT
### Cloud GPUs and TPUs
* Cloud Storage
The unzipped pre-trained model files can also be found in the Google Cloud
Storage folder `gs://cloud-tpu-checkpoints/bert/keras_bert`. For example:
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export MODEL_DIR=gs://some_bucket/my_output_dir
```
Currently, users are able to access to `tf-nightly` TPUs and the following TPU
script should run with `tf-nightly`.
* GPU -> TPU
Just add the following flags to `run_classifier.py` or `run_squad.py`:
```shell
--distribution_strategy=tpu
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
### Sentence and Sentence-pair Classification Tasks
This example code fine-tunes `BERT-Large` on the Microsoft Research Paraphrase
Corpus (MRPC) corpus, which only contains 3,600 examples and can fine-tune in a
few minutes on most GPUs.
We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
workflow.
For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
(uncased_L-12_H-768_A-12).
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export MODEL_DIR=gs://some_bucket/my_output_dir
export GLUE_DIR=gs://some_bucket/datasets
export TASK=MRPC
python run_classifier.py \
--mode='train_and_eval' \
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=4 \
--eval_batch_size=4 \
--steps_per_loop=1 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=mirrored
```
Alternatively, instead of specifying `init_checkpoint`, you can specify
`hub_module_url` to employ a pre-trained BERT hub module, e.g.,
` --hub_module_url=https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1`.
After training a model, to get predictions from the classifier, you can set the
`--mode=predict` and offer the test set tfrecords to `--eval_data_path`.
The output will be created in file called test_results.tsv in the output folder.
Each line will contain output for each sample, columns are the class
probabilities.
```shell
python run_classifier.py \
--mode='predict' \
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
--bert_config_file=${BERT_DIR}/bert_config.json \
--eval_batch_size=4 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=mirrored
```
To use TPU, you only need to switch the distribution strategy type to `tpu` with TPU
information and use remote storage for model checkpoints.
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export TPU_IP_ADDRESS='???'
export MODEL_DIR=gs://some_bucket/my_output_dir
export GLUE_DIR=gs://some_bucket/datasets
export TASK=MRPC
python run_classifier.py \
--mode='train_and_eval' \
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=32 \
--eval_batch_size=32 \
--steps_per_loop=1000 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=tpu \
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
Note that, we specify `steps_per_loop=1000` for TPU, because running a loop of
training steps inside a `tf.function` can significantly increase TPU utilization
and callbacks will not be called inside the loop.
### SQuAD 1.1
The Stanford Question Answering Dataset (SQuAD) is a popular question answering
benchmark dataset. See more on [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/).
We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
workflow.
For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
(uncased_L-12_H-768_A-12).
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export SQUAD_DIR=gs://some_bucket/datasets
export MODEL_DIR=gs://some_bucket/my_output_dir
export SQUAD_VERSION=v1.1
python run_squad.py \
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
--vocab_file=${BERT_DIR}/vocab.txt \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=4 \
--predict_batch_size=4 \
--learning_rate=8e-5 \
--num_train_epochs=2 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=mirrored
```
Similarly, you can replace `init_checkpoint` FLAG with `hub_module_url` to
specify a hub module path.
`run_squad.py` writes the prediction for `--predict_file` by default. If you set
the `--model=predict` and offer the SQuAD test data, the scripts will generate
the prediction json file.
To use TPU, you need to switch the distribution strategy type to `tpu` with TPU
information.
```shell
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
export TPU_IP_ADDRESS='???'
export MODEL_DIR=gs://some_bucket/my_output_dir
export SQUAD_DIR=gs://some_bucket/datasets
export SQUAD_VERSION=v1.1
python run_squad.py \
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
--vocab_file=${BERT_DIR}/vocab.txt \
--bert_config_file=${BERT_DIR}/bert_config.json \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=32 \
--learning_rate=8e-5 \
--num_train_epochs=2 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=tpu \
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```
The dev set predictions will be saved into a file called predictions.json in the
model_dir:
```shell
python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ./squad/predictions.json
```
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# BERT FineTuning with Cloud TPU: Sentence and Sentence-Pair Classification Tasks (TF 2.1)
This tutorial shows you how to train the Bidirectional Encoder Representations from Transformers (BERT) model on Cloud TPU.
## Set up Cloud Storage and Compute Engine VM
1. [Open a cloud shell window](https://console.cloud.google.com/?cloudshell=true&_ga=2.11844148.-1612541229.1552429951)
2. Create a variable for the project's id:
```
export PROJECT_ID=your-project_id
```
3. Configure `gcloud` command-line tool to use the project where you want to create Cloud TPU.
```
gcloud config set project ${PROJECT_ID}
```
4. Create a Cloud Storage bucket using the following command:
```
gsutil mb -p ${PROJECT_ID} -c standard -l europe-west4 -b on gs://your-bucket-name
```
This Cloud Storage bucket stores the data you use to train your model and the training results.
5. Launch a Compute Engine VM and Cloud TPU using the ctpu up command.
```
ctpu up --tpu-size=v3-8 \
--machine-type=n1-standard-8 \
--zone=europe-west4-a \
--tf-version=2.1 [optional flags: --project, --name]
```
6. The configuration you specified appears. Enter y to approve or n to cancel.
7. When the ctpu up command has finished executing, verify that your shell prompt has changed from username@project to username@tpuname. This change shows that you are now logged into your Compute Engine VM.
```
gcloud compute ssh vm-name --zone=europe-west4-a
(vm)$ export TPU_NAME=vm-name
```
As you continue these instructions, run each command that begins with `(vm)$` in your VM session window.
## Prepare the Dataset
1. From your Compute Engine virtual machine (VM), install requirements.txt.
```
(vm)$ cd /usr/share/models
(vm)$ sudo pip3 install -r official/requirements.txt
```
2. Optional: download download_glue_data.py
This tutorial uses the General Language Understanding Evaluation (GLUE) benchmark to evaluate and analyze the performance of the model. The GLUE data is provided for this tutorial at gs://cloud-tpu-checkpoints/bert/classification.
## Define parameter values
Next, define several parameter values that are required when you train and evaluate your model:
```
(vm)$ export PYTHONPATH="$PYTHONPATH:/usr/share/tpu/models"
(vm)$ export STORAGE_BUCKET=gs://your-bucket-name
(vm)$ export BERT_BASE_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
(vm)$ export MODEL_DIR=${STORAGE_BUCKET}/bert-output
(vm)$ export GLUE_DIR=gs://cloud-tpu-checkpoints/bert/classification
(vm)$ export TASK=mnli
```
## Train the model
From your Compute Engine VM, run the following command.
```
(vm)$ python3 official/nlp/bert/run_classifier.py \
--mode='train_and_eval' \
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--train_batch_size=32 \
--eval_batch_size=32 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=tpu \
--tpu=${TPU_NAME}
```
## Verify your results
The training takes approximately 1 hour on a v3-8 TPU. When script completes, you should see results similar to the following:
```
Training Summary:
{'train_loss': 0.28142181038856506,
'last_train_metrics': 0.9467429518699646,
'eval_metrics': 0.8599063158035278,
'total_training_steps': 36813}
```
## Clean up
To avoid incurring charges to your GCP account for the resources used in this topic:
1. Disconnect from the Compute Engine VM:
```
(vm)$ exit
```
2. In your Cloud Shell, run ctpu delete with the --zone flag you used when you set up the Cloud TPU to delete your Compute Engine VM and your Cloud TPU:
```
$ ctpu delete --zone=your-zone
```
3. Run ctpu status specifying your zone to make sure you have no instances allocated to avoid unnecessary charges for TPU usage. The deletion might take several minutes. A response like the one below indicates there are no more allocated instances:
```
$ ctpu status --zone=your-zone
```
4. Run gsutil as shown, replacing your-bucket with the name of the Cloud Storage bucket you created for this tutorial:
```
$ gsutil rm -r gs://your-bucket
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment