Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
aef943ed
Unverified
Commit
aef943ed
authored
May 09, 2022
by
SunJong Park
Committed by
GitHub
May 09, 2022
Browse files
Merge branch 'tensorflow:master' into master
parents
67ad909d
930abe21
Changes
74
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
151 additions
and
43 deletions
+151
-43
official/nlp/modeling/networks/xlnet_base.py
official/nlp/modeling/networks/xlnet_base.py
+3
-2
official/nlp/modeling/ops/beam_search.py
official/nlp/modeling/ops/beam_search.py
+1
-1
official/projects/edgetpu/nlp/modeling/edgetpu_layers.py
official/projects/edgetpu/nlp/modeling/edgetpu_layers.py
+1
-1
official/projects/edgetpu/vision/tasks/image_classification.py
...ial/projects/edgetpu/vision/tasks/image_classification.py
+2
-2
official/projects/qat/vision/configs/experiments/retinanet/coco_spinenet49_mobile_qat_tpu_e2e.yaml
...riments/retinanet/coco_spinenet49_mobile_qat_tpu_e2e.yaml
+1
-1
official/projects/qat/vision/modeling/factory.py
official/projects/qat/vision/modeling/factory.py
+8
-0
official/projects/qat/vision/modeling/factory_test.py
official/projects/qat/vision/modeling/factory_test.py
+20
-5
official/projects/qat/vision/quantization/helper.py
official/projects/qat/vision/quantization/helper.py
+45
-0
official/projects/qat/vision/quantization/helper_test.py
official/projects/qat/vision/quantization/helper_test.py
+54
-0
official/projects/qat/vision/quantization/layer_transforms.py
...cial/projects/qat/vision/quantization/layer_transforms.py
+2
-23
official/projects/qat/vision/tasks/retinanet.py
official/projects/qat/vision/tasks/retinanet.py
+4
-0
official/projects/qat/vision/tasks/retinanet_test.py
official/projects/qat/vision/tasks/retinanet_test.py
+1
-0
official/projects/video_ssl/modeling/video_ssl_model.py
official/projects/video_ssl/modeling/video_ssl_model.py
+1
-1
official/projects/yt8m/dataloaders/yt8m_input.py
official/projects/yt8m/dataloaders/yt8m_input.py
+1
-1
official/projects/yt8m/train_test.py
official/projects/yt8m/train_test.py
+2
-1
official/vision/beta/projects/centernet/README.md
official/vision/beta/projects/centernet/README.md
+1
-1
official/vision/beta/projects/simclr/README.md
official/vision/beta/projects/simclr/README.md
+1
-1
official/vision/beta/projects/yolo/losses/yolo_loss.py
official/vision/beta/projects/yolo/losses/yolo_loss.py
+1
-1
official/vision/beta/projects/yolo/modeling/decoders/yolo_decoder.py
...sion/beta/projects/yolo/modeling/decoders/yolo_decoder.py
+1
-1
official/vision/beta/projects/yolo/modeling/layers/nn_blocks.py
...al/vision/beta/projects/yolo/modeling/layers/nn_blocks.py
+1
-1
No files found.
official/nlp/modeling/networks/xlnet_base.py
View file @
aef943ed
...
...
@@ -18,6 +18,7 @@ from absl import logging
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.layers
import
transformer_xl
...
...
@@ -507,7 +508,7 @@ class XLNetBase(tf.keras.layers.Layer):
self
.
_embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
self
.
_vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
self
.
_initializer
,
initializer
=
tf_utils
.
clone_initializer
(
self
.
_initializer
)
,
dtype
=
tf
.
float32
,
name
=
"word_embedding"
)
self
.
_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
...
...
@@ -666,7 +667,7 @@ class XLNetBase(tf.keras.layers.Layer):
shape
=
[
self
.
_num_layers
,
2
,
self
.
_num_attention_heads
,
self
.
_head_size
],
dtype
=
tf
.
float32
,
initializer
=
self
.
_initializer
)
initializer
=
tf_utils
.
clone_initializer
(
self
.
_initializer
)
)
segment_embedding
=
self
.
_segment_embedding
segment_matrix
=
_compute_segment_matrix
(
...
...
official/nlp/modeling/ops/beam_search.py
View file @
aef943ed
...
...
@@ -204,7 +204,7 @@ class SequenceBeamSearch(tf.Module):
candidate_log_probs
=
_log_prob_from_logits
(
logits
)
# Calculate new log probabilities if each of the alive sequences were
# extended # by the
the
candidate IDs.
# extended # by the candidate IDs.
# Shape [batch_size, beam_size, vocab_size]
log_probs
=
candidate_log_probs
+
tf
.
expand_dims
(
alive_log_probs
,
axis
=
2
)
...
...
official/projects/edgetpu/nlp/modeling/edgetpu_layers.py
View file @
aef943ed
...
...
@@ -123,7 +123,7 @@ class EdgeTPUMultiHeadAttention(tf.keras.layers.MultiHeadAttention):
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to
c
o
stomize attention computation to replace the default dot-product
c
u
stomize attention computation to replace the default dot-product
attention.
Args:
...
...
official/projects/edgetpu/vision/tasks/image_classification.py
View file @
aef943ed
...
...
@@ -265,7 +265,7 @@ class EdgeTPUTask(base_task.Task):
"""Does forward and backward.
Args:
inputs: A tuple of
of
input tensors of (features, labels).
inputs: A tuple of input tensors of (features, labels).
model: A tf.keras.Model instance.
optimizer: The optimizer for this training step.
metrics: A nested structure of metrics objects.
...
...
@@ -319,7 +319,7 @@ class EdgeTPUTask(base_task.Task):
"""Runs validatation step.
Args:
inputs: A tuple of
of
input tensors of (features, labels).
inputs: A tuple of input tensors of (features, labels).
model: A tf.keras.Model instance.
metrics: A nested structure of metrics objects.
...
...
official/projects/qat/vision/configs/experiments/retinanet/coco_spinenet49_mobile_qat_tpu_e2e.yaml
View file @
aef943ed
# --experiment_type=retinanet_spinenet_mobile_coco_qat
# COCO mAP: 2
2.0
# COCO mAP: 2
3.2
# QAT only supports float32 tpu due to fake-quant op.
runtime
:
distribution_strategy
:
'
tpu'
...
...
official/projects/qat/vision/modeling/factory.py
View file @
aef943ed
...
...
@@ -22,6 +22,7 @@ from official.projects.qat.vision.configs import common
from
official.projects.qat.vision.modeling
import
segmentation_model
as
qat_segmentation_model
from
official.projects.qat.vision.modeling.heads
import
dense_prediction_heads
as
dense_prediction_heads_qat
from
official.projects.qat.vision.n_bit
import
schemes
as
n_bit_schemes
from
official.projects.qat.vision.quantization
import
helper
from
official.projects.qat.vision.quantization
import
schemes
from
official.vision
import
configs
from
official.vision.modeling
import
classification_model
...
...
@@ -157,6 +158,7 @@ def build_qat_retinanet(
head
=
(
dense_prediction_heads_qat
.
RetinaNetHeadQuantized
.
from_config
(
head
.
get_config
()))
optimized_model
=
retinanet_model
.
RetinaNetModel
(
optimized_backbone
,
model
.
decoder
,
...
...
@@ -167,6 +169,12 @@ def build_qat_retinanet(
num_scales
=
model_config
.
anchor
.
num_scales
,
aspect_ratios
=
model_config
.
anchor
.
aspect_ratios
,
anchor_size
=
model_config
.
anchor
.
anchor_size
)
if
quantization
.
quantize_detection_head
:
# Call the model with dummy input to build the head part.
dummpy_input
=
tf
.
zeros
([
1
]
+
model_config
.
input_size
)
optimized_model
(
dummpy_input
,
training
=
True
)
helper
.
copy_original_weights
(
model
.
head
,
optimized_model
.
head
)
return
optimized_model
...
...
official/projects/qat/vision/modeling/factory_test.py
View file @
aef943ed
...
...
@@ -21,12 +21,14 @@ import tensorflow as tf
from
official.projects.qat.vision.configs
import
common
from
official.projects.qat.vision.modeling
import
factory
as
qat_factory
from
official.projects.qat.vision.modeling.heads
import
dense_prediction_heads
as
qat_dense_prediction_heads
from
official.vision.configs
import
backbones
from
official.vision.configs
import
decoders
from
official.vision.configs
import
image_classification
as
classification_cfg
from
official.vision.configs
import
retinanet
as
retinanet_cfg
from
official.vision.configs
import
semantic_segmentation
as
semantic_segmentation_cfg
from
official.vision.modeling
import
factory
from
official.vision.modeling.heads
import
dense_prediction_heads
class
ClassificationModelBuilderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
@@ -67,9 +69,14 @@ class ClassificationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
class
RetinaNetBuilderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
'spinenet_mobile'
,
(
640
,
640
),
False
),
(
'spinenet_mobile'
,
(
640
,
640
),
False
,
False
),
(
'spinenet_mobile'
,
(
640
,
640
),
False
,
True
),
)
def
test_builder
(
self
,
backbone_type
,
input_size
,
has_attribute_heads
):
def
test_builder
(
self
,
backbone_type
,
input_size
,
has_attribute_heads
,
quantize_detection_head
):
num_classes
=
2
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
input_size
[
0
],
input_size
[
1
],
3
])
...
...
@@ -83,6 +90,7 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
attribute_heads_config
=
None
model_config
=
retinanet_cfg
.
RetinaNet
(
num_classes
=
num_classes
,
input_size
=
[
input_size
[
0
],
input_size
[
1
],
3
],
backbone
=
backbones
.
Backbone
(
type
=
backbone_type
,
spinenet_mobile
=
backbones
.
SpineNetMobile
(
...
...
@@ -92,15 +100,17 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
max_level
=
7
,
use_keras_upsampling_2d
=
True
)),
head
=
retinanet_cfg
.
RetinaNetHead
(
attribute_heads
=
attribute_heads_config
))
attribute_heads
=
attribute_heads_config
,
use_separable_conv
=
True
))
l2_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
5e-5
)
quantization_config
=
common
.
Quantization
()
quantization_config
=
common
.
Quantization
(
quantize_detection_head
=
quantize_detection_head
)
model
=
factory
.
build_retinanet
(
input_specs
=
input_specs
,
model_config
=
model_config
,
l2_regularizer
=
l2_regularizer
)
_
=
qat_factory
.
build_qat_retinanet
(
qat_model
=
qat_factory
.
build_qat_retinanet
(
model
=
model
,
quantization
=
quantization_config
,
model_config
=
model_config
)
...
...
@@ -109,6 +119,11 @@ class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
dict
(
name
=
'att1'
,
type
=
'regression'
,
size
=
1
))
self
.
assertEqual
(
model_config
.
head
.
attribute_heads
[
1
].
as_dict
(),
dict
(
name
=
'att2'
,
type
=
'classification'
,
size
=
2
))
self
.
assertIsInstance
(
qat_model
.
head
,
(
qat_dense_prediction_heads
.
RetinaNetHeadQuantized
if
quantize_detection_head
else
dense_prediction_heads
.
RetinaNetHead
))
class
SegmentationModelBuilderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
official/projects/qat/vision/quantization/helper.py
View file @
aef943ed
...
...
@@ -21,6 +21,51 @@ import tensorflow_model_optimization as tfmot
from
official.projects.qat.vision.quantization
import
configs
_QUANTIZATION_WEIGHT_NAMES
=
[
'output_max'
,
'output_min'
,
'optimizer_step'
,
'kernel_min'
,
'kernel_max'
,
'add_three_min'
,
'add_three_max'
,
'divide_six_min'
,
'divide_six_max'
,
'depthwise_kernel_min'
,
'depthwise_kernel_max'
,
'reduce_mean_quantizer_vars_min'
,
'reduce_mean_quantizer_vars_max'
,
'quantize_layer_min'
,
'quantize_layer_max'
,
'quantize_layer_2_min'
,
'quantize_layer_2_max'
,
'post_activation_min'
,
'post_activation_max'
,
]
_ORIGINAL_WEIGHT_NAME
=
[
'kernel'
,
'depthwise_kernel'
,
'gamma'
,
'beta'
,
'moving_mean'
,
'moving_variance'
,
'bias'
]
def
is_quantization_weight_name
(
name
:
str
)
->
bool
:
simple_name
=
name
.
split
(
'/'
)[
-
1
].
split
(
':'
)[
0
]
if
simple_name
in
_QUANTIZATION_WEIGHT_NAMES
:
return
True
if
simple_name
in
_ORIGINAL_WEIGHT_NAME
:
return
False
raise
ValueError
(
'Variable name {} is not supported.'
.
format
(
simple_name
))
def
copy_original_weights
(
original_model
:
tf
.
keras
.
Model
,
quantized_model
:
tf
.
keras
.
Model
):
"""Helper function that copy the original model weights to quantized model."""
original_weight_value
=
original_model
.
get_weights
()
weight_values
=
quantized_model
.
get_weights
()
original_idx
=
0
for
idx
,
weight
in
enumerate
(
quantized_model
.
weights
):
if
not
is_quantization_weight_name
(
weight
.
name
):
if
original_idx
>=
len
(
original_weight_value
):
raise
ValueError
(
'Not enought original model weights.'
)
weight_values
[
idx
]
=
original_weight_value
[
original_idx
]
original_idx
=
original_idx
+
1
if
original_idx
<
len
(
original_weight_value
):
raise
ValueError
(
'Not enought quantized model weights.'
)
quantized_model
.
set_weights
(
weight_values
)
class
LayerQuantizerHelper
(
object
):
"""Helper class that handles quantizers."""
...
...
official/projects/qat/vision/quantization/helper_test.py
0 → 100644
View file @
aef943ed
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for helper."""
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow_model_optimization
as
tfmot
from
official.projects.qat.vision.quantization
import
helper
class
HelperTest
(
tf
.
test
.
TestCase
):
def
create_simple_model
(
self
):
return
tf
.
keras
.
models
.
Sequential
([
tf
.
keras
.
layers
.
Dense
(
8
,
input_shape
=
(
16
,)),
])
def
test_copy_original_weights_for_simple_model_with_custom_weights
(
self
):
one_model
=
self
.
create_simple_model
()
one_weights
=
[
np
.
ones_like
(
weight
)
for
weight
in
one_model
.
get_weights
()]
one_model
.
set_weights
(
one_weights
)
qat_model
=
tfmot
.
quantization
.
keras
.
quantize_model
(
self
.
create_simple_model
())
zero_weights
=
[
np
.
zeros_like
(
weight
)
for
weight
in
qat_model
.
get_weights
()]
qat_model
.
set_weights
(
zero_weights
)
helper
.
copy_original_weights
(
one_model
,
qat_model
)
qat_model_weights
=
qat_model
.
get_weights
()
count
=
0
for
idx
,
weight
in
enumerate
(
qat_model
.
weights
):
if
not
helper
.
is_quantization_weight_name
(
weight
.
name
):
self
.
assertAllEqual
(
qat_model_weights
[
idx
],
np
.
ones_like
(
qat_model_weights
[
idx
]))
count
+=
1
self
.
assertLen
(
one_model
.
weights
,
count
)
self
.
assertGreater
(
len
(
qat_model
.
weights
),
len
(
one_model
.
weights
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/qat/vision/quantization/layer_transforms.py
View file @
aef943ed
...
...
@@ -21,6 +21,7 @@ import tensorflow_model_optimization as tfmot
from
official.projects.qat.vision.modeling.layers
import
nn_blocks
as
quantized_nn_blocks
from
official.projects.qat.vision.modeling.layers
import
nn_layers
as
quantized_nn_layers
from
official.projects.qat.vision.quantization
import
configs
from
official.projects.qat.vision.quantization
import
helper
keras
=
tf
.
keras
LayerNode
=
tfmot
.
quantization
.
keras
.
graph_transformations
.
transforms
.
LayerNode
...
...
@@ -31,18 +32,6 @@ _LAYER_NAMES = [
'Vision>SegmentationHead'
,
'Vision>SpatialPyramidPooling'
,
'Vision>ASPP'
]
_QUANTIZATION_WEIGHT_NAMES
=
[
'output_max'
,
'output_min'
,
'optimizer_step'
,
'kernel_min'
,
'kernel_max'
,
'add_three_min'
,
'add_three_max'
,
'divide_six_min'
,
'divide_six_max'
,
'depthwise_kernel_min'
,
'depthwise_kernel_max'
,
'reduce_mean_quantizer_vars_min'
,
'reduce_mean_quantizer_vars_max'
]
_ORIGINAL_WEIGHT_NAME
=
[
'kernel'
,
'depthwise_kernel'
,
'gamma'
,
'beta'
,
'moving_mean'
,
'moving_variance'
,
'bias'
]
class
CustomLayerQuantize
(
tfmot
.
quantization
.
keras
.
graph_transformations
.
transforms
.
Transform
):
...
...
@@ -58,16 +47,6 @@ class CustomLayerQuantize(
"""See base class."""
return
LayerPattern
(
self
.
_original_layer_pattern
)
def
_is_quantization_weight_name
(
self
,
name
):
simple_name
=
name
.
split
(
'/'
)[
-
1
].
split
(
':'
)[
0
]
if
simple_name
in
_QUANTIZATION_WEIGHT_NAMES
:
return
True
if
simple_name
in
_ORIGINAL_WEIGHT_NAME
:
return
False
raise
ValueError
(
'Variable name {} is not supported on '
'CustomLayerQuantize({}) transform.'
.
format
(
simple_name
,
self
.
_original_layer_pattern
))
def
_create_layer_metadata
(
self
,
layer_class_name
:
str
)
->
Mapping
[
str
,
tfmot
.
quantization
.
keras
.
QuantizeConfig
]:
...
...
@@ -97,7 +76,7 @@ class CustomLayerQuantize(
match_idx
=
0
names_and_weights
=
[]
for
name_and_weight
in
quantized_names_and_weights
:
if
not
s
el
f
.
_
is_quantization_weight_name
(
name
=
name_and_weight
[
0
]):
if
not
h
el
per
.
is_quantization_weight_name
(
name
=
name_and_weight
[
0
]):
name_and_weight
=
bottleneck_names_and_weights
[
match_idx
]
match_idx
=
match_idx
+
1
names_and_weights
.
append
(
name_and_weight
)
...
...
official/projects/qat/vision/tasks/retinanet.py
View file @
aef943ed
...
...
@@ -28,6 +28,10 @@ class RetinaNetTask(retinanet.RetinaNetTask):
def
build_model
(
self
)
->
tf
.
keras
.
Model
:
"""Builds RetinaNet model with QAT."""
model
=
super
(
RetinaNetTask
,
self
).
build_model
()
# Call the model with dummy input to build the head part.
dummpy_input
=
tf
.
zeros
([
1
]
+
self
.
task_config
.
model
.
input_size
)
model
(
dummpy_input
,
training
=
True
)
if
self
.
task_config
.
quantization
:
model
=
factory
.
build_qat_retinanet
(
model
,
...
...
official/projects/qat/vision/tasks/retinanet_test.py
View file @
aef943ed
...
...
@@ -65,6 +65,7 @@ class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase):
task
=
retinanet
.
RetinaNetTask
(
config
.
task
)
model
=
task
.
build_model
()
self
.
assertLen
(
model
.
weights
,
2393
)
metrics
=
task
.
build_metrics
(
training
=
is_training
)
strategy
=
tf
.
distribute
.
get_strategy
()
...
...
official/projects/video_ssl/modeling/video_ssl_model.py
View file @
aef943ed
...
...
@@ -53,7 +53,7 @@ class VideoSSLModel(tf.keras.Model):
hidden_dim: `int` number of hidden units in MLP.
hidden_layer_num: `int` number of hidden layers in MLP.
hidden_norm_args: `dict` for batchnorm arguments in MLP.
projection_dim: `int` number of ouput dimension for MLP.
projection_dim: `int` number of ou
t
put dimension for MLP.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
dropout_rate: `float` rate for dropout regularization.
aggregate_endpoints: `bool` aggregate all end ponits or only use the
...
...
official/projects/yt8m/dataloaders/yt8m_input.py
View file @
aef943ed
...
...
@@ -33,7 +33,7 @@ from official.vision.dataloaders import parser
def
resize_axis
(
tensor
,
axis
,
new_size
,
fill_value
=
0
):
"""Truncates or pads a tensor to new_size on
on
a given axis.
"""Truncates or pads a tensor to new_size on a given axis.
Truncate or extend tensor such that tensor.shape[axis] == new_size. If the
size increases, the padding will be performed at the end, using fill_value.
...
...
official/projects/yt8m/train_test.py
View file @
aef943ed
...
...
@@ -82,7 +82,8 @@ class TrainTest(parameterized.TestCase, tf.test.TestCase):
})
FLAGS
.
params_override
=
params_override
train_lib
.
train
.
main
(
'unused_args'
)
with
train_lib
.
train
.
gin
.
unlock_config
():
train_lib
.
train
.
main
(
'unused_args'
)
FLAGS
.
mode
=
'eval'
...
...
official/vision/beta/projects/centernet/README.md
View file @
aef943ed
...
...
@@ -22,7 +22,7 @@ heatmaps (one heatmap for each class) is needed to predict the object. CenterNet
proves that this can be done without a significant difference in accuracy.
## Enviroment setup
## Enviro
n
ment setup
The code can be run on multiple GPUs or TPUs with different distribution
strategies. See the TensorFlow distributed training
...
...
official/vision/beta/projects/simclr/README.md
View file @
aef943ed
...
...
@@ -10,7 +10,7 @@
An illustration of SimCLR (from
<a
href=
"https://ai.googleblog.com/2020/04/advancing-self-supervised-and-semi.html"
>
our blog here
</a>
).
</div>
## Enviroment setup
## Enviro
n
ment setup
The code can be run on multiple GPUs or TPUs with different distribution
strategies. See the TensorFlow distributed training
...
...
official/vision/beta/projects/yolo/losses/yolo_loss.py
View file @
aef943ed
...
...
@@ -323,7 +323,7 @@ class DarknetLoss(YoloLossBase):
grid_points
=
tf
.
stop_gradient
(
grid_points
)
anchor_grid
=
tf
.
stop_gradient
(
anchor_grid
)
# Split all the ground truths to use as sep
e
rate items in loss computation.
# Split all the ground truths to use as sep
a
rate items in loss computation.
(
true_box
,
ind_mask
,
true_class
)
=
tf
.
split
(
y_true
,
[
4
,
1
,
1
],
axis
=-
1
)
true_conf
=
tf
.
squeeze
(
true_conf
,
axis
=-
1
)
true_class
=
tf
.
squeeze
(
true_class
,
axis
=-
1
)
...
...
official/vision/beta/projects/yolo/modeling/decoders/yolo_decoder.py
View file @
aef943ed
...
...
@@ -613,7 +613,7 @@ def build_yolo_decoder(
'{yolo_model.YOLO_MODELS[decoder_cfg.version].keys()}'
'or specify a custom decoder config using YoloDecoder.'
)
base_model
=
YOLO_MODELS
[
decoder_cfg
.
version
][
decoder_cfg
.
type
]
base_model
=
YOLO_MODELS
[
decoder_cfg
.
version
][
decoder_cfg
.
type
]
.
copy
()
cfg_dict
=
decoder_cfg
.
as_dict
()
for
key
in
base_model
:
...
...
official/vision/beta/projects/yolo/modeling/layers/nn_blocks.py
View file @
aef943ed
...
...
@@ -823,7 +823,7 @@ class CSPStack(tf.keras.layers.Layer):
make it a cross stage partial. Added for ease of use. you should be able
to wrap any layer stack with a CSP independent of wether it belongs
to the Darknet family. if filter_scale = 2, then the blocks in the stack
passed into the
the
CSP stack should also have filters = filters/filter_scale
passed into the CSP stack should also have filters = filters/filter_scale
Cross Stage Partial networks (CSPNets) were proposed in:
[1] Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu,
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment