Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
307a8194
Commit
307a8194
authored
Aug 08, 2020
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Aug 08, 2020
Browse files
Use functions from util map.
PiperOrigin-RevId: 325621239
parent
69221551
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
3 deletions
+42
-3
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+4
-0
research/object_detection/model_lib_v2.py
research/object_detection/model_lib_v2.py
+2
-3
research/object_detection/protos/center_net.proto
research/object_detection/protos/center_net.proto
+36
-0
No files found.
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
307a8194
...
@@ -2583,6 +2583,9 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2583,6 +2583,9 @@ class CenterNetMetaArch(model.DetectionModel):
detections: a dictionary containing the following fields
detections: a dictionary containing the following fields
detection_boxes - A tensor of shape [batch, max_detections, 4]
detection_boxes - A tensor of shape [batch, max_detections, 4]
holding the predicted boxes.
holding the predicted boxes.
detection_boxes_strided: A tensor of shape [batch_size, num_detections,
4] holding the predicted boxes in absolute coordinates of the
feature extractor's final layer output.
detection_scores: A tensor of shape [batch, max_detections] holding
detection_scores: A tensor of shape [batch, max_detections] holding
the predicted score for each box.
the predicted score for each box.
detection_classes: An integer tensor of shape [batch, max_detections]
detection_classes: An integer tensor of shape [batch, max_detections]
...
@@ -2626,6 +2629,7 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2626,6 +2629,7 @@ class CenterNetMetaArch(model.DetectionModel):
fields
.
DetectionResultFields
.
detection_scores
:
scores
,
fields
.
DetectionResultFields
.
detection_scores
:
scores
,
fields
.
DetectionResultFields
.
detection_classes
:
classes
,
fields
.
DetectionResultFields
.
detection_classes
:
classes
,
fields
.
DetectionResultFields
.
num_detections
:
num_detections
,
fields
.
DetectionResultFields
.
num_detections
:
num_detections
,
'detection_boxes_strided'
:
boxes_strided
}
}
if
self
.
_kp_params_dict
:
if
self
.
_kp_params_dict
:
...
...
research/object_detection/model_lib_v2.py
View file @
307a8194
...
@@ -28,7 +28,6 @@ import tensorflow.compat.v2 as tf2
...
@@ -28,7 +28,6 @@ import tensorflow.compat.v2 as tf2
from
object_detection
import
eval_util
from
object_detection
import
eval_util
from
object_detection
import
inputs
from
object_detection
import
inputs
from
object_detection
import
model_lib
from
object_detection
import
model_lib
from
object_detection.builders
import
model_builder
from
object_detection.builders
import
optimizer_builder
from
object_detection.builders
import
optimizer_builder
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.protos
import
train_pb2
from
object_detection.protos
import
train_pb2
...
@@ -503,7 +502,7 @@ def train_loop(
...
@@ -503,7 +502,7 @@ def train_loop(
# Build the model, optimizer, and training input
# Build the model, optimizer, and training input
strategy
=
tf
.
compat
.
v2
.
distribute
.
get_strategy
()
strategy
=
tf
.
compat
.
v2
.
distribute
.
get_strategy
()
with
strategy
.
scope
():
with
strategy
.
scope
():
detection_model
=
model_builder
.
build
(
detection_model
=
MODEL_BUILD_UTIL_MAP
[
'detection_model_fn_base'
]
(
model_config
=
model_config
,
is_training
=
True
)
model_config
=
model_config
,
is_training
=
True
)
def
train_dataset_fn
(
input_context
):
def
train_dataset_fn
(
input_context
):
...
@@ -939,7 +938,7 @@ def eval_continuously(
...
@@ -939,7 +938,7 @@ def eval_continuously(
if
kwargs
[
'use_bfloat16'
]:
if
kwargs
[
'use_bfloat16'
]:
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_bfloat16'
)
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_bfloat16'
)
detection_model
=
model_builder
.
build
(
detection_model
=
MODEL_BUILD_UTIL_MAP
[
'detection_model_fn_base'
]
(
model_config
=
model_config
,
is_training
=
True
)
model_config
=
model_config
,
is_training
=
True
)
# Create the inputs.
# Create the inputs.
...
...
research/object_detection/protos/center_net.proto
View file @
307a8194
...
@@ -244,6 +244,42 @@ message CenterNet {
...
@@ -244,6 +244,42 @@ message CenterNet {
optional
ClassificationLoss
classification_loss
=
5
;
optional
ClassificationLoss
classification_loss
=
5
;
}
}
optional
TrackEstimation
track_estimation_task
=
10
;
optional
TrackEstimation
track_estimation_task
=
10
;
// BEGIN GOOGLE-INTERNAL
// Experimental Occupancy network head, use with caution.
message
OccupancyNetMaskPrediction
{
// The loss used for penalizing mask predictions.
optional
ClassificationLoss
classification_loss
=
1
;
// Number of points to sample within a box while training occupancy net.
optional
int32
num_samples
=
2
[
default
=
1000
];
// The dimension of the occupancy embedding.
optional
int32
dim
=
3
[
default
=
256
];
// Weight of occupancy embedding loss.
optional
float
task_loss_weight
=
4
[
default
=
1.0
];
// The stride in pixels at test time when computing the mask. THis is
// useful is computing the full mask is too expensive.
optional
int32
mask_stride
=
5
[
default
=
1
];
// If set, concatenate the occupancy embedding features to (x, y)
// coordinates before feeding it to the occupancy network head.
optional
bool
concat_features
=
6
[
default
=
true
];
// If set to a positive value, defines the length to which the embedding
// is clipped before concatenating to the (x, y) coordinates when
// concat_features=true.
optional
int32
concat_clip
=
7
[
default
=
-
1
];
// The probability threshold to apply for masks to output a binary mask.
optional
float
mask_prob_threshold
=
8
[
default
=
0.5
];
}
optional
OccupancyNetMaskPrediction
occupancy_net_mask_prediction
=
11
;
// EBD GOOGLE-INTERNAL
}
}
message
CenterNetFeatureExtractor
{
message
CenterNetFeatureExtractor
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment