Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
06eec91c
Commit
06eec91c
authored
Nov 01, 2021
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Nov 01, 2021
Browse files
Internal change
PiperOrigin-RevId: 406973172
parent
c4ebfef2
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
76 additions
and
43 deletions
+76
-43
official/modeling/multitask/base_trainer.py
official/modeling/multitask/base_trainer.py
+7
-0
official/vision/beta/configs/image_classification.py
official/vision/beta/configs/image_classification.py
+1
-0
official/vision/beta/configs/maskrcnn.py
official/vision/beta/configs/maskrcnn.py
+1
-0
official/vision/beta/configs/retinanet.py
official/vision/beta/configs/retinanet.py
+1
-0
official/vision/beta/configs/semantic_segmentation.py
official/vision/beta/configs/semantic_segmentation.py
+1
-0
official/vision/beta/modeling/factory.py
official/vision/beta/modeling/factory.py
+58
-43
official/vision/beta/projects/vit/configs/image_classification.py
.../vision/beta/projects/vit/configs/image_classification.py
+1
-0
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+1
-0
official/vision/beta/tasks/maskrcnn.py
official/vision/beta/tasks/maskrcnn.py
+1
-0
official/vision/beta/tasks/retinanet.py
official/vision/beta/tasks/retinanet.py
+2
-0
official/vision/beta/tasks/semantic_segmentation.py
official/vision/beta/tasks/semantic_segmentation.py
+2
-0
No files found.
official/modeling/multitask/base_trainer.py
View file @
06eec91c
...
...
@@ -17,10 +17,12 @@
The trainer derives from the Orbit `StandardTrainer` class.
"""
from
typing
import
Union
import
gin
import
orbit
import
tensorflow
as
tf
from
official.modeling
import
optimization
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
multitask
...
...
@@ -45,6 +47,11 @@ class MultiTaskBaseTrainer(orbit.StandardTrainer):
self
.
_training_metrics
=
None
self
.
_global_step
=
orbit
.
utils
.
create_global_step
()
# Creates a shadow copy of the weights to store weights moving average.
if
isinstance
(
self
.
_optimizer
,
optimization
.
ExponentialMovingAverage
)
and
not
self
.
_optimizer
.
has_shadow_copy
:
self
.
_optimizer
.
shadow_copy
(
multi_task_model
)
if
hasattr
(
self
.
multi_task_model
,
"checkpoint_items"
):
checkpoint_items
=
self
.
multi_task_model
.
checkpoint_items
else
:
...
...
official/vision/beta/configs/image_classification.py
View file @
06eec91c
...
...
@@ -70,6 +70,7 @@ class ImageClassificationModel(hyperparams.Config):
@
dataclasses
.
dataclass
class
Losses
(
hyperparams
.
Config
):
loss_weight
:
float
=
1.0
one_hot
:
bool
=
True
label_smoothing
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
...
...
official/vision/beta/configs/maskrcnn.py
View file @
06eec91c
...
...
@@ -185,6 +185,7 @@ class MaskRCNN(hyperparams.Config):
@
dataclasses
.
dataclass
class
Losses
(
hyperparams
.
Config
):
loss_weight
:
float
=
1.0
rpn_huber_loss_delta
:
float
=
1.
/
9.
frcnn_huber_loss_delta
:
float
=
1.
l2_weight_decay
:
float
=
0.0
...
...
official/vision/beta/configs/retinanet.py
View file @
06eec91c
...
...
@@ -83,6 +83,7 @@ class Anchor(hyperparams.Config):
@
dataclasses
.
dataclass
class
Losses
(
hyperparams
.
Config
):
loss_weight
:
float
=
1.0
focal_loss_alpha
:
float
=
0.25
focal_loss_gamma
:
float
=
1.5
huber_loss_delta
:
float
=
0.1
...
...
official/vision/beta/configs/semantic_segmentation.py
View file @
06eec91c
...
...
@@ -92,6 +92,7 @@ class SemanticSegmentationModel(hyperparams.Config):
@
dataclasses
.
dataclass
class
Losses
(
hyperparams
.
Config
):
loss_weight
:
float
=
1.0
label_smoothing
:
float
=
0.0
ignore_label
:
int
=
255
class_weights
:
List
[
float
]
=
dataclasses
.
field
(
default_factory
=
list
)
...
...
official/vision/beta/modeling/factory.py
View file @
06eec91c
...
...
@@ -14,7 +14,7 @@
"""Factory methods to build models."""
# Import libraries
from
typing
import
Optional
import
tensorflow
as
tf
...
...
@@ -41,10 +41,12 @@ from official.vision.beta.modeling.layers import roi_sampler
def
build_classification_model
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
classification_cfg
.
ImageClassificationModel
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
,
skip_logits_layer
:
bool
=
False
)
->
tf
.
keras
.
Model
:
# pytype: disable=annotation-type-mismatch # typed-keras
l2_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
skip_logits_layer
:
bool
=
False
,
backbone
:
Optional
[
tf
.
keras
.
Model
]
=
None
)
->
tf
.
keras
.
Model
:
"""Builds the classification model."""
norm_activation_config
=
model_config
.
norm_activation
if
not
backbone
:
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
backbone_config
=
model_config
.
backbone
,
...
...
@@ -66,12 +68,15 @@ def build_classification_model(
return
model
def
build_maskrcnn
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
def
build_maskrcnn
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
maskrcnn_cfg
.
MaskRCNN
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
# pytype: disable=annotation-type-mismatch # typed-keras
l2_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
backbone
:
Optional
[
tf
.
keras
.
Model
]
=
None
,
decoder
:
Optional
[
tf
.
keras
.
Model
]
=
None
)
->
tf
.
keras
.
Model
:
"""Builds Mask R-CNN model."""
norm_activation_config
=
model_config
.
norm_activation
if
not
backbone
:
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
backbone_config
=
model_config
.
backbone
,
...
...
@@ -79,6 +84,7 @@ def build_maskrcnn(
l2_regularizer
=
l2_regularizer
)
backbone_features
=
backbone
(
tf
.
keras
.
Input
(
input_specs
.
shape
[
1
:]))
if
not
decoder
:
decoder
=
decoders
.
factory
.
build_decoder
(
input_specs
=
backbone
.
output_specs
,
model_config
=
model_config
,
...
...
@@ -121,7 +127,6 @@ def build_maskrcnn(
kernel_regularizer
=
l2_regularizer
,
name
=
'detection_head'
)
# Builds decoder and region proposal network:
if
decoder
:
decoder_features
=
decoder
(
backbone_features
)
rpn_head
(
decoder_features
)
...
...
@@ -253,9 +258,13 @@ def build_maskrcnn(
def
build_retinanet
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
retinanet_cfg
.
RetinaNet
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
# pytype: disable=annotation-type-mismatch # typed-keras
l2_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
backbone
:
Optional
[
tf
.
keras
.
Model
]
=
None
,
decoder
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
)
->
tf
.
keras
.
Model
:
"""Builds RetinaNet model."""
norm_activation_config
=
model_config
.
norm_activation
if
not
backbone
:
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
backbone_config
=
model_config
.
backbone
,
...
...
@@ -263,6 +272,7 @@ def build_retinanet(
l2_regularizer
=
l2_regularizer
)
backbone_features
=
backbone
(
tf
.
keras
.
Input
(
input_specs
.
shape
[
1
:]))
if
not
decoder
:
decoder
=
decoders
.
factory
.
build_decoder
(
input_specs
=
backbone
.
output_specs
,
model_config
=
model_config
,
...
...
@@ -321,15 +331,20 @@ def build_retinanet(
def
build_segmentation_model
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
segmentation_cfg
.
SemanticSegmentationModel
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
# pytype: disable=annotation-type-mismatch # typed-keras
l2_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
backbone
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
decoder
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
)
->
tf
.
keras
.
Model
:
"""Builds Segmentation model."""
norm_activation_config
=
model_config
.
norm_activation
if
not
backbone
:
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
backbone_config
=
model_config
.
backbone
,
norm_activation_config
=
norm_activation_config
,
l2_regularizer
=
l2_regularizer
)
if
not
decoder
:
decoder
=
decoders
.
factory
.
build_decoder
(
input_specs
=
backbone
.
output_specs
,
model_config
=
model_config
,
...
...
official/vision/beta/projects/vit/configs/image_classification.py
View file @
06eec91c
...
...
@@ -49,6 +49,7 @@ class ImageClassificationModel(hyperparams.Config):
@
dataclasses
.
dataclass
class
Losses
(
hyperparams
.
Config
):
loss_weight
:
float
=
1.0
one_hot
:
bool
=
True
label_smoothing
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
...
...
official/vision/beta/tasks/image_classification.py
View file @
06eec91c
...
...
@@ -169,6 +169,7 @@ class ImageClassificationTask(base_task.Task):
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
total_loss
=
losses_config
.
loss_weight
*
total_loss
return
total_loss
def
build_metrics
(
self
,
...
...
official/vision/beta/tasks/maskrcnn.py
View file @
06eec91c
...
...
@@ -236,6 +236,7 @@ class MaskRCNNTask(base_task.Task):
reg_loss
=
tf
.
reduce_sum
(
aux_losses
)
total_loss
=
model_loss
+
reg_loss
total_loss
=
params
.
losses
.
loss_weight
*
total_loss
losses
=
{
'total_loss'
:
total_loss
,
'rpn_score_loss'
:
rpn_score_loss
,
...
...
official/vision/beta/tasks/retinanet.py
View file @
06eec91c
...
...
@@ -220,6 +220,8 @@ class RetinaNetTask(base_task.Task):
reg_loss
=
tf
.
reduce_sum
(
aux_losses
)
total_loss
=
model_loss
+
reg_loss
total_loss
=
params
.
losses
.
loss_weight
*
total_loss
return
total_loss
,
cls_loss
,
box_loss
,
model_loss
def
build_metrics
(
self
,
training
:
bool
=
True
):
...
...
official/vision/beta/tasks/semantic_segmentation.py
View file @
06eec91c
...
...
@@ -140,6 +140,8 @@ class SemanticSegmentationTask(base_task.Task):
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
total_loss
=
loss_params
.
loss_weight
*
total_loss
return
total_loss
def
build_metrics
(
self
,
training
:
bool
=
True
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment