Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ca552843
Unverified
Commit
ca552843
authored
Sep 16, 2021
by
Srihari Humbarwadi
Committed by
GitHub
Sep 16, 2021
Browse files
Merge branch 'panoptic-segmentation' into panoptic-segmentation
parents
7e2f7a35
6b90e134
Changes
283
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
325 additions
and
240 deletions
+325
-240
official/vision/beta/projects/movinet/README.md
official/vision/beta/projects/movinet/README.md
+2
-2
official/vision/beta/projects/movinet/export_saved_model_test.py
...l/vision/beta/projects/movinet/export_saved_model_test.py
+2
-2
official/vision/beta/projects/movinet/modeling/movinet.py
official/vision/beta/projects/movinet/modeling/movinet.py
+24
-21
official/vision/beta/projects/movinet/modeling/movinet_layers.py
...l/vision/beta/projects/movinet/modeling/movinet_layers.py
+2
-2
official/vision/beta/projects/movinet/modeling/movinet_test.py
...ial/vision/beta/projects/movinet/modeling/movinet_test.py
+10
-10
official/vision/beta/projects/panoptic_maskrcnn/configs/panoptic_maskrcnn.py
...a/projects/panoptic_maskrcnn/configs/panoptic_maskrcnn.py
+19
-1
official/vision/beta/projects/panoptic_maskrcnn/dataloaders/panoptic_maskrcnn_input.py
.../panoptic_maskrcnn/dataloaders/panoptic_maskrcnn_input.py
+84
-35
official/vision/beta/projects/panoptic_maskrcnn/modeling/panoptic_maskrcnn_model_test.py
...anoptic_maskrcnn/modeling/panoptic_maskrcnn_model_test.py
+1
-1
official/vision/beta/projects/panoptic_maskrcnn/tasks/panoptic_maskrcnn.py
...eta/projects/panoptic_maskrcnn/tasks/panoptic_maskrcnn.py
+10
-7
official/vision/beta/projects/simclr/common/registry_imports.py
...al/vision/beta/projects/simclr/common/registry_imports.py
+0
-14
official/vision/beta/projects/simclr/configs/experiments/imagenet_simclr_multitask_tpu.yaml
...lr/configs/experiments/imagenet_simclr_multitask_tpu.yaml
+138
-0
official/vision/beta/projects/simclr/configs/multitask_config.py
...l/vision/beta/projects/simclr/configs/multitask_config.py
+11
-4
official/vision/beta/projects/simclr/configs/simclr.py
official/vision/beta/projects/simclr/configs/simclr.py
+20
-34
official/vision/beta/projects/simclr/configs/simclr_test.py
official/vision/beta/projects/simclr/configs/simclr_test.py
+1
-17
official/vision/beta/projects/simclr/dataloaders/preprocess_ops.py
...vision/beta/projects/simclr/dataloaders/preprocess_ops.py
+0
-14
official/vision/beta/projects/simclr/dataloaders/simclr_input.py
...l/vision/beta/projects/simclr/dataloaders/simclr_input.py
+0
-14
official/vision/beta/projects/simclr/heads/simclr_head.py
official/vision/beta/projects/simclr/heads/simclr_head.py
+1
-15
official/vision/beta/projects/simclr/heads/simclr_head_test.py
...ial/vision/beta/projects/simclr/heads/simclr_head_test.py
+0
-16
official/vision/beta/projects/simclr/losses/contrastive_losses.py
.../vision/beta/projects/simclr/losses/contrastive_losses.py
+0
-15
official/vision/beta/projects/simclr/losses/contrastive_losses_test.py
...on/beta/projects/simclr/losses/contrastive_losses_test.py
+0
-16
No files found.
official/vision/beta/projects/movinet/README.md
View file @
ca552843
...
...
@@ -338,7 +338,7 @@ with the Python API:
```
python
# Create the interpreter and signature runner
interpreter
=
tf
.
lite
.
Interpreter
(
'/tmp/movinet_a0_stream.tflite'
)
signature
=
interpreter
.
get_signature_runner
()
runner
=
interpreter
.
get_signature_runner
()
# Extract state names and create the initial (zero) states
def
state_name
(
name
:
str
)
->
str
:
...
...
@@ -358,7 +358,7 @@ clips = tf.split(video, video.shape[1], axis=1)
states
=
init_states
for
clip
in
clips
:
# Input shape: [1, 1, 172, 172, 3]
outputs
=
signature
(
**
states
,
image
=
clip
)
outputs
=
runner
(
**
states
,
image
=
clip
)
logits
=
outputs
.
pop
(
'logits'
)
states
=
outputs
```
...
...
official/vision/beta/projects/movinet/export_saved_model_test.py
View file @
ca552843
...
...
@@ -121,7 +121,7 @@ class ExportSavedModelTest(tf.test.TestCase):
tflite_model
=
converter
.
convert
()
interpreter
=
tf
.
lite
.
Interpreter
(
model_content
=
tflite_model
)
signature
=
interpreter
.
get_signature_runner
()
runner
=
interpreter
.
get_signature_runner
(
'serving_default'
)
def
state_name
(
name
:
str
)
->
str
:
return
name
[
len
(
'serving_default_'
):
-
len
(
':0'
)]
...
...
@@ -137,7 +137,7 @@ class ExportSavedModelTest(tf.test.TestCase):
states
=
init_states
for
clip
in
clips
:
outputs
=
signature
(
**
states
,
image
=
clip
)
outputs
=
runner
(
**
states
,
image
=
clip
)
logits
=
outputs
.
pop
(
'logits'
)
states
=
outputs
...
...
official/vision/beta/projects/movinet/modeling/movinet.py
View file @
ca552843
...
...
@@ -17,10 +17,10 @@
Reference: https://arxiv.org/pdf/2103.11511.pdf
"""
import
dataclasses
import
math
from
typing
import
Dict
,
Mapping
,
Optional
,
Sequence
,
Tuple
,
Union
import
dataclasses
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
...
...
@@ -454,7 +454,7 @@ class Movinet(tf.keras.Model):
stochastic_depth_idx
=
1
for
block_idx
,
block
in
enumerate
(
self
.
_block_specs
):
if
isinstance
(
block
,
StemSpec
):
x
,
states
=
movinet_layers
.
Stem
(
layer_obj
=
movinet_layers
.
Stem
(
block
.
filters
,
block
.
kernel_size
,
block
.
strides
,
...
...
@@ -466,9 +466,9 @@ class Movinet(tf.keras.Model):
batch_norm_layer
=
self
.
_norm
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
state_prefix
=
'state
/
stem'
,
name
=
'stem'
)
(
x
,
states
=
states
)
state_prefix
=
'state
_
stem'
,
name
=
'stem'
)
x
,
states
=
layer_obj
(
x
,
states
=
states
)
endpoints
[
'stem'
]
=
x
elif
isinstance
(
block
,
MovinetBlockSpec
):
if
not
(
len
(
block
.
expand_filters
)
==
len
(
block
.
kernel_sizes
)
==
...
...
@@ -486,8 +486,8 @@ class Movinet(tf.keras.Model):
self
.
_stochastic_depth_drop_rate
*
stochastic_depth_idx
/
num_layers
)
expand_filters
,
kernel_size
,
strides
=
layer
name
=
f
'b
{
block_idx
-
1
}
/l
{
layer_idx
}
'
x
,
states
=
movinet_layers
.
MovinetBlock
(
name
=
f
'b
lock
{
block_idx
-
1
}
_layer
{
layer_idx
}
'
layer_obj
=
movinet_layers
.
MovinetBlock
(
block
.
base_filters
,
expand_filters
,
kernel_size
=
kernel_size
,
...
...
@@ -505,13 +505,14 @@ class Movinet(tf.keras.Model):
batch_norm_layer
=
self
.
_norm
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
state_prefix
=
f
'state/
{
name
}
'
,
name
=
name
)(
x
,
states
=
states
)
state_prefix
=
f
'state_
{
name
}
'
,
name
=
name
)
x
,
states
=
layer_obj
(
x
,
states
=
states
)
endpoints
[
name
]
=
x
stochastic_depth_idx
+=
1
elif
isinstance
(
block
,
HeadSpec
):
x
,
states
=
movinet_layers
.
Head
(
layer_obj
=
movinet_layers
.
Head
(
project_filters
=
block
.
project_filters
,
conv_type
=
self
.
_conv_type
,
activation
=
self
.
_activation
,
...
...
@@ -520,9 +521,9 @@ class Movinet(tf.keras.Model):
batch_norm_layer
=
self
.
_norm
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
state_prefix
=
'state
/
head'
,
name
=
'head'
)
(
x
,
states
=
states
)
state_prefix
=
'state
_
head'
,
name
=
'head'
)
x
,
states
=
layer_obj
(
x
,
states
=
states
)
endpoints
[
'head'
]
=
x
else
:
raise
ValueError
(
'Unknown block type {}'
.
format
(
block
))
...
...
@@ -567,7 +568,7 @@ class Movinet(tf.keras.Model):
for
block_idx
,
block
in
enumerate
(
block_specs
):
if
isinstance
(
block
,
StemSpec
):
if
block
.
kernel_size
[
0
]
>
1
:
states
[
'state
/
stem
/
stream_buffer'
]
=
(
states
[
'state
_
stem
_
stream_buffer'
]
=
(
input_shape
[
0
],
input_shape
[
1
],
divide_resolution
(
input_shape
[
2
],
num_downsamples
),
...
...
@@ -590,8 +591,10 @@ class Movinet(tf.keras.Model):
self
.
_conv_type
in
[
'2plus1d'
,
'3d_2plus1d'
]):
num_downsamples
+=
1
prefix
=
f
'state_block
{
block_idx
}
_layer
{
layer_idx
}
'
if
kernel_size
[
0
]
>
1
:
states
[
f
'
state/b
{
block_idx
}
/l
{
layer_idx
}
/
stream_buffer'
]
=
(
states
[
f
'
{
prefix
}
_
stream_buffer'
]
=
(
input_shape
[
0
],
kernel_size
[
0
]
-
1
,
divide_resolution
(
input_shape
[
2
],
num_downsamples
),
...
...
@@ -599,13 +602,13 @@ class Movinet(tf.keras.Model):
expand_filters
,
)
states
[
f
'
state/b
{
block_idx
}
/l
{
layer_idx
}
/
pool_buffer'
]
=
(
states
[
f
'
{
prefix
}
_
pool_buffer'
]
=
(
input_shape
[
0
],
1
,
1
,
1
,
expand_filters
,
)
states
[
f
'
state/b
{
block_idx
}
/l
{
layer_idx
}
/
pool_frame_count'
]
=
(
1
,)
states
[
f
'
{
prefix
}
_
pool_frame_count'
]
=
(
1
,)
if
use_positional_encoding
:
name
=
f
'
state/b
{
block_idx
}
/l
{
layer_idx
}
/
pos_enc_frame_count'
name
=
f
'
{
prefix
}
_
pos_enc_frame_count'
states
[
name
]
=
(
1
,)
if
strides
[
1
]
!=
strides
[
2
]:
...
...
@@ -618,10 +621,10 @@ class Movinet(tf.keras.Model):
self
.
_conv_type
not
in
[
'2plus1d'
,
'3d_2plus1d'
]):
num_downsamples
+=
1
elif
isinstance
(
block
,
HeadSpec
):
states
[
'state
/
head
/
pool_buffer'
]
=
(
states
[
'state
_
head
_
pool_buffer'
]
=
(
input_shape
[
0
],
1
,
1
,
1
,
block
.
project_filters
,
)
states
[
'state
/
head
/
pool_frame_count'
]
=
(
1
,)
states
[
'state
_
head
_
pool_frame_count'
]
=
(
1
,)
return
states
...
...
official/vision/beta/projects/movinet/modeling/movinet_layers.py
View file @
ca552843
...
...
@@ -478,7 +478,7 @@ class StreamBuffer(tf.keras.layers.Layer):
state_prefix
=
state_prefix
if
state_prefix
is
not
None
else
''
self
.
_state_prefix
=
state_prefix
self
.
_state_name
=
f
'
{
state_prefix
}
/
stream_buffer'
self
.
_state_name
=
f
'
{
state_prefix
}
_
stream_buffer'
self
.
_buffer_size
=
buffer_size
def
get_config
(
self
):
...
...
@@ -501,7 +501,7 @@ class StreamBuffer(tf.keras.layers.Layer):
inputs: the input tensor.
states: a dict of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s).
Expected keys include `state_prefix + '
/
stream_buffer'`.
Expected keys include `state_prefix + '
_
stream_buffer'`.
Returns:
the output tensor and states
...
...
official/vision/beta/projects/movinet/modeling/movinet_test.py
View file @
ca552843
...
...
@@ -35,11 +35,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
endpoints
,
states
=
network
(
inputs
)
self
.
assertAllEqual
(
endpoints
[
'stem'
].
shape
,
[
1
,
8
,
64
,
64
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
0/l
0'
].
shape
,
[
1
,
8
,
32
,
32
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
1/l
0'
].
shape
,
[
1
,
8
,
16
,
16
,
32
])
self
.
assertAllEqual
(
endpoints
[
'b
2/l
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
3/l
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
4/l
0'
].
shape
,
[
1
,
8
,
4
,
4
,
104
])
self
.
assertAllEqual
(
endpoints
[
'b
lock0_layer
0'
].
shape
,
[
1
,
8
,
32
,
32
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
lock1_layer
0'
].
shape
,
[
1
,
8
,
16
,
16
,
32
])
self
.
assertAllEqual
(
endpoints
[
'b
lock2_layer
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
lock3_layer
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
lock4_layer
0'
].
shape
,
[
1
,
8
,
4
,
4
,
104
])
self
.
assertAllEqual
(
endpoints
[
'head'
].
shape
,
[
1
,
1
,
1
,
1
,
480
])
self
.
assertNotEmpty
(
states
)
...
...
@@ -59,11 +59,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
endpoints
,
new_states
=
backbone
({
**
init_states
,
'image'
:
inputs
})
self
.
assertAllEqual
(
endpoints
[
'stem'
].
shape
,
[
1
,
8
,
64
,
64
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
0/l
0'
].
shape
,
[
1
,
8
,
32
,
32
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
1/l
0'
].
shape
,
[
1
,
8
,
16
,
16
,
32
])
self
.
assertAllEqual
(
endpoints
[
'b
2/l
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
3/l
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
4/l
0'
].
shape
,
[
1
,
8
,
4
,
4
,
104
])
self
.
assertAllEqual
(
endpoints
[
'b
lock0_layer
0'
].
shape
,
[
1
,
8
,
32
,
32
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
lock1_layer
0'
].
shape
,
[
1
,
8
,
16
,
16
,
32
])
self
.
assertAllEqual
(
endpoints
[
'b
lock2_layer
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
lock3_layer
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
lock4_layer
0'
].
shape
,
[
1
,
8
,
4
,
4
,
104
])
self
.
assertAllEqual
(
endpoints
[
'head'
].
shape
,
[
1
,
1
,
1
,
1
,
480
])
self
.
assertNotEmpty
(
init_states
)
...
...
official/vision/beta/projects/panoptic_maskrcnn/configs/panoptic_maskrcnn.py
View file @
ca552843
...
...
@@ -22,6 +22,7 @@ from official.core import config_definitions as cfg
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.modeling
import
optimization
from
official.vision.beta.configs
import
common
from
official.vision.beta.configs
import
maskrcnn
from
official.vision.beta.configs
import
semantic_segmentation
...
...
@@ -47,14 +48,31 @@ class Parser(maskrcnn.Parser):
segmentation_groundtruth_padded_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
segmentation_ignore_label
:
int
=
255
panoptic_ignore_label
:
int
=
0
# Setting this to true will enable parsing category_mask and instance_mask.
include_panoptic_masks
:
bool
=
True
@
dataclasses
.
dataclass
class
TfExampleDecoder
(
common
.
TfExampleDecoder
):
"""A simple TF Example decoder config."""
# Setting this to true will enable decoding category_mask and instance_mask.
include_panoptic_masks
:
bool
=
True
@
dataclasses
.
dataclass
class
DataDecoder
(
common
.
DataDecoder
):
"""Data decoder config."""
simple_decoder
:
TfExampleDecoder
=
TfExampleDecoder
()
@
dataclasses
.
dataclass
class
DataConfig
(
maskrcnn
.
DataConfig
):
"""Input config for training."""
decoder
:
DataDecoder
=
DataDecoder
()
parser
:
Parser
=
Parser
()
# @dataclasses.dataclass
@
dataclasses
.
dataclass
class
PanopticSegmentationGenerator
(
hyperparams
.
Config
):
output_size
:
List
[
int
]
=
dataclasses
.
field
(
...
...
official/vision/beta/projects/panoptic_maskrcnn/dataloaders/panoptic_maskrcnn_input.py
View file @
ca552843
...
...
@@ -24,25 +24,51 @@ from official.vision.beta.ops import preprocess_ops
class
TfExampleDecoder
(
tf_example_decoder
.
TfExampleDecoder
):
"""Tensorflow Example proto decoder."""
def
__init__
(
self
,
regenerate_source_id
,
mask_binarize_threshold
):
def
__init__
(
self
,
regenerate_source_id
,
mask_binarize_threshold
,
include_panoptic_masks
):
super
(
TfExampleDecoder
,
self
).
__init__
(
include_mask
=
True
,
regenerate_source_id
=
regenerate_source_id
,
mask_binarize_threshold
=
None
)
self
.
_segmentation_keys_to_features
=
{
self
.
_include_panoptic_masks
=
include_panoptic_masks
keys_to_features
=
{
'image/segmentation/class/encoded'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
)
}
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
)}
if
include_panoptic_masks
:
keys_to_features
.
update
({
'image/panoptic/category_mask'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
'image/panoptic/instance_mask'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
)})
self
.
_segmentation_keys_to_features
=
keys_to_features
def
decode
(
self
,
serialized_example
):
decoded_tensors
=
super
(
TfExampleDecoder
,
self
).
decode
(
serialized_example
)
segmentation_
parsed_tensors
=
tf
.
io
.
parse_single_example
(
parsed_tensors
=
tf
.
io
.
parse_single_example
(
serialized_example
,
self
.
_segmentation_keys_to_features
)
segmentation_mask
=
tf
.
io
.
decode_image
(
segmentation_
parsed_tensors
[
'image/segmentation/class/encoded'
],
parsed_tensors
[
'image/segmentation/class/encoded'
],
channels
=
1
)
segmentation_mask
.
set_shape
([
None
,
None
,
1
])
decoded_tensors
.
update
({
'groundtruth_segmentation_mask'
:
segmentation_mask
})
if
self
.
_include_panoptic_masks
:
category_mask
=
tf
.
io
.
decode_image
(
parsed_tensors
[
'image/panoptic/category_mask'
],
channels
=
1
)
instance_mask
=
tf
.
io
.
decode_image
(
parsed_tensors
[
'image/panoptic/instance_mask'
],
channels
=
1
)
category_mask
.
set_shape
([
None
,
None
,
1
])
instance_mask
.
set_shape
([
None
,
None
,
1
])
decoded_tensors
.
update
({
'groundtruth_panoptic_category_mask'
:
category_mask
,
'groundtruth_panoptic_instance_mask'
:
instance_mask
})
return
decoded_tensors
...
...
@@ -69,6 +95,8 @@ class Parser(maskrcnn_input.Parser):
segmentation_resize_eval_groundtruth
=
True
,
segmentation_groundtruth_padded_size
=
None
,
segmentation_ignore_label
=
255
,
panoptic_ignore_label
=
0
,
include_panoptic_masks
=
True
,
dtype
=
'float32'
):
"""Initializes parameters for parsing annotations in the dataset.
...
...
@@ -106,8 +134,12 @@ class Parser(maskrcnn_input.Parser):
segmentation_groundtruth_padded_size: `Tensor` or `list` for [height,
width]. When resize_eval_groundtruth is set to False, the groundtruth
masks are padded to this size.
segmentation_ignore_label: `int` the pixel with ignore label will not used
for training and evaluation.
segmentation_ignore_label: `int` the pixels with ignore label will not be
used for training and evaluation.
panoptic_ignore_label: `int` the pixels with ignore label will not be used
by the PQ evaluator.
include_panoptic_masks: `bool`, if True, category_mask and instance_mask
will be parsed. Set this to true if PQ evaluator is enabled.
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
"""
super
(
Parser
,
self
).
__init__
(
...
...
@@ -139,6 +171,8 @@ class Parser(maskrcnn_input.Parser):
'specified when segmentation_resize_eval_groundtruth is False.'
)
self
.
_segmentation_groundtruth_padded_size
=
segmentation_groundtruth_padded_size
self
.
_segmentation_ignore_label
=
segmentation_ignore_label
self
.
_panoptic_ignore_label
=
panoptic_ignore_label
self
.
_include_panoptic_masks
=
include_panoptic_masks
def
_parse_train_data
(
self
,
data
):
"""Parses data for training.
...
...
@@ -250,39 +284,54 @@ class Parser(maskrcnn_input.Parser):
shape [height_l, width_l, 4] representing anchor boxes at each
level.
"""
segmentation_mask
=
tf
.
cast
(
data
[
'groundtruth_segmentation_mask'
],
tf
.
float32
)
segmentation_mask
=
tf
.
reshape
(
segmentation_mask
,
shape
=
[
1
,
data
[
'height'
],
data
[
'width'
],
1
])
segmentation_mask
+=
1
def
_process_mask
(
mask
,
ignore_label
,
image_info
):
mask
=
tf
.
cast
(
mask
,
dtype
=
tf
.
float32
)
mask
=
tf
.
reshape
(
mask
,
shape
=
[
1
,
data
[
'height'
],
data
[
'width'
],
1
])
mask
+=
1
if
self
.
_segmentation_resize_eval_groundtruth
:
# Resizes eval masks to match input image sizes. In that case, mean IoU
# is computed on output_size not the original size of the images.
image_scale
=
image_info
[
2
,
:]
offset
=
image_info
[
3
,
:]
mask
=
preprocess_ops
.
resize_and_crop_masks
(
mask
,
image_scale
,
self
.
_output_size
,
offset
)
else
:
mask
=
tf
.
image
.
pad_to_bounding_box
(
mask
,
0
,
0
,
self
.
_segmentation_groundtruth_padded_size
[
0
],
self
.
_segmentation_groundtruth_padded_size
[
1
])
mask
-=
1
# Assign ignore label to the padded region.
mask
=
tf
.
where
(
tf
.
equal
(
mask
,
-
1
),
ignore_label
*
tf
.
ones_like
(
mask
),
mask
)
mask
=
tf
.
squeeze
(
mask
,
axis
=
0
)
return
mask
image
,
labels
=
super
(
Parser
,
self
).
_parse_eval_data
(
data
)
image_info
=
labels
[
'image_info'
]
if
self
.
_segmentation_resize_eval_groundtruth
:
# Resizes eval masks to match input image sizes. In that case, mean IoU
# is computed on output_size not the original size of the images.
image_info
=
labels
[
'image_info'
]
image_scale
=
image_info
[
2
,
:]
offset
=
image_info
[
3
,
:]
segmentation_mask
=
preprocess_ops
.
resize_and_crop_masks
(
segmentation_mask
,
image_scale
,
self
.
_output_size
,
offset
)
else
:
segmentation_mask
=
tf
.
image
.
pad_to_bounding_box
(
segmentation_mask
,
0
,
0
,
self
.
_segmentation_groundtruth_padded_size
[
0
],
self
.
_segmentation_groundtruth_padded_size
[
1
])
segmentation_mask
-=
1
# Assign ignore label to the padded region.
segmentation_mask
=
tf
.
where
(
tf
.
equal
(
segmentation_mask
,
-
1
),
self
.
_segmentation_ignore_label
*
tf
.
ones_like
(
segmentation_mask
),
segmentation_mask
)
segmentation_mask
=
tf
.
squeeze
(
segmentation_mask
,
axis
=
0
)
segmentation_mask
=
_process_mask
(
data
[
'groundtruth_segmentation_mask'
],
self
.
_segmentation_ignore_label
,
image_info
)
segmentation_valid_mask
=
tf
.
not_equal
(
segmentation_mask
,
self
.
_segmentation_ignore_label
)
labels
[
'groundtruths'
].
update
({
'gt_segmentation_mask'
:
segmentation_mask
,
'gt_segmentation_valid_mask'
:
segmentation_valid_mask
})
if
self
.
_include_panoptic_masks
:
panoptic_category_mask
=
_process_mask
(
data
[
'groundtruth_panoptic_category_mask'
],
self
.
_panoptic_ignore_label
,
image_info
)
panoptic_instance_mask
=
_process_mask
(
data
[
'groundtruth_panoptic_instance_mask'
],
self
.
_panoptic_ignore_label
,
image_info
)
labels
[
'groundtruths'
].
update
({
'gt_panoptic_category_mask'
:
panoptic_category_mask
,
'gt_panoptic_instance_mask'
:
panoptic_instance_mask
})
return
image
,
labels
official/vision/beta/projects/panoptic_maskrcnn/modeling/panoptic_maskrcnn_model_test.py
View file @
ca552843
...
...
@@ -493,7 +493,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
ckpt
.
save
(
os
.
path
.
join
(
save_dir
,
'ckpt'
))
partial_ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
backbone
)
partial_ckpt
.
re
store
(
tf
.
train
.
latest_checkpoint
(
partial_ckpt
.
re
ad
(
tf
.
train
.
latest_checkpoint
(
save_dir
)).
expect_partial
().
assert_existing_objects_matched
()
partial_ckpt_mask
=
tf
.
train
.
Checkpoint
(
...
...
official/vision/beta/projects/panoptic_maskrcnn/tasks/panoptic_maskrcnn.py
View file @
ca552843
...
...
@@ -78,14 +78,14 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
checkpoint_path
=
_get_checkpoint_path
(
self
.
task_config
.
init_checkpoint
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
re
store
(
checkpoint_path
)
status
.
assert_consum
ed
()
status
=
ckpt
.
re
ad
(
checkpoint_path
)
status
.
expect_partial
().
assert_existing_objects_match
ed
()
elif
init_module
==
'backbone'
:
checkpoint_path
=
_get_checkpoint_path
(
self
.
task_config
.
init_checkpoint
)
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
re
store
(
checkpoint_path
)
status
=
ckpt
.
re
ad
(
checkpoint_path
)
status
.
expect_partial
().
assert_existing_objects_matched
()
elif
init_module
==
'segmentation_backbone'
:
...
...
@@ -93,7 +93,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self
.
task_config
.
segmentation_init_checkpoint
)
ckpt
=
tf
.
train
.
Checkpoint
(
segmentation_backbone
=
model
.
segmentation_backbone
)
status
=
ckpt
.
re
store
(
checkpoint_path
)
status
=
ckpt
.
re
ad
(
checkpoint_path
)
status
.
expect_partial
().
assert_existing_objects_matched
()
elif
init_module
==
'segmentation_decoder'
:
...
...
@@ -101,7 +101,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self
.
task_config
.
segmentation_init_checkpoint
)
ckpt
=
tf
.
train
.
Checkpoint
(
segmentation_decoder
=
model
.
segmentation_decoder
)
status
=
ckpt
.
re
store
(
checkpoint_path
)
status
=
ckpt
.
re
ad
(
checkpoint_path
)
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
...
...
@@ -122,7 +122,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
if
params
.
decoder
.
type
==
'simple_decoder'
:
decoder
=
panoptic_maskrcnn_input
.
TfExampleDecoder
(
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
,
mask_binarize_threshold
=
decoder_cfg
.
mask_binarize_threshold
)
mask_binarize_threshold
=
decoder_cfg
.
mask_binarize_threshold
,
include_panoptic_masks
=
decoder_cfg
.
include_panoptic_masks
)
else
:
raise
ValueError
(
'Unknown decoder type: {}!'
.
format
(
params
.
decoder
.
type
))
...
...
@@ -148,7 +149,9 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
.
segmentation_resize_eval_groundtruth
,
segmentation_groundtruth_padded_size
=
params
.
parser
.
segmentation_groundtruth_padded_size
,
segmentation_ignore_label
=
params
.
parser
.
segmentation_ignore_label
)
segmentation_ignore_label
=
params
.
parser
.
segmentation_ignore_label
,
panoptic_ignore_label
=
params
.
parser
.
panoptic_ignore_label
,
include_panoptic_masks
=
params
.
parser
.
include_panoptic_masks
)
reader
=
input_reader_factory
.
input_reader_generator
(
params
,
...
...
official/vision/beta/projects/simclr/common/registry_imports.py
View file @
ca552843
...
...
@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""All necessary imports for registration."""
# pylint: disable=unused-import
...
...
official/vision/beta/projects/simclr/configs/experiments/imagenet_simclr_multitask_tpu.yaml
0 → 100644
View file @
ca552843
runtime
:
distribution_strategy
:
tpu
mixed_precision_dtype
:
'
bfloat16'
task
:
init_checkpoint
:
'
'
model
:
backbone
:
resnet
:
model_id
:
50
type
:
resnet
projection_head
:
ft_proj_idx
:
1
num_proj_layers
:
3
proj_output_dim
:
128
backbone_trainable
:
true
heads
:
!!python/tuple
# Define heads for the PRETRAIN networks here
-
task_name
:
pretrain_imagenet
mode
:
pretrain
# # Define heads for the FINETUNE networks here
-
task_name
:
finetune_imagenet_10percent
mode
:
finetune
supervised_head
:
num_classes
:
1001
zero_init
:
true
input_size
:
[
224
,
224
,
3
]
l2_weight_decay
:
0.0
norm_activation
:
norm_epsilon
:
1.0e-05
norm_momentum
:
0.9
use_sync_bn
:
true
task_routines
:
!!python/tuple
# Define TASK CONFIG for the PRETRAIN networks here
-
task_name
:
pretrain_imagenet
task_weight
:
30.0
task_config
:
evaluation
:
one_hot
:
true
top_k
:
5
loss
:
l2_weight_decay
:
0.0
projection_norm
:
true
temperature
:
0.1
model
:
input_size
:
[
224
,
224
,
3
]
mode
:
pretrain
train_data
:
input_path
:
/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/train*
input_set_label_to_zero
:
true
# Set labels to zeros to double confirm that no label is used during pretrain
is_training
:
true
global_batch_size
:
4096
dtype
:
'
bfloat16'
parser
:
aug_rand_hflip
:
true
mode
:
pretrain
decoder
:
decode_label
:
true
validation_data
:
input_path
:
/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/valid*
is_training
:
false
global_batch_size
:
2048
dtype
:
'
bfloat16'
drop_remainder
:
false
parser
:
mode
:
pretrain
decoder
:
decode_label
:
true
# Define TASK CONFIG for the FINETUNE Networks here
-
task_name
:
finetune_imagenet_10percent
task_weight
:
1.0
task_config
:
evaluation
:
one_hot
:
true
top_k
:
5
loss
:
l2_weight_decay
:
0.0
label_smoothing
:
0.0
one_hot
:
true
model
:
input_size
:
[
224
,
224
,
3
]
mode
:
finetune
supervised_head
:
num_classes
:
1001
zero_init
:
true
train_data
:
tfds_name
:
'
imagenet2012_subset/10pct'
tfds_split
:
'
train'
input_path
:
'
'
is_training
:
true
global_batch_size
:
1024
dtype
:
'
bfloat16'
parser
:
aug_rand_hflip
:
true
mode
:
finetune
decoder
:
decode_label
:
true
validation_data
:
tfds_name
:
'
imagenet2012_subset/10pct'
tfds_split
:
'
validation'
input_path
:
'
'
is_training
:
false
global_batch_size
:
2048
dtype
:
'
bfloat16'
drop_remainder
:
false
parser
:
mode
:
finetune
decoder
:
decode_label
:
true
trainer
:
trainer_type
:
interleaving
task_sampler
:
proportional
:
alpha
:
1.0
type
:
proportional
train_steps
:
32000
# 100 epochs
validation_steps
:
24
# NUM_EXAMPLES (50000) // global_batch_size
validation_interval
:
625
steps_per_loop
:
625
# NUM_EXAMPLES (1281167) // global_batch_size
summary_interval
:
625
checkpoint_interval
:
625
max_to_keep
:
3
optimizer_config
:
learning_rate
:
cosine
:
decay_steps
:
32000
initial_learning_rate
:
4.8
type
:
cosine
optimizer
:
lars
:
exclude_from_weight_decay
:
[
batch_normalization
,
bias
]
momentum
:
0.9
weight_decay_rate
:
1.0e-06
type
:
lars
warmup
:
linear
:
name
:
linear
warmup_steps
:
3200
type
:
linear
official/vision/beta/projects/simclr/configs/multitask_config.py
View file @
ca552843
...
...
@@ -29,6 +29,7 @@ from official.vision.beta.projects.simclr.modeling import simclr_model
@
dataclasses
.
dataclass
class
SimCLRMTHeadConfig
(
hyperparams
.
Config
):
"""Per-task specific configs."""
task_name
:
str
=
'task_name'
# Supervised head is required for finetune, but optional for pretrain.
supervised_head
:
simclr_configs
.
SupervisedHead
=
simclr_configs
.
SupervisedHead
(
num_classes
=
1001
)
...
...
@@ -50,6 +51,9 @@ class SimCLRMTModelConfig(hyperparams.Config):
# L2 weight decay is used in the model, not in task.
# Note that this can not be used together with lars optimizer.
l2_weight_decay
:
float
=
0.0
init_checkpoint
:
str
=
''
# backbone_projection or backbone
init_checkpoint_modules
:
str
=
'backbone_projection'
@
exp_factory
.
register_config_factory
(
'multitask_simclr'
)
...
...
@@ -57,14 +61,17 @@ def multitask_simclr() -> multitask_configs.MultiTaskExperimentConfig:
return
multitask_configs
.
MultiTaskExperimentConfig
(
task
=
multitask_configs
.
MultiTaskConfig
(
model
=
SimCLRMTModelConfig
(
heads
=
(
SimCLRMTHeadConfig
(
mode
=
simclr_model
.
PRETRAIN
),
SimCLRMTHeadConfig
(
mode
=
simclr_model
.
FINETUNE
))),
heads
=
(
SimCLRMTHeadConfig
(
task_name
=
'pretrain_simclr'
,
mode
=
simclr_model
.
PRETRAIN
),
SimCLRMTHeadConfig
(
task_name
=
'finetune_simclr'
,
mode
=
simclr_model
.
FINETUNE
))),
task_routines
=
(
multitask_configs
.
TaskRoutine
(
task_name
=
simclr_model
.
PRETRAIN
,
task_name
=
'pretrain_simclr'
,
task_config
=
simclr_configs
.
SimCLRPretrainTask
(),
task_weight
=
2.0
),
multitask_configs
.
TaskRoutine
(
task_name
=
simclr_model
.
FINETUNE
,
task_name
=
'finetune_simclr'
,
task_config
=
simclr_configs
.
SimCLRFinetuneTask
(),
task_weight
=
1.0
))),
trainer
=
multitask_configs
.
MultiTaskTrainerConfig
())
official/vision/beta/projects/simclr/configs/simclr.py
View file @
ca552843
...
...
@@ -12,27 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SimCLR configurations."""
import
dataclasses
import
os
from
typing
import
List
,
Optional
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
...
...
@@ -73,6 +57,9 @@ class DataConfig(cfg.DataConfig):
# simclr specific configs
parser
:
Parser
=
Parser
()
decoder
:
Decoder
=
Decoder
()
# Useful when doing a sanity check that we absolutely use no labels while
# pretrain by setting labels to zeros (default = False, keep original labels)
input_set_label_to_zero
:
bool
=
False
@
dataclasses
.
dataclass
...
...
@@ -115,9 +102,7 @@ class SimCLRModel(hyperparams.Config):
backbone
:
backbones
.
Backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
())
projection_head
:
ProjectionHead
=
ProjectionHead
(
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
)
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
)
supervised_head
:
SupervisedHead
=
SupervisedHead
(
num_classes
=
1001
)
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
(
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
...
...
@@ -201,9 +186,7 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
(
model_id
=
50
)),
projection_head
=
ProjectionHead
(
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
),
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
),
supervised_head
=
SupervisedHead
(
num_classes
=
1001
),
norm_activation
=
common
.
NormActivation
(
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
True
)),
...
...
@@ -233,10 +216,13 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
'optimizer'
:
{
'type'
:
'lars'
,
'lars'
:
{
'momentum'
:
0.9
,
'weight_decay_rate'
:
0.000001
,
'momentum'
:
0.9
,
'weight_decay_rate'
:
0.000001
,
'exclude_from_weight_decay'
:
[
'batch_normalization'
,
'bias'
]
'batch_normalization'
,
'bias'
]
}
},
'learning_rate'
:
{
...
...
@@ -278,11 +264,8 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
(
model_id
=
50
)),
projection_head
=
ProjectionHead
(
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
),
supervised_head
=
SupervisedHead
(
num_classes
=
1001
,
zero_init
=
True
),
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
),
supervised_head
=
SupervisedHead
(
num_classes
=
1001
,
zero_init
=
True
),
norm_activation
=
common
.
NormActivation
(
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)),
loss
=
ClassificationLosses
(),
...
...
@@ -311,10 +294,13 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
'optimizer'
:
{
'type'
:
'lars'
,
'lars'
:
{
'momentum'
:
0.9
,
'weight_decay_rate'
:
0.0
,
'momentum'
:
0.9
,
'weight_decay_rate'
:
0.0
,
'exclude_from_weight_decay'
:
[
'batch_normalization'
,
'bias'
]
'batch_normalization'
,
'bias'
]
}
},
'learning_rate'
:
{
...
...
official/vision/beta/projects/simclr/configs/simclr_test.py
View file @
ca552843
...
...
@@ -12,23 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for simclr."""
# pylint: disable=unused-import
"""Tests for SimCLR config."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
...
...
official/vision/beta/projects/simclr/dataloaders/preprocess_ops.py
View file @
ca552843
...
...
@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocessing ops."""
import
functools
import
tensorflow
as
tf
...
...
official/vision/beta/projects/simclr/dataloaders/simclr_input.py
View file @
ca552843
...
...
@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data parser and processing for SimCLR.
For pre-training:
...
...
official/vision/beta/projects/simclr/heads/simclr_head.py
View file @
ca552843
...
...
@@ -12,21 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dense prediction heads."""
"""SimCLR prediction heads."""
from
typing
import
Text
,
Optional
...
...
official/vision/beta/projects/simclr/heads/simclr_head_test.py
View file @
ca552843
...
...
@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
import
numpy
as
np
...
...
official/vision/beta/projects/simclr/losses/contrastive_losses.py
View file @
ca552843
...
...
@@ -12,21 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contrastive loss functions."""
import
functools
...
...
official/vision/beta/projects/simclr/losses/contrastive_losses_test.py
View file @
ca552843
...
...
@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
import
numpy
as
np
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment