Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
307a8194
"...resnet50_tensorflow.git" did not exist on "e934a4adbd427d1b1d37fc01a422380253caa84b"
Commit
307a8194
authored
Aug 08, 2020
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Aug 08, 2020
Browse files
Use functions from util map.
PiperOrigin-RevId: 325621239
parent
69221551
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
3 deletions
+42
-3
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+4
-0
research/object_detection/model_lib_v2.py
research/object_detection/model_lib_v2.py
+2
-3
research/object_detection/protos/center_net.proto
research/object_detection/protos/center_net.proto
+36
-0
No files found.
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
307a8194
...
@@ -2583,6 +2583,9 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2583,6 +2583,9 @@ class CenterNetMetaArch(model.DetectionModel):
detections: a dictionary containing the following fields
detections: a dictionary containing the following fields
detection_boxes - A tensor of shape [batch, max_detections, 4]
detection_boxes - A tensor of shape [batch, max_detections, 4]
holding the predicted boxes.
holding the predicted boxes.
detection_boxes_strided: A tensor of shape [batch_size, num_detections,
4] holding the predicted boxes in absolute coordinates of the
feature extractor's final layer output.
detection_scores: A tensor of shape [batch, max_detections] holding
detection_scores: A tensor of shape [batch, max_detections] holding
the predicted score for each box.
the predicted score for each box.
detection_classes: An integer tensor of shape [batch, max_detections]
detection_classes: An integer tensor of shape [batch, max_detections]
...
@@ -2626,6 +2629,7 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2626,6 +2629,7 @@ class CenterNetMetaArch(model.DetectionModel):
fields
.
DetectionResultFields
.
detection_scores
:
scores
,
fields
.
DetectionResultFields
.
detection_scores
:
scores
,
fields
.
DetectionResultFields
.
detection_classes
:
classes
,
fields
.
DetectionResultFields
.
detection_classes
:
classes
,
fields
.
DetectionResultFields
.
num_detections
:
num_detections
,
fields
.
DetectionResultFields
.
num_detections
:
num_detections
,
'detection_boxes_strided'
:
boxes_strided
}
}
if
self
.
_kp_params_dict
:
if
self
.
_kp_params_dict
:
...
...
research/object_detection/model_lib_v2.py
View file @
307a8194
...
@@ -28,7 +28,6 @@ import tensorflow.compat.v2 as tf2
...
@@ -28,7 +28,6 @@ import tensorflow.compat.v2 as tf2
from
object_detection
import
eval_util
from
object_detection
import
eval_util
from
object_detection
import
inputs
from
object_detection
import
inputs
from
object_detection
import
model_lib
from
object_detection
import
model_lib
from
object_detection.builders
import
model_builder
from
object_detection.builders
import
optimizer_builder
from
object_detection.builders
import
optimizer_builder
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.protos
import
train_pb2
from
object_detection.protos
import
train_pb2
...
@@ -503,7 +502,7 @@ def train_loop(
...
@@ -503,7 +502,7 @@ def train_loop(
# Build the model, optimizer, and training input
# Build the model, optimizer, and training input
strategy
=
tf
.
compat
.
v2
.
distribute
.
get_strategy
()
strategy
=
tf
.
compat
.
v2
.
distribute
.
get_strategy
()
with
strategy
.
scope
():
with
strategy
.
scope
():
detection_model
=
model_builder
.
build
(
detection_model
=
MODEL_BUILD_UTIL_MAP
[
'detection_model_fn_base'
]
(
model_config
=
model_config
,
is_training
=
True
)
model_config
=
model_config
,
is_training
=
True
)
def
train_dataset_fn
(
input_context
):
def
train_dataset_fn
(
input_context
):
...
@@ -939,7 +938,7 @@ def eval_continuously(
...
@@ -939,7 +938,7 @@ def eval_continuously(
if
kwargs
[
'use_bfloat16'
]:
if
kwargs
[
'use_bfloat16'
]:
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_bfloat16'
)
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'mixed_bfloat16'
)
detection_model
=
model_builder
.
build
(
detection_model
=
MODEL_BUILD_UTIL_MAP
[
'detection_model_fn_base'
]
(
model_config
=
model_config
,
is_training
=
True
)
model_config
=
model_config
,
is_training
=
True
)
# Create the inputs.
# Create the inputs.
...
...
research/object_detection/protos/center_net.proto
View file @
307a8194
...
@@ -244,6 +244,42 @@ message CenterNet {
...
@@ -244,6 +244,42 @@ message CenterNet {
optional
ClassificationLoss
classification_loss
=
5
;
optional
ClassificationLoss
classification_loss
=
5
;
}
}
optional
TrackEstimation
track_estimation_task
=
10
;
optional
TrackEstimation
track_estimation_task
=
10
;
// BEGIN GOOGLE-INTERNAL
// Experimental Occupancy network head, use with caution.
message
OccupancyNetMaskPrediction
{
// The loss used for penalizing mask predictions.
optional
ClassificationLoss
classification_loss
=
1
;
// Number of points to sample within a box while training occupancy net.
optional
int32
num_samples
=
2
[
default
=
1000
];
// The dimension of the occupancy embedding.
optional
int32
dim
=
3
[
default
=
256
];
// Weight of occupancy embedding loss.
optional
float
task_loss_weight
=
4
[
default
=
1.0
];
// The stride in pixels at test time when computing the mask. THis is
// useful is computing the full mask is too expensive.
optional
int32
mask_stride
=
5
[
default
=
1
];
// If set, concatenate the occupancy embedding features to (x, y)
// coordinates before feeding it to the occupancy network head.
optional
bool
concat_features
=
6
[
default
=
true
];
// If set to a positive value, defines the length to which the embedding
// is clipped before concatenating to the (x, y) coordinates when
// concat_features=true.
optional
int32
concat_clip
=
7
[
default
=
-
1
];
// The probability threshold to apply for masks to output a binary mask.
optional
float
mask_prob_threshold
=
8
[
default
=
0.5
];
}
optional
OccupancyNetMaskPrediction
occupancy_net_mask_prediction
=
11
;
// EBD GOOGLE-INTERNAL
}
}
message
CenterNetFeatureExtractor
{
message
CenterNetFeatureExtractor
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment