Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b261ebb4
Commit
b261ebb4
authored
Oct 11, 2021
by
A. Unique TensorFlower
Browse files
Merge pull request #10286 from PurdueDualityLab:task_pr
PiperOrigin-RevId: 402338060
parents
ca431476
379d64c5
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1031 additions
and
53 deletions
+1031
-53
official/vision/beta/projects/yolo/ops/preprocessing_ops.py
official/vision/beta/projects/yolo/ops/preprocessing_ops.py
+7
-7
official/vision/beta/projects/yolo/optimization/__init__.py
official/vision/beta/projects/yolo/optimization/__init__.py
+22
-0
official/vision/beta/projects/yolo/optimization/configs/__init__.py
...ision/beta/projects/yolo/optimization/configs/__init__.py
+14
-0
official/vision/beta/projects/yolo/optimization/configs/optimization_config.py
...projects/yolo/optimization/configs/optimization_config.py
+56
-0
official/vision/beta/projects/yolo/optimization/configs/optimizer_config.py
...ta/projects/yolo/optimization/configs/optimizer_config.py
+63
-0
official/vision/beta/projects/yolo/optimization/optimizer_factory.py
...sion/beta/projects/yolo/optimization/optimizer_factory.py
+99
-0
official/vision/beta/projects/yolo/optimization/sgd_torch.py
official/vision/beta/projects/yolo/optimization/sgd_torch.py
+312
-0
official/vision/beta/projects/yolo/tasks/task_utils.py
official/vision/beta/projects/yolo/tasks/task_utils.py
+52
-0
official/vision/beta/projects/yolo/tasks/yolo.py
official/vision/beta/projects/yolo/tasks/yolo.py
+404
-0
official/vision/beta/projects/yolo/train.py
official/vision/beta/projects/yolo/train.py
+2
-46
No files found.
official/vision/beta/projects/yolo/ops/preprocessing_ops.py
View file @
b261ebb4
...
...
@@ -170,14 +170,14 @@ def get_image_shape(image):
def
_augment_hsv_darknet
(
image
,
rh
,
rs
,
rv
,
seed
=
None
):
"""Randomize the hue, saturation, and brightness via the darknet method."""
if
rh
>
0.0
:
delta
=
random_uniform_strong
(
-
rh
,
rh
,
seed
=
seed
)
image
=
tf
.
image
.
adjust_hue
(
image
,
delta
)
delta
h
=
random_uniform_strong
(
-
rh
,
rh
,
seed
=
seed
)
image
=
tf
.
image
.
adjust_hue
(
image
,
delta
h
)
if
rs
>
0.0
:
delta
=
random_scale
(
rs
,
seed
=
seed
)
image
=
tf
.
image
.
adjust_saturation
(
image
,
delta
)
delta
s
=
random_scale
(
rs
,
seed
=
seed
)
image
=
tf
.
image
.
adjust_saturation
(
image
,
delta
s
)
if
rv
>
0.0
:
delta
=
random_scale
(
rv
,
seed
=
seed
)
image
*=
delta
delta
v
=
random_scale
(
rv
,
seed
=
seed
)
image
*=
tf
.
cast
(
deltav
,
image
.
dtype
)
# clip the values of the image between 0.0 and 1.0
image
=
tf
.
clip_by_value
(
image
,
0.0
,
1.0
)
...
...
@@ -719,7 +719,7 @@ def affine_warp_boxes(affine, boxes, output_size, box_history):
return
tf
.
stack
([
y_min
,
x_min
,
y_max
,
x_max
],
axis
=-
1
)
def
_aug_boxes
(
affine_matrix
,
box
):
"""Apply an affine transformation matrix M to the boxes augment
e
boxes."""
"""Apply an affine transformation matrix M to the boxes augment boxes."""
corners
=
_get_corners
(
box
)
corners
=
tf
.
reshape
(
corners
,
[
-
1
,
4
,
2
])
z
=
tf
.
expand_dims
(
tf
.
ones_like
(
corners
[...,
1
]),
axis
=-
1
)
...
...
official/vision/beta/projects/yolo/optimization/__init__.py
0 → 100755
View file @
b261ebb4
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimization package definition."""
# pylint: disable=wildcard-import
from
official.modeling.optimization.configs.learning_rate_config
import
*
from
official.modeling.optimization.ema_optimizer
import
ExponentialMovingAverage
from
official.vision.beta.projects.yolo.optimization.configs.optimization_config
import
*
from
official.vision.beta.projects.yolo.optimization.configs.optimizer_config
import
*
from
official.vision.beta.projects.yolo.optimization.optimizer_factory
import
OptimizerFactory
as
YoloOptimizerFactory
official/vision/beta/projects/yolo/optimization/configs/__init__.py
0 → 100755
View file @
b261ebb4
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/vision/beta/projects/yolo/optimization/configs/optimization_config.py
0 → 100755
View file @
b261ebb4
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataclasses for optimization configs.
This file define the dataclass for optimization configs (OptimizationConfig).
It also has two helper functions get_optimizer_config, and get_lr_config from
an OptimizationConfig class.
"""
import
dataclasses
from
typing
import
Optional
from
official.modeling.optimization.configs
import
optimization_config
as
optimization_cfg
from
official.vision.beta.projects.yolo.optimization.configs
import
optimizer_config
as
opt_cfg
@
dataclasses
.
dataclass
class
OptimizerConfig
(
optimization_cfg
.
OptimizerConfig
):
"""Configuration for optimizer.
Attributes:
type: 'str', type of optimizer to be used, on the of fields below.
sgd: sgd optimizer config.
adam: adam optimizer config.
adamw: adam with weight decay.
lamb: lamb optimizer.
rmsprop: rmsprop optimizer.
"""
type
:
Optional
[
str
]
=
None
sgd_torch
:
opt_cfg
.
SGDTorchConfig
=
opt_cfg
.
SGDTorchConfig
()
@
dataclasses
.
dataclass
class
OptimizationConfig
(
optimization_cfg
.
OptimizationConfig
):
"""Configuration for optimizer and learning rate schedule.
Attributes:
optimizer: optimizer oneof config.
ema: optional exponential moving average optimizer config, if specified, ema
optimizer will be used.
learning_rate: learning rate oneof config.
warmup: warmup oneof config.
"""
type
:
Optional
[
str
]
=
None
optimizer
:
OptimizerConfig
=
OptimizerConfig
()
official/vision/beta/projects/yolo/optimization/configs/optimizer_config.py
0 → 100755
View file @
b261ebb4
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataclasses for optimizer configs."""
import
dataclasses
from
typing
import
List
,
Optional
from
official.modeling.hyperparams
import
base_config
from
official.modeling.optimization.configs
import
optimizer_config
@
dataclasses
.
dataclass
class
BaseOptimizerConfig
(
base_config
.
Config
):
"""Base optimizer config.
Attributes:
clipnorm: float >= 0 or None. If not None, Gradients will be clipped when
their L2 norm exceeds this value.
clipvalue: float >= 0 or None. If not None, Gradients will be clipped when
their absolute value exceeds this value.
global_clipnorm: float >= 0 or None. If not None, gradient of all weights is
clipped so that their global norm is no higher than this value
"""
clipnorm
:
Optional
[
float
]
=
None
clipvalue
:
Optional
[
float
]
=
None
global_clipnorm
:
Optional
[
float
]
=
None
@
dataclasses
.
dataclass
class
SGDTorchConfig
(
optimizer_config
.
BaseOptimizerConfig
):
"""Configuration for SGD optimizer.
The attributes for this class matches the arguments of tf.keras.optimizer.SGD.
Attributes:
name: name of the optimizer.
decay: decay rate for SGD optimizer.
nesterov: nesterov for SGD optimizer.
momentum_start: momentum starting point for SGD optimizer.
momentum: momentum for SGD optimizer.
"""
name
:
str
=
"SGD"
decay
:
float
=
0.0
nesterov
:
bool
=
False
momentum_start
:
float
=
0.0
momentum
:
float
=
0.9
warmup_steps
:
int
=
0
weight_decay
:
float
=
0.0
weight_keys
:
Optional
[
List
[
str
]]
=
dataclasses
.
field
(
default_factory
=
lambda
:
[
"kernel"
,
"weight"
])
bias_keys
:
Optional
[
List
[
str
]]
=
dataclasses
.
field
(
default_factory
=
lambda
:
[
"bias"
,
"beta"
])
official/vision/beta/projects/yolo/optimization/optimizer_factory.py
0 → 100755
View file @
b261ebb4
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimizer factory class."""
import
gin
from
official.modeling.optimization
import
ema_optimizer
from
official.modeling.optimization
import
optimizer_factory
from
official.vision.beta.projects.yolo.optimization
import
sgd_torch
optimizer_factory
.
OPTIMIZERS_CLS
.
update
({
'sgd_torch'
:
sgd_torch
.
SGDTorch
,
})
OPTIMIZERS_CLS
=
optimizer_factory
.
OPTIMIZERS_CLS
LR_CLS
=
optimizer_factory
.
LR_CLS
WARMUP_CLS
=
optimizer_factory
.
WARMUP_CLS
class
OptimizerFactory
(
optimizer_factory
.
OptimizerFactory
):
"""Optimizer factory class.
This class builds learning rate and optimizer based on an optimization config.
To use this class, you need to do the following:
(1) Define optimization config, this includes optimizer, and learning rate
schedule.
(2) Initialize the class using the optimization config.
(3) Build learning rate.
(4) Build optimizer.
This is a typical example for using this class:
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
},
'warmup': {
'type': 'linear',
'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01}
}
}
opt_config = OptimizationConfig(params)
opt_factory = OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
"""
def
get_bias_lr_schedule
(
self
,
bias_lr
):
"""Build learning rate.
Builds learning rate from config. Learning rate schedule is built according
to the learning rate config. If learning rate type is consant,
lr_config.learning_rate is returned.
Args:
bias_lr: learning rate config.
Returns:
tf.keras.optimizers.schedules.LearningRateSchedule instance. If
learning rate type is consant, lr_config.learning_rate is returned.
"""
if
self
.
_lr_type
==
'constant'
:
lr
=
self
.
_lr_config
.
learning_rate
else
:
lr
=
LR_CLS
[
self
.
_lr_type
](
**
self
.
_lr_config
.
as_dict
())
if
self
.
_warmup_config
:
if
self
.
_warmup_type
!=
'linear'
:
raise
ValueError
(
'Smart Bias is only supported currently with a'
'linear warm up.'
)
warm_up_cfg
=
self
.
_warmup_config
.
as_dict
()
warm_up_cfg
[
'warmup_learning_rate'
]
=
bias_lr
lr
=
WARMUP_CLS
[
'linear'
](
lr
,
**
warm_up_cfg
)
return
lr
@
gin
.
configurable
def
add_ema
(
self
,
optimizer
):
"""Add EMA to the optimizer independently of the build optimizer method."""
if
self
.
_use_ema
:
optimizer
=
ema_optimizer
.
ExponentialMovingAverage
(
optimizer
,
**
self
.
_ema_config
.
as_dict
())
return
optimizer
official/vision/beta/projects/yolo/optimization/sgd_torch.py
0 → 100644
View file @
b261ebb4
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SGD PyTorch optimizer."""
import
re
from
absl
import
logging
import
tensorflow
as
tf
LearningRateSchedule
=
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
def
_var_key
(
var
):
"""Key for representing a primary variable, for looking up slots.
In graph mode the name is derived from the var shared name.
In eager mode the name is derived from the var unique id.
If distribution strategy exists, get the primary variable first.
Args:
var: the variable.
Returns:
the unique name of the variable.
"""
# pylint: disable=protected-access
# Get the distributed variable if it exists.
if
hasattr
(
var
,
"_distributed_container"
):
var
=
var
.
_distributed_container
()
if
var
.
_in_graph_mode
:
return
var
.
_shared_name
return
var
.
_unique_id
class
SGDTorch
(
tf
.
keras
.
optimizers
.
Optimizer
):
"""Optimizer that simulates the SGD module used in pytorch.
For details on the differences between the original SGD implemention and the
one in pytorch:
https://pytorch.org/docs/stable/generated/torch.optim.SGD.html.
This optimizer also allow for the usage of a momentum warmup along side a
learning rate warm up, though using this is not required.
Example of usage for training:
```python
opt = SGDTorch(learning_rate, weight_decay = 0.0001)
l2_regularization = None
# iterate all model.trainable_variables and split the variables by key
# into the weights, biases, and others.
optimizer.search_and_set_variable_groups(model.trainable_variables)
# if the learning rate schedule on the biases are different. if lr is not set
# the default schedule used for weights will be used on the biases.
opt.set_bias_lr(<lr schedule>)
# if the learning rate schedule on the others are different. if lr is not set
# the default schedule used for weights will be used on the biases.
opt.set_other_lr(<lr schedule>)
```
"""
_HAS_AGGREGATE_GRAD
=
True
def
__init__
(
self
,
weight_decay
=
0.0
,
learning_rate
=
0.01
,
momentum
=
0.0
,
momentum_start
=
0.0
,
warmup_steps
=
1000
,
nesterov
=
False
,
name
=
"SGD"
,
weight_keys
=
(
"kernel"
,
"weight"
),
bias_keys
=
(
"bias"
,
"beta"
),
**
kwargs
):
super
(
SGDTorch
,
self
).
__init__
(
name
,
**
kwargs
)
# Create Hyper Params for each group of the LR
self
.
_set_hyper
(
"learning_rate"
,
kwargs
.
get
(
"lr"
,
learning_rate
))
self
.
_set_hyper
(
"bias_learning_rate"
,
kwargs
.
get
(
"lr"
,
learning_rate
))
self
.
_set_hyper
(
"other_learning_rate"
,
kwargs
.
get
(
"lr"
,
learning_rate
))
# SGD decay param
self
.
_set_hyper
(
"decay"
,
self
.
_initial_decay
)
# Weight decay param
self
.
_weight_decay
=
weight_decay
!=
0.0
self
.
_set_hyper
(
"weight_decay"
,
weight_decay
)
# Enable Momentum
self
.
_momentum
=
False
if
isinstance
(
momentum
,
tf
.
Tensor
)
or
callable
(
momentum
)
or
momentum
>
0
:
self
.
_momentum
=
True
if
isinstance
(
momentum
,
(
int
,
float
))
and
(
momentum
<
0
or
momentum
>
1
):
raise
ValueError
(
"`momentum` must be between [0, 1]."
)
self
.
_set_hyper
(
"momentum"
,
momentum
)
self
.
_set_hyper
(
"momentum_start"
,
momentum_start
)
self
.
_set_hyper
(
"warmup_steps"
,
tf
.
cast
(
warmup_steps
,
tf
.
int32
))
# Enable Nesterov Momentum
self
.
nesterov
=
nesterov
# weights, biases, other
self
.
_weight_keys
=
weight_keys
self
.
_bias_keys
=
bias_keys
self
.
_variables_set
=
False
self
.
_wset
=
set
()
self
.
_bset
=
set
()
self
.
_oset
=
set
()
logging
.
info
(
"Pytorch SGD simulation: "
)
logging
.
info
(
"Weight Decay: %f"
,
weight_decay
)
def
set_bias_lr
(
self
,
lr
):
self
.
_set_hyper
(
"bias_learning_rate"
,
lr
)
def
set_other_lr
(
self
,
lr
):
self
.
_set_hyper
(
"other_learning_rate"
,
lr
)
def
_search
(
self
,
var
,
keys
):
"""Search all all keys for matches. Return True on match."""
if
keys
is
not
None
:
# variable group is not ignored so search for the keys.
for
r
in
keys
:
if
re
.
search
(
r
,
var
.
name
)
is
not
None
:
return
True
return
False
def
search_and_set_variable_groups
(
self
,
variables
):
"""Search all variable for matches at each group."""
weights
=
[]
biases
=
[]
others
=
[]
for
var
in
variables
:
if
self
.
_search
(
var
,
self
.
_weight_keys
):
# search for weights
weights
.
append
(
var
)
elif
self
.
_search
(
var
,
self
.
_bias_keys
):
# search for biases
biases
.
append
(
var
)
else
:
# if all searches fail, add to other group
others
.
append
(
var
)
self
.
_set_variable_groups
(
weights
,
biases
,
others
)
return
weights
,
biases
,
others
def
_set_variable_groups
(
self
,
weights
,
biases
,
others
):
"""Sets the variables to be used in each group."""
if
self
.
_variables_set
:
logging
.
warning
(
"_set_variable_groups has been called again indicating"
"that the variable groups have already been set, they"
"will be updated."
)
self
.
_wset
.
update
(
set
([
_var_key
(
w
)
for
w
in
weights
]))
self
.
_bset
.
update
(
set
([
_var_key
(
b
)
for
b
in
biases
]))
self
.
_oset
.
update
(
set
([
_var_key
(
o
)
for
o
in
others
]))
self
.
_variables_set
=
True
return
def
_get_variable_group
(
self
,
var
,
coefficients
):
if
self
.
_variables_set
:
# check which groups hold which varaibles, preset.
if
_var_key
(
var
)
in
self
.
_wset
:
return
True
,
False
,
False
elif
_var_key
(
var
)
in
self
.
_bset
:
return
False
,
True
,
False
else
:
# search the variables at run time.
if
self
.
_search
(
var
,
self
.
_weight_keys
):
return
True
,
False
,
False
elif
self
.
_search
(
var
,
self
.
_bias_keys
):
return
False
,
True
,
False
return
False
,
False
,
True
def
_create_slots
(
self
,
var_list
):
"""Create a momentum variable for each variable."""
if
self
.
_momentum
:
for
var
in
var_list
:
# check if trainable to support GPU EMA.
if
var
.
trainable
:
self
.
add_slot
(
var
,
"momentum"
)
def
_get_momentum
(
self
,
iteration
):
"""Get the momentum value."""
momentum
=
self
.
_get_hyper
(
"momentum"
)
momentum_start
=
self
.
_get_hyper
(
"momentum_start"
)
momentum_warm_up_steps
=
tf
.
cast
(
self
.
_get_hyper
(
"warmup_steps"
),
iteration
.
dtype
)
value
=
tf
.
cond
(
(
iteration
-
momentum_warm_up_steps
)
<=
0
,
true_fn
=
lambda
:
(
momentum_start
+
# pylint: disable=g-long-lambda
(
tf
.
cast
(
iteration
,
momentum
.
dtype
)
*
(
momentum
-
momentum_start
)
/
tf
.
cast
(
momentum_warm_up_steps
,
momentum
.
dtype
))),
false_fn
=
lambda
:
momentum
)
return
value
def
_prepare_local
(
self
,
var_device
,
var_dtype
,
apply_state
):
super
(
SGDTorch
,
self
).
_prepare_local
(
var_device
,
var_dtype
,
apply_state
)
# pytype: disable=attribute-error
weight_decay
=
self
.
_get_hyper
(
"weight_decay"
)
apply_state
[(
var_device
,
var_dtype
)][
"weight_decay"
]
=
tf
.
cast
(
weight_decay
,
var_dtype
)
if
self
.
_momentum
:
momentum
=
self
.
_get_momentum
(
self
.
iterations
)
momentum
=
tf
.
cast
(
momentum
,
var_dtype
)
apply_state
[(
var_device
,
var_dtype
)][
"momentum"
]
=
tf
.
identity
(
momentum
)
bias_lr
=
self
.
_get_hyper
(
"bias_learning_rate"
)
if
isinstance
(
bias_lr
,
LearningRateSchedule
):
bias_lr
=
bias_lr
(
self
.
iterations
)
bias_lr
=
tf
.
cast
(
bias_lr
,
var_dtype
)
apply_state
[(
var_device
,
var_dtype
)][
"bias_lr_t"
]
=
tf
.
identity
(
bias_lr
)
other_lr
=
self
.
_get_hyper
(
"other_learning_rate"
)
if
isinstance
(
other_lr
,
LearningRateSchedule
):
other_lr
=
other_lr
(
self
.
iterations
)
other_lr
=
tf
.
cast
(
other_lr
,
var_dtype
)
apply_state
[(
var_device
,
var_dtype
)][
"other_lr_t"
]
=
tf
.
identity
(
other_lr
)
return
apply_state
[(
var_device
,
var_dtype
)]
def
_apply
(
self
,
grad
,
var
,
weight_decay
,
momentum
,
lr
):
"""Uses Pytorch Optimizer with Weight decay SGDW."""
dparams
=
grad
groups
=
[]
# do not update non-trainable weights
if
not
var
.
trainable
:
return
tf
.
group
(
*
groups
)
if
self
.
_weight_decay
:
dparams
+=
(
weight_decay
*
var
)
if
self
.
_momentum
:
momentum_var
=
self
.
get_slot
(
var
,
"momentum"
)
momentum_update
=
momentum_var
.
assign
(
momentum
*
momentum_var
+
dparams
,
use_locking
=
self
.
_use_locking
)
groups
.
append
(
momentum_update
)
if
self
.
nesterov
:
dparams
+=
(
momentum
*
momentum_update
)
else
:
dparams
=
momentum_update
weight_update
=
var
.
assign_add
(
-
lr
*
dparams
,
use_locking
=
self
.
_use_locking
)
groups
.
append
(
weight_update
)
return
tf
.
group
(
*
groups
)
def
_run_sgd
(
self
,
grad
,
var
,
apply_state
=
None
):
var_device
,
var_dtype
=
var
.
device
,
var
.
dtype
.
base_dtype
coefficients
=
((
apply_state
or
{}).
get
((
var_device
,
var_dtype
))
or
self
.
_fallback_apply_state
(
var_device
,
var_dtype
))
weights
,
bias
,
others
=
self
.
_get_variable_group
(
var
,
coefficients
)
weight_decay
=
tf
.
zeros_like
(
coefficients
[
"weight_decay"
])
lr
=
coefficients
[
"lr_t"
]
if
weights
:
weight_decay
=
coefficients
[
"weight_decay"
]
lr
=
coefficients
[
"lr_t"
]
elif
bias
:
weight_decay
=
tf
.
zeros_like
(
coefficients
[
"weight_decay"
])
lr
=
coefficients
[
"bias_lr_t"
]
elif
others
:
weight_decay
=
tf
.
zeros_like
(
coefficients
[
"weight_decay"
])
lr
=
coefficients
[
"other_lr_t"
]
momentum
=
coefficients
[
"momentum"
]
return
self
.
_apply
(
grad
,
var
,
weight_decay
,
momentum
,
lr
)
def
_resource_apply_dense
(
self
,
grad
,
var
,
apply_state
=
None
):
return
self
.
_run_sgd
(
grad
,
var
,
apply_state
=
apply_state
)
def
_resource_apply_sparse
(
self
,
grad
,
var
,
indices
,
apply_state
=
None
):
# This method is only needed for momentum optimization.
holder
=
tf
.
tensor_scatter_nd_add
(
tf
.
zeros_like
(
var
),
tf
.
expand_dims
(
indices
,
axis
=-
1
),
grad
)
return
self
.
_run_sgd
(
holder
,
var
,
apply_state
=
apply_state
)
def
get_config
(
self
):
config
=
super
(
SGDTorch
,
self
).
get_config
()
config
.
update
({
"learning_rate"
:
self
.
_serialize_hyperparameter
(
"learning_rate"
),
"decay"
:
self
.
_initial_decay
,
"momentum"
:
self
.
_serialize_hyperparameter
(
"momentum"
),
"momentum_start"
:
self
.
_serialize_hyperparameter
(
"momentum_start"
),
"warmup_steps"
:
self
.
_serialize_hyperparameter
(
"warmup_steps"
),
"nesterov"
:
self
.
nesterov
,
})
return
config
@
property
def
learning_rate
(
self
):
return
self
.
_optimizer
.
_get_hyper
(
"learning_rate"
)
# pylint: disable=protected-access
official/vision/beta/projects/yolo/tasks/task_utils.py
0 → 100644
View file @
b261ebb4
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for yolo task."""
import
tensorflow
as
tf
class
ListMetrics
:
"""Private class used to cleanly place the matric values for each level."""
def
__init__
(
self
,
metric_names
,
name
=
"ListMetrics"
):
self
.
name
=
name
self
.
_metric_names
=
metric_names
self
.
_metrics
=
self
.
build_metric
()
return
def
build_metric
(
self
):
metric_names
=
self
.
_metric_names
metrics
=
[]
for
name
in
metric_names
:
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
,
dtype
=
tf
.
float32
))
return
metrics
def
update_state
(
self
,
loss_metrics
):
metrics
=
self
.
_metrics
for
m
in
metrics
:
m
.
update_state
(
loss_metrics
[
m
.
name
])
return
def
result
(
self
):
logs
=
dict
()
metrics
=
self
.
_metrics
for
m
in
metrics
:
logs
.
update
({
m
.
name
:
m
.
result
()})
return
logs
def
reset_states
(
self
):
metrics
=
self
.
_metrics
for
m
in
metrics
:
m
.
reset_states
()
return
official/vision/beta/projects/yolo/tasks/yolo.py
0 → 100755
View file @
b261ebb4
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains classes used to train Yolo."""
import
collections
from
typing
import
Optional
from
absl
import
logging
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.core
import
input_reader
from
official.core
import
task_factory
from
official.modeling
import
performance
from
official.vision.beta.dataloaders
import
tfds_factory
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.projects.yolo
import
optimization
from
official.vision.beta.projects.yolo.configs
import
yolo
as
exp_cfg
from
official.vision.beta.projects.yolo.dataloaders
import
tf_example_decoder
from
official.vision.beta.projects.yolo.dataloaders
import
yolo_input
from
official.vision.beta.projects.yolo.modeling
import
factory
from
official.vision.beta.projects.yolo.ops
import
mosaic
from
official.vision.beta.projects.yolo.ops
import
preprocessing_ops
from
official.vision.beta.projects.yolo.tasks
import
task_utils
OptimizationConfig
=
optimization
.
OptimizationConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
@
task_factory
.
register_task_cls
(
exp_cfg
.
YoloTask
)
class
YoloTask
(
base_task
.
Task
):
"""A single-replica view of training procedure.
YOLO task provides artifacts for training/evalution procedures, including
loading/iterating over Datasets, initializing the model, calculating the loss,
post-processing, and customized metrics with reduction.
"""
def
__init__
(
self
,
params
,
logging_dir
:
Optional
[
str
]
=
None
):
super
().
__init__
(
params
,
logging_dir
)
self
.
coco_metric
=
None
self
.
_loss_fn
=
None
self
.
_model
=
None
self
.
_coco_91_to_80
=
False
self
.
_metrics
=
[]
# globally set the random seed
preprocessing_ops
.
set_random_seeds
(
seed
=
params
.
seed
)
return
def
build_model
(
self
):
"""Build an instance of Yolo."""
model_base_cfg
=
self
.
task_config
.
model
l2_weight_decay
=
self
.
task_config
.
weight_decay
/
2.0
input_size
=
model_base_cfg
.
input_size
.
copy
()
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
input_size
)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
)
if
l2_weight_decay
else
None
)
model
,
losses
=
factory
.
build_yolo
(
input_specs
,
model_base_cfg
,
l2_regularizer
)
# save for later usage within the task.
self
.
_loss_fn
=
losses
self
.
_model
=
model
return
model
def
_get_data_decoder
(
self
,
params
):
"""Get a decoder object to decode the dataset."""
if
params
.
tfds_name
:
decoder
=
tfds_factory
.
get_detection_decoder
(
params
.
tfds_name
)
else
:
decoder_cfg
=
params
.
decoder
.
get
()
if
params
.
decoder
.
type
==
'simple_decoder'
:
self
.
_coco_91_to_80
=
decoder_cfg
.
coco91_to_80
decoder
=
tf_example_decoder
.
TfExampleDecoder
(
coco91_to_80
=
decoder_cfg
.
coco91_to_80
,
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
elif
params
.
decoder
.
type
==
'label_map_decoder'
:
decoder
=
tf_example_label_map_decoder
.
TfExampleDecoderLabelMap
(
label_map
=
decoder_cfg
.
label_map
,
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
else
:
raise
ValueError
(
'Unknown decoder type: {}!'
.
format
(
params
.
decoder
.
type
))
return
decoder
def
build_inputs
(
self
,
params
,
input_context
=
None
):
"""Build input dataset."""
model
=
self
.
task_config
.
model
# get anchor boxes dict based on models min and max level
backbone
=
model
.
backbone
.
get
()
anchor_dict
,
level_limits
=
model
.
anchor_boxes
.
get
(
backbone
.
min_level
,
backbone
.
max_level
)
params
.
seed
=
self
.
task_config
.
seed
# set shared patamters between mosaic and yolo_input
base_config
=
dict
(
letter_box
=
params
.
parser
.
letter_box
,
aug_rand_translate
=
params
.
parser
.
aug_rand_translate
,
aug_rand_angle
=
params
.
parser
.
aug_rand_angle
,
aug_rand_perspective
=
params
.
parser
.
aug_rand_perspective
,
area_thresh
=
params
.
parser
.
area_thresh
,
random_flip
=
params
.
parser
.
random_flip
,
seed
=
params
.
seed
,
)
# get the decoder
decoder
=
self
.
_get_data_decoder
(
params
)
# init Mosaic
sample_fn
=
mosaic
.
Mosaic
(
output_size
=
model
.
input_size
,
mosaic_frequency
=
params
.
parser
.
mosaic
.
mosaic_frequency
,
mixup_frequency
=
params
.
parser
.
mosaic
.
mixup_frequency
,
jitter
=
params
.
parser
.
mosaic
.
jitter
,
mosaic_center
=
params
.
parser
.
mosaic
.
mosaic_center
,
mosaic_crop_mode
=
params
.
parser
.
mosaic
.
mosaic_crop_mode
,
aug_scale_min
=
params
.
parser
.
mosaic
.
aug_scale_min
,
aug_scale_max
=
params
.
parser
.
mosaic
.
aug_scale_max
,
**
base_config
)
# init Parser
parser
=
yolo_input
.
Parser
(
output_size
=
model
.
input_size
,
anchors
=
anchor_dict
,
use_tie_breaker
=
params
.
parser
.
use_tie_breaker
,
jitter
=
params
.
parser
.
jitter
,
aug_scale_min
=
params
.
parser
.
aug_scale_min
,
aug_scale_max
=
params
.
parser
.
aug_scale_max
,
aug_rand_hue
=
params
.
parser
.
aug_rand_hue
,
aug_rand_saturation
=
params
.
parser
.
aug_rand_saturation
,
aug_rand_brightness
=
params
.
parser
.
aug_rand_brightness
,
max_num_instances
=
params
.
parser
.
max_num_instances
,
scale_xy
=
model
.
detection_generator
.
scale_xy
.
get
(),
expanded_strides
=
model
.
detection_generator
.
path_scales
.
get
(),
darknet
=
model
.
darknet_based_model
,
best_match_only
=
params
.
parser
.
best_match_only
,
anchor_t
=
params
.
parser
.
anchor_thresh
,
random_pad
=
params
.
parser
.
random_pad
,
level_limits
=
level_limits
,
dtype
=
params
.
dtype
,
**
base_config
)
# init the dataset reader
reader
=
input_reader
.
InputReader
(
params
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
decoder_fn
=
decoder
.
decode
,
sample_fn
=
sample_fn
.
mosaic_fn
(
is_training
=
params
.
is_training
),
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
build_metrics
(
self
,
training
=
True
):
"""Build detection metrics."""
metrics
=
[]
backbone
=
self
.
task_config
.
model
.
backbone
.
get
()
metric_names
=
collections
.
defaultdict
(
list
)
for
key
in
range
(
backbone
.
min_level
,
backbone
.
max_level
+
1
):
key
=
str
(
key
)
metric_names
[
key
].
append
(
'loss'
)
metric_names
[
key
].
append
(
'avg_iou'
)
metric_names
[
key
].
append
(
'avg_obj'
)
metric_names
[
'net'
].
append
(
'box'
)
metric_names
[
'net'
].
append
(
'class'
)
metric_names
[
'net'
].
append
(
'conf'
)
for
_
,
key
in
enumerate
(
metric_names
.
keys
()):
metrics
.
append
(
task_utils
.
ListMetrics
(
metric_names
[
key
],
name
=
key
))
self
.
_metrics
=
metrics
if
not
training
:
annotation_file
=
self
.
task_config
.
annotation_file
if
self
.
_coco_91_to_80
:
annotation_file
=
None
self
.
coco_metric
=
coco_evaluator
.
COCOEvaluator
(
annotation_file
=
annotation_file
,
include_mask
=
False
,
need_rescale_bboxes
=
False
,
per_category_metrics
=
self
.
_task_config
.
per_category_metrics
)
return
metrics
def
build_losses
(
self
,
outputs
,
labels
,
aux_losses
=
None
):
"""Build YOLO losses."""
return
self
.
_loss_fn
(
labels
,
outputs
)
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
"""Train Step.
Forward step and backwards propagate the model.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
image
,
label
=
inputs
with
tf
.
GradientTape
(
persistent
=
False
)
as
tape
:
# Compute a prediction
y_pred
=
model
(
image
,
training
=
True
)
# Cast to float32 for gradietn computation
y_pred
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
y_pred
)
# Get the total loss
(
scaled_loss
,
metric_loss
,
loss_metrics
)
=
self
.
build_losses
(
y_pred
[
'raw_output'
],
label
)
# Scale the loss for numerical stability
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
# Compute the gradient
train_vars
=
model
.
trainable_variables
gradients
=
tape
.
gradient
(
scaled_loss
,
train_vars
)
# Get unscaled loss if we are using the loss scale optimizer on fp16
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
gradients
=
optimizer
.
get_unscaled_gradients
(
gradients
)
# Apply gradients to the model
optimizer
.
apply_gradients
(
zip
(
gradients
,
train_vars
))
logs
=
{
self
.
loss
:
metric_loss
}
# Compute all metrics
if
metrics
:
for
m
in
metrics
:
m
.
update_state
(
loss_metrics
[
m
.
name
])
logs
.
update
({
m
.
name
:
m
.
result
()})
return
logs
def
_reorg_boxes
(
self
,
boxes
,
num_detections
,
image
):
"""Scale and Clean boxes prior to Evaluation."""
# Build a prediciton mask to take only the number of detections
mask
=
tf
.
sequence_mask
(
num_detections
,
maxlen
=
tf
.
shape
(
boxes
)[
1
])
mask
=
tf
.
cast
(
tf
.
expand_dims
(
mask
,
axis
=-
1
),
boxes
.
dtype
)
# Denormalize the boxes by the shape of the image
inshape
=
tf
.
cast
(
preprocessing_ops
.
get_image_shape
(
image
),
boxes
.
dtype
)
boxes
=
box_ops
.
denormalize_boxes
(
boxes
,
inshape
)
# Mask the boxes for usage
boxes
*=
mask
boxes
+=
(
mask
-
1
)
return
boxes
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
image
,
label
=
inputs
# Step the model once
y_pred
=
model
(
image
,
training
=
False
)
y_pred
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
y_pred
)
(
_
,
metric_loss
,
loss_metrics
)
=
self
.
build_losses
(
y_pred
[
'raw_output'
],
label
)
logs
=
{
self
.
loss
:
metric_loss
}
# Reorganize and rescale the boxes
boxes
=
self
.
_reorg_boxes
(
y_pred
[
'bbox'
],
y_pred
[
'num_detections'
],
image
)
label
[
'groundtruths'
][
'boxes'
]
=
self
.
_reorg_boxes
(
label
[
'groundtruths'
][
'boxes'
],
label
[
'groundtruths'
][
'num_detections'
],
image
)
# Build the input for the coc evaluation metric
coco_model_outputs
=
{
'detection_boxes'
:
boxes
,
'detection_scores'
:
y_pred
[
'confidence'
],
'detection_classes'
:
y_pred
[
'classes'
],
'num_detections'
:
y_pred
[
'num_detections'
],
'source_id'
:
label
[
'groundtruths'
][
'source_id'
],
'image_info'
:
label
[
'groundtruths'
][
'image_info'
]
}
# Compute all metrics
if
metrics
:
logs
.
update
(
{
self
.
coco_metric
.
name
:
(
label
[
'groundtruths'
],
coco_model_outputs
)})
for
m
in
metrics
:
m
.
update_state
(
loss_metrics
[
m
.
name
])
logs
.
update
({
m
.
name
:
m
.
result
()})
return
logs
def
aggregate_logs
(
self
,
state
=
None
,
step_outputs
=
None
):
"""Get Metric Results."""
if
not
state
:
self
.
coco_metric
.
reset_states
()
state
=
self
.
coco_metric
self
.
coco_metric
.
update_state
(
step_outputs
[
self
.
coco_metric
.
name
][
0
],
step_outputs
[
self
.
coco_metric
.
name
][
1
])
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
):
"""Reduce logs and remove unneeded items. Update with COCO results."""
res
=
self
.
coco_metric
.
result
()
return
res
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loading pretrained checkpoint."""
if
not
self
.
task_config
.
init_checkpoint
:
logging
.
info
(
'Training from Scratch.'
)
return
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
# Restoring checkpoint.
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
ckpt_items
=
{}
if
'backbone'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
backbone
=
model
.
backbone
)
if
'decoder'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
decoder
=
model
.
decoder
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
ckpt_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
create_optimizer
(
self
,
optimizer_config
:
OptimizationConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
"""Creates an TF optimizer from configurations.
Args:
optimizer_config: the parameters of the Optimization settings.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
opt_factory
=
optimization
.
YoloOptimizerFactory
(
optimizer_config
)
# pylint: disable=protected-access
ema
=
opt_factory
.
_use_ema
opt_factory
.
_use_ema
=
False
opt_type
=
opt_factory
.
_optimizer_type
if
opt_type
==
'sgd_torch'
:
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
optimizer
.
set_bias_lr
(
opt_factory
.
get_bias_lr_schedule
(
self
.
_task_config
.
smart_bias_lr
))
optimizer
.
search_and_set_variable_groups
(
self
.
_model
.
trainable_variables
)
else
:
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
opt_factory
.
_use_ema
=
ema
if
ema
:
logging
.
info
(
'EMA is enabled.'
)
optimizer
=
opt_factory
.
add_ema
(
optimizer
)
# pylint: enable=protected-access
if
runtime_config
and
runtime_config
.
loss_scale
:
use_float16
=
runtime_config
.
mixed_precision_dtype
==
'float16'
optimizer
=
performance
.
configure_optimizer
(
optimizer
,
use_graph_rewrite
=
False
,
use_float16
=
use_float16
,
loss_scale
=
runtime_config
.
loss_scale
)
return
optimizer
official/vision/beta/projects/yolo/train.py
View file @
b261ebb4
...
...
@@ -12,62 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""TensorFlow Model Garden Vision training driver."""
from
absl
import
app
from
absl
import
flags
import
gin
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.vision.beta
import
train
from
official.vision.beta.projects.yolo.common
import
registry_imports
# pylint: disable=unused-import
FLAGS
=
flags
.
FLAGS
'''
python3 -m official.vision.beta.projects.yolo.train --mode=train_and_eval --experiment=darknet_classification --model_dir=training_dir --config_file=official/vision/beta/projects/yolo/configs/experiments/darknet53_tfds.yaml
'''
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
print
(
FLAGS
.
experiment
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
app
.
run
(
train
.
main
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment