Commit 0daae829 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

change on track_reid classification net to solve dimension error in centerMOT pipeline

PiperOrigin-RevId: 394505975
parent 27bd23e5
...@@ -2898,16 +2898,12 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2898,16 +2898,12 @@ class CenterNetMetaArch(model.DetectionModel):
self.track_reid_classification_net = tf.keras.Sequential() self.track_reid_classification_net = tf.keras.Sequential()
for _ in range(self._track_params.num_fc_layers - 1): for _ in range(self._track_params.num_fc_layers - 1):
self.track_reid_classification_net.add( self.track_reid_classification_net.add(
tf.keras.layers.Dense(self._track_params.reid_embed_size, tf.keras.layers.Dense(self._track_params.reid_embed_size))
input_shape=(
self._track_params.reid_embed_size,)))
self.track_reid_classification_net.add( self.track_reid_classification_net.add(
tf.keras.layers.BatchNormalization()) tf.keras.layers.BatchNormalization())
self.track_reid_classification_net.add(tf.keras.layers.ReLU()) self.track_reid_classification_net.add(tf.keras.layers.ReLU())
self.track_reid_classification_net.add( self.track_reid_classification_net.add(
tf.keras.layers.Dense(self._track_params.num_track_ids, tf.keras.layers.Dense(self._track_params.num_track_ids))
input_shape=(
self._track_params.reid_embed_size,)))
if self._temporal_offset_params is not None: if self._temporal_offset_params is not None:
prediction_heads[TEMPORAL_OFFSET] = self._make_prediction_net_list( prediction_heads[TEMPORAL_OFFSET] = self._make_prediction_net_list(
num_feature_outputs, NUM_OFFSET_CHANNELS, name='temporal_offset', num_feature_outputs, NUM_OFFSET_CHANNELS, name='temporal_offset',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment