Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
96eed905
Commit
96eed905
authored
Mar 29, 2021
by
A. Unique TensorFlower
Browse files
Adding gradient clipping for detection models.
PiperOrigin-RevId: 365639389
parent
1dc59163
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
6 deletions
+6
-6
official/vision/detection/configs/shapemask_config.py
official/vision/detection/configs/shapemask_config.py
+1
-1
official/vision/detection/executor/detection_executor.py
official/vision/detection/executor/detection_executor.py
+5
-5
No files found.
official/vision/detection/configs/shapemask_config.py
View file @
96eed905
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.hyperparams
import
params_dict
from
official.vision.detection.configs
import
base_config
from
official.vision.detection.configs
import
base_config
SHAPEMASK_RESNET_FROZEN_VAR_PREFIX
=
r
'(
resnet\d+/)
conv2d(|_([1-9]|10))\/'
SHAPEMASK_RESNET_FROZEN_VAR_PREFIX
=
r
'(conv2d(|_([1-9]|10))
|batch_normalization(|_([1-9]|10)))
\/'
SHAPEMASK_CFG
=
params_dict
.
ParamsDict
(
base_config
.
BASE_CFG
)
SHAPEMASK_CFG
=
params_dict
.
ParamsDict
(
base_config
.
BASE_CFG
)
SHAPEMASK_CFG
.
override
({
SHAPEMASK_CFG
.
override
({
...
...
official/vision/detection/executor/detection_executor.py
View file @
96eed905
...
@@ -63,10 +63,9 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
...
@@ -63,10 +63,9 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
trainable_variables
)
trainable_variables
)
logging
.
info
(
'Filter trainable variables from %d to %d'
,
logging
.
info
(
'Filter trainable variables from %d to %d'
,
len
(
model
.
trainable_variables
),
len
(
trainable_variables
))
len
(
model
.
trainable_variables
),
len
(
trainable_variables
))
_
update_state
=
lambda
labels
,
outputs
:
None
update_state
_fn
=
lambda
labels
,
outputs
:
None
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
if
isinstance
(
metric
,
tf
.
keras
.
metrics
.
Metric
):
_update_state
=
lambda
labels
,
outputs
:
metric
.
update_state
(
update_state_fn
=
metric
.
update_state
labels
,
outputs
)
else
:
else
:
logging
.
error
(
'Detection: train metric is not an instance of '
logging
.
error
(
'Detection: train metric is not an instance of '
'tf.keras.metrics.Metric.'
)
'tf.keras.metrics.Metric.'
)
...
@@ -82,10 +81,11 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
...
@@ -82,10 +81,11 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
for
k
,
v
in
all_losses
.
items
():
for
k
,
v
in
all_losses
.
items
():
losses
[
k
]
=
tf
.
reduce_mean
(
v
)
losses
[
k
]
=
tf
.
reduce_mean
(
v
)
per_replica_loss
=
losses
[
'total_loss'
]
/
strategy
.
num_replicas_in_sync
per_replica_loss
=
losses
[
'total_loss'
]
/
strategy
.
num_replicas_in_sync
_
update_state
(
labels
,
outputs
)
update_state
_fn
(
labels
,
outputs
)
grads
=
tape
.
gradient
(
per_replica_loss
,
trainable_variables
)
grads
=
tape
.
gradient
(
per_replica_loss
,
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
grads
,
trainable_variables
))
clipped_grads
,
_
=
tf
.
clip_by_global_norm
(
grads
,
clip_norm
=
1.0
)
optimizer
.
apply_gradients
(
zip
(
clipped_grads
,
trainable_variables
))
return
losses
return
losses
return
_replicated_step
return
_replicated_step
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment