Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1f8b5b27
Unverified
Commit
1f8b5b27
authored
Sep 03, 2021
by
Simon Geisler
Committed by
GitHub
Sep 03, 2021
Browse files
Merge branch 'master' into master
parents
0eeeaf98
8fcf177e
Changes
99
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1066 additions
and
62 deletions
+1066
-62
official/vision/beta/projects/deepmac_maskrcnn/configs/experiments/deep_mask_head_rcnn_nonvoc_spinenet143_hg52.yaml
...eriments/deep_mask_head_rcnn_nonvoc_spinenet143_hg52.yaml
+61
-0
official/vision/beta/projects/deepmac_maskrcnn/configs/experiments/deep_mask_head_rcnn_voc_spinenet143_hg52.yaml
...experiments/deep_mask_head_rcnn_voc_spinenet143_hg52.yaml
+56
-0
official/vision/beta/projects/movinet/README.md
official/vision/beta/projects/movinet/README.md
+2
-2
official/vision/beta/projects/movinet/export_saved_model_test.py
...l/vision/beta/projects/movinet/export_saved_model_test.py
+2
-4
official/vision/beta/projects/movinet/modeling/movinet.py
official/vision/beta/projects/movinet/modeling/movinet.py
+24
-21
official/vision/beta/projects/movinet/modeling/movinet_layers.py
...l/vision/beta/projects/movinet/modeling/movinet_layers.py
+2
-2
official/vision/beta/projects/movinet/modeling/movinet_test.py
...ial/vision/beta/projects/movinet/modeling/movinet_test.py
+10
-10
official/vision/beta/projects/panoptic_maskrcnn/modeling/panoptic_maskrcnn_model_test.py
...anoptic_maskrcnn/modeling/panoptic_maskrcnn_model_test.py
+1
-1
official/vision/beta/projects/panoptic_maskrcnn/tasks/panoptic_maskrcnn.py
...eta/projects/panoptic_maskrcnn/tasks/panoptic_maskrcnn.py
+5
-5
official/vision/beta/projects/simclr/configs/multitask_config.py
...l/vision/beta/projects/simclr/configs/multitask_config.py
+3
-0
official/vision/beta/projects/simclr/modeling/multitask_model.py
...l/vision/beta/projects/simclr/modeling/multitask_model.py
+36
-10
official/vision/beta/projects/simclr/tasks/simclr.py
official/vision/beta/projects/simclr/tasks/simclr.py
+7
-7
official/vision/beta/projects/video_ssl/README.md
official/vision/beta/projects/video_ssl/README.md
+62
-0
official/vision/beta/projects/video_ssl/configs/__init__.py
official/vision/beta/projects/video_ssl/configs/__init__.py
+4
-0
official/vision/beta/projects/video_ssl/configs/experiments/cvrl_linear_eval_k600.yaml
.../video_ssl/configs/experiments/cvrl_linear_eval_k600.yaml
+92
-0
official/vision/beta/projects/video_ssl/configs/experiments/cvrl_pretrain_k600_200ep.yaml
...deo_ssl/configs/experiments/cvrl_pretrain_k600_200ep.yaml
+73
-0
official/vision/beta/projects/video_ssl/configs/video_ssl.py
official/vision/beta/projects/video_ssl/configs/video_ssl.py
+138
-0
official/vision/beta/projects/video_ssl/configs/video_ssl_test.py
.../vision/beta/projects/video_ssl/configs/video_ssl_test.py
+56
-0
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input.py
...on/beta/projects/video_ssl/dataloaders/video_ssl_input.py
+321
-0
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input_test.py
...ta/projects/video_ssl/dataloaders/video_ssl_input_test.py
+111
-0
No files found.
official/vision/beta/projects/deepmac_maskrcnn/configs/experiments/deep_mask_head_rcnn_nonvoc_spinenet143_hg52.yaml
0 → 100644
View file @
1f8b5b27
# Expect to reach: box mAP: 49.3%, mask mAP: 43.4% on COCO
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
bfloat16'
task
:
allowed_mask_class_ids
:
[
8
,
10
,
11
,
13
,
14
,
15
,
22
,
23
,
24
,
25
,
27
,
28
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
,
41
,
42
,
43
,
46
,
47
,
48
,
49
,
50
,
51
,
52
,
53
,
54
,
55
,
56
,
57
,
58
,
59
,
60
,
61
,
65
,
70
,
73
,
74
,
75
,
76
,
77
,
78
,
79
,
80
,
81
,
82
,
84
,
85
,
86
,
87
,
88
,
89
,
90
]
per_category_metrics
:
true
init_checkpoint
:
null
train_data
:
global_batch_size
:
256
parser
:
aug_rand_hflip
:
true
aug_scale_min
:
0.1
aug_scale_max
:
2.0
losses
:
l2_weight_decay
:
0.00004
model
:
mask_head
:
class_agnostic
:
true
convnet_variant
:
'
hourglass52'
num_filters
:
64
mask_roi_aligner
:
crop_size
:
32
use_gt_boxes_for_masks
:
true
anchor
:
anchor_size
:
4.0
num_scales
:
3
min_level
:
3
max_level
:
7
input_size
:
[
1280
,
1280
,
3
]
backbone
:
spinenet
:
stochastic_depth_drop_rate
:
0.2
model_id
:
'
143'
type
:
'
spinenet'
decoder
:
type
:
'
identity'
norm_activation
:
norm_epsilon
:
0.001
norm_momentum
:
0.99
use_sync_bn
:
true
detection_generator
:
pre_nms_top_k
:
1000
trainer
:
train_steps
:
231000
optimizer_config
:
learning_rate
:
type
:
'
stepwise'
stepwise
:
boundaries
:
[
219450
,
226380
]
values
:
[
0.32
,
0.032
,
0.0032
]
warmup
:
type
:
'
linear'
linear
:
warmup_steps
:
2000
warmup_learning_rate
:
0.0067
official/vision/beta/projects/deepmac_maskrcnn/configs/experiments/deep_mask_head_rcnn_voc_spinenet143_hg52.yaml
0 → 100644
View file @
1f8b5b27
# Expect to reach: box mAP: 49.3%, mask mAP: 43.4% on COCO
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
bfloat16'
task
:
allowed_mask_class_ids
:
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
9
,
16
,
17
,
18
,
19
,
20
,
21
,
44
,
62
,
63
,
64
,
67
,
72
]
per_category_metrics
:
true
init_checkpoint
:
null
train_data
:
global_batch_size
:
256
parser
:
aug_rand_hflip
:
true
aug_scale_min
:
0.1
aug_scale_max
:
2.0
losses
:
l2_weight_decay
:
0.00004
model
:
mask_head
:
class_agnostic
:
true
convnet_variant
:
'
hourglass52'
num_filters
:
64
mask_roi_aligner
:
crop_size
:
32
use_gt_boxes_for_masks
:
true
anchor
:
anchor_size
:
4.0
num_scales
:
3
min_level
:
3
max_level
:
7
input_size
:
[
1280
,
1280
,
3
]
backbone
:
spinenet
:
stochastic_depth_drop_rate
:
0.2
model_id
:
'
143'
type
:
'
spinenet'
decoder
:
type
:
'
identity'
norm_activation
:
norm_epsilon
:
0.001
norm_momentum
:
0.99
use_sync_bn
:
true
detection_generator
:
pre_nms_top_k
:
1000
trainer
:
train_steps
:
231000
optimizer_config
:
learning_rate
:
type
:
'
stepwise'
stepwise
:
boundaries
:
[
219450
,
226380
]
values
:
[
0.32
,
0.032
,
0.0032
]
warmup
:
type
:
'
linear'
linear
:
warmup_steps
:
2000
warmup_learning_rate
:
0.0067
official/vision/beta/projects/movinet/README.md
View file @
1f8b5b27
...
@@ -338,7 +338,7 @@ with the Python API:
...
@@ -338,7 +338,7 @@ with the Python API:
```
python
```
python
# Create the interpreter and signature runner
# Create the interpreter and signature runner
interpreter
=
tf
.
lite
.
Interpreter
(
'/tmp/movinet_a0_stream.tflite'
)
interpreter
=
tf
.
lite
.
Interpreter
(
'/tmp/movinet_a0_stream.tflite'
)
signature
=
interpreter
.
get_signature_runner
()
runner
=
interpreter
.
get_signature_runner
()
# Extract state names and create the initial (zero) states
# Extract state names and create the initial (zero) states
def
state_name
(
name
:
str
)
->
str
:
def
state_name
(
name
:
str
)
->
str
:
...
@@ -358,7 +358,7 @@ clips = tf.split(video, video.shape[1], axis=1)
...
@@ -358,7 +358,7 @@ clips = tf.split(video, video.shape[1], axis=1)
states
=
init_states
states
=
init_states
for
clip
in
clips
:
for
clip
in
clips
:
# Input shape: [1, 1, 172, 172, 3]
# Input shape: [1, 1, 172, 172, 3]
outputs
=
signature
(
**
states
,
image
=
clip
)
outputs
=
runner
(
**
states
,
image
=
clip
)
logits
=
outputs
.
pop
(
'logits'
)
logits
=
outputs
.
pop
(
'logits'
)
states
=
outputs
states
=
outputs
```
```
...
...
official/vision/beta/projects/movinet/export_saved_model_test.py
View file @
1f8b5b27
...
@@ -99,8 +99,6 @@ class ExportSavedModelTest(tf.test.TestCase):
...
@@ -99,8 +99,6 @@ class ExportSavedModelTest(tf.test.TestCase):
self
.
assertAllClose
(
outputs
,
expected_outputs
,
1e-5
,
1e-5
)
self
.
assertAllClose
(
outputs
,
expected_outputs
,
1e-5
,
1e-5
)
def
test_movinet_export_a0_stream_with_tflite
(
self
):
def
test_movinet_export_a0_stream_with_tflite
(
self
):
self
.
skipTest
(
'b/195800800'
)
saved_model_path
=
self
.
get_temp_dir
()
saved_model_path
=
self
.
get_temp_dir
()
FLAGS
.
export_path
=
saved_model_path
FLAGS
.
export_path
=
saved_model_path
...
@@ -123,7 +121,7 @@ class ExportSavedModelTest(tf.test.TestCase):
...
@@ -123,7 +121,7 @@ class ExportSavedModelTest(tf.test.TestCase):
tflite_model
=
converter
.
convert
()
tflite_model
=
converter
.
convert
()
interpreter
=
tf
.
lite
.
Interpreter
(
model_content
=
tflite_model
)
interpreter
=
tf
.
lite
.
Interpreter
(
model_content
=
tflite_model
)
signature
=
interpreter
.
get_signature_runner
()
runner
=
interpreter
.
get_signature_runner
(
'serving_default'
)
def
state_name
(
name
:
str
)
->
str
:
def
state_name
(
name
:
str
)
->
str
:
return
name
[
len
(
'serving_default_'
):
-
len
(
':0'
)]
return
name
[
len
(
'serving_default_'
):
-
len
(
':0'
)]
...
@@ -139,7 +137,7 @@ class ExportSavedModelTest(tf.test.TestCase):
...
@@ -139,7 +137,7 @@ class ExportSavedModelTest(tf.test.TestCase):
states
=
init_states
states
=
init_states
for
clip
in
clips
:
for
clip
in
clips
:
outputs
=
signature
(
**
states
,
image
=
clip
)
outputs
=
runner
(
**
states
,
image
=
clip
)
logits
=
outputs
.
pop
(
'logits'
)
logits
=
outputs
.
pop
(
'logits'
)
states
=
outputs
states
=
outputs
...
...
official/vision/beta/projects/movinet/modeling/movinet.py
View file @
1f8b5b27
...
@@ -17,10 +17,10 @@
...
@@ -17,10 +17,10 @@
Reference: https://arxiv.org/pdf/2103.11511.pdf
Reference: https://arxiv.org/pdf/2103.11511.pdf
"""
"""
import
dataclasses
import
math
import
math
from
typing
import
Dict
,
Mapping
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Dict
,
Mapping
,
Optional
,
Sequence
,
Tuple
,
Union
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
...
@@ -454,7 +454,7 @@ class Movinet(tf.keras.Model):
...
@@ -454,7 +454,7 @@ class Movinet(tf.keras.Model):
stochastic_depth_idx
=
1
stochastic_depth_idx
=
1
for
block_idx
,
block
in
enumerate
(
self
.
_block_specs
):
for
block_idx
,
block
in
enumerate
(
self
.
_block_specs
):
if
isinstance
(
block
,
StemSpec
):
if
isinstance
(
block
,
StemSpec
):
x
,
states
=
movinet_layers
.
Stem
(
layer_obj
=
movinet_layers
.
Stem
(
block
.
filters
,
block
.
filters
,
block
.
kernel_size
,
block
.
kernel_size
,
block
.
strides
,
block
.
strides
,
...
@@ -466,9 +466,9 @@ class Movinet(tf.keras.Model):
...
@@ -466,9 +466,9 @@ class Movinet(tf.keras.Model):
batch_norm_layer
=
self
.
_norm
,
batch_norm_layer
=
self
.
_norm
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
state_prefix
=
'state
/
stem'
,
state_prefix
=
'state
_
stem'
,
name
=
'stem'
)
(
name
=
'stem'
)
x
,
states
=
states
)
x
,
states
=
layer_obj
(
x
,
states
=
states
)
endpoints
[
'stem'
]
=
x
endpoints
[
'stem'
]
=
x
elif
isinstance
(
block
,
MovinetBlockSpec
):
elif
isinstance
(
block
,
MovinetBlockSpec
):
if
not
(
len
(
block
.
expand_filters
)
==
len
(
block
.
kernel_sizes
)
==
if
not
(
len
(
block
.
expand_filters
)
==
len
(
block
.
kernel_sizes
)
==
...
@@ -486,8 +486,8 @@ class Movinet(tf.keras.Model):
...
@@ -486,8 +486,8 @@ class Movinet(tf.keras.Model):
self
.
_stochastic_depth_drop_rate
*
stochastic_depth_idx
/
self
.
_stochastic_depth_drop_rate
*
stochastic_depth_idx
/
num_layers
)
num_layers
)
expand_filters
,
kernel_size
,
strides
=
layer
expand_filters
,
kernel_size
,
strides
=
layer
name
=
f
'b
{
block_idx
-
1
}
/l
{
layer_idx
}
'
name
=
f
'b
lock
{
block_idx
-
1
}
_layer
{
layer_idx
}
'
x
,
states
=
movinet_layers
.
MovinetBlock
(
layer_obj
=
movinet_layers
.
MovinetBlock
(
block
.
base_filters
,
block
.
base_filters
,
expand_filters
,
expand_filters
,
kernel_size
=
kernel_size
,
kernel_size
=
kernel_size
,
...
@@ -505,13 +505,14 @@ class Movinet(tf.keras.Model):
...
@@ -505,13 +505,14 @@ class Movinet(tf.keras.Model):
batch_norm_layer
=
self
.
_norm
,
batch_norm_layer
=
self
.
_norm
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
state_prefix
=
f
'state/
{
name
}
'
,
state_prefix
=
f
'state_
{
name
}
'
,
name
=
name
)(
name
=
name
)
x
,
states
=
states
)
x
,
states
=
layer_obj
(
x
,
states
=
states
)
endpoints
[
name
]
=
x
endpoints
[
name
]
=
x
stochastic_depth_idx
+=
1
stochastic_depth_idx
+=
1
elif
isinstance
(
block
,
HeadSpec
):
elif
isinstance
(
block
,
HeadSpec
):
x
,
states
=
movinet_layers
.
Head
(
layer_obj
=
movinet_layers
.
Head
(
project_filters
=
block
.
project_filters
,
project_filters
=
block
.
project_filters
,
conv_type
=
self
.
_conv_type
,
conv_type
=
self
.
_conv_type
,
activation
=
self
.
_activation
,
activation
=
self
.
_activation
,
...
@@ -520,9 +521,9 @@ class Movinet(tf.keras.Model):
...
@@ -520,9 +521,9 @@ class Movinet(tf.keras.Model):
batch_norm_layer
=
self
.
_norm
,
batch_norm_layer
=
self
.
_norm
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
state_prefix
=
'state
/
head'
,
state_prefix
=
'state
_
head'
,
name
=
'head'
)
(
name
=
'head'
)
x
,
states
=
states
)
x
,
states
=
layer_obj
(
x
,
states
=
states
)
endpoints
[
'head'
]
=
x
endpoints
[
'head'
]
=
x
else
:
else
:
raise
ValueError
(
'Unknown block type {}'
.
format
(
block
))
raise
ValueError
(
'Unknown block type {}'
.
format
(
block
))
...
@@ -567,7 +568,7 @@ class Movinet(tf.keras.Model):
...
@@ -567,7 +568,7 @@ class Movinet(tf.keras.Model):
for
block_idx
,
block
in
enumerate
(
block_specs
):
for
block_idx
,
block
in
enumerate
(
block_specs
):
if
isinstance
(
block
,
StemSpec
):
if
isinstance
(
block
,
StemSpec
):
if
block
.
kernel_size
[
0
]
>
1
:
if
block
.
kernel_size
[
0
]
>
1
:
states
[
'state
/
stem
/
stream_buffer'
]
=
(
states
[
'state
_
stem
_
stream_buffer'
]
=
(
input_shape
[
0
],
input_shape
[
0
],
input_shape
[
1
],
input_shape
[
1
],
divide_resolution
(
input_shape
[
2
],
num_downsamples
),
divide_resolution
(
input_shape
[
2
],
num_downsamples
),
...
@@ -590,8 +591,10 @@ class Movinet(tf.keras.Model):
...
@@ -590,8 +591,10 @@ class Movinet(tf.keras.Model):
self
.
_conv_type
in
[
'2plus1d'
,
'3d_2plus1d'
]):
self
.
_conv_type
in
[
'2plus1d'
,
'3d_2plus1d'
]):
num_downsamples
+=
1
num_downsamples
+=
1
prefix
=
f
'state_block
{
block_idx
}
_layer
{
layer_idx
}
'
if
kernel_size
[
0
]
>
1
:
if
kernel_size
[
0
]
>
1
:
states
[
f
'
state/b
{
block_idx
}
/l
{
layer_idx
}
/
stream_buffer'
]
=
(
states
[
f
'
{
prefix
}
_
stream_buffer'
]
=
(
input_shape
[
0
],
input_shape
[
0
],
kernel_size
[
0
]
-
1
,
kernel_size
[
0
]
-
1
,
divide_resolution
(
input_shape
[
2
],
num_downsamples
),
divide_resolution
(
input_shape
[
2
],
num_downsamples
),
...
@@ -599,13 +602,13 @@ class Movinet(tf.keras.Model):
...
@@ -599,13 +602,13 @@ class Movinet(tf.keras.Model):
expand_filters
,
expand_filters
,
)
)
states
[
f
'
state/b
{
block_idx
}
/l
{
layer_idx
}
/
pool_buffer'
]
=
(
states
[
f
'
{
prefix
}
_
pool_buffer'
]
=
(
input_shape
[
0
],
1
,
1
,
1
,
expand_filters
,
input_shape
[
0
],
1
,
1
,
1
,
expand_filters
,
)
)
states
[
f
'
state/b
{
block_idx
}
/l
{
layer_idx
}
/
pool_frame_count'
]
=
(
1
,)
states
[
f
'
{
prefix
}
_
pool_frame_count'
]
=
(
1
,)
if
use_positional_encoding
:
if
use_positional_encoding
:
name
=
f
'
state/b
{
block_idx
}
/l
{
layer_idx
}
/
pos_enc_frame_count'
name
=
f
'
{
prefix
}
_
pos_enc_frame_count'
states
[
name
]
=
(
1
,)
states
[
name
]
=
(
1
,)
if
strides
[
1
]
!=
strides
[
2
]:
if
strides
[
1
]
!=
strides
[
2
]:
...
@@ -618,10 +621,10 @@ class Movinet(tf.keras.Model):
...
@@ -618,10 +621,10 @@ class Movinet(tf.keras.Model):
self
.
_conv_type
not
in
[
'2plus1d'
,
'3d_2plus1d'
]):
self
.
_conv_type
not
in
[
'2plus1d'
,
'3d_2plus1d'
]):
num_downsamples
+=
1
num_downsamples
+=
1
elif
isinstance
(
block
,
HeadSpec
):
elif
isinstance
(
block
,
HeadSpec
):
states
[
'state
/
head
/
pool_buffer'
]
=
(
states
[
'state
_
head
_
pool_buffer'
]
=
(
input_shape
[
0
],
1
,
1
,
1
,
block
.
project_filters
,
input_shape
[
0
],
1
,
1
,
1
,
block
.
project_filters
,
)
)
states
[
'state
/
head
/
pool_frame_count'
]
=
(
1
,)
states
[
'state
_
head
_
pool_frame_count'
]
=
(
1
,)
return
states
return
states
...
...
official/vision/beta/projects/movinet/modeling/movinet_layers.py
View file @
1f8b5b27
...
@@ -478,7 +478,7 @@ class StreamBuffer(tf.keras.layers.Layer):
...
@@ -478,7 +478,7 @@ class StreamBuffer(tf.keras.layers.Layer):
state_prefix
=
state_prefix
if
state_prefix
is
not
None
else
''
state_prefix
=
state_prefix
if
state_prefix
is
not
None
else
''
self
.
_state_prefix
=
state_prefix
self
.
_state_prefix
=
state_prefix
self
.
_state_name
=
f
'
{
state_prefix
}
/
stream_buffer'
self
.
_state_name
=
f
'
{
state_prefix
}
_
stream_buffer'
self
.
_buffer_size
=
buffer_size
self
.
_buffer_size
=
buffer_size
def
get_config
(
self
):
def
get_config
(
self
):
...
@@ -501,7 +501,7 @@ class StreamBuffer(tf.keras.layers.Layer):
...
@@ -501,7 +501,7 @@ class StreamBuffer(tf.keras.layers.Layer):
inputs: the input tensor.
inputs: the input tensor.
states: a dict of states such that, if any of the keys match for this
states: a dict of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s).
layer, will overwrite the contents of the buffer(s).
Expected keys include `state_prefix + '
/
stream_buffer'`.
Expected keys include `state_prefix + '
_
stream_buffer'`.
Returns:
Returns:
the output tensor and states
the output tensor and states
...
...
official/vision/beta/projects/movinet/modeling/movinet_test.py
View file @
1f8b5b27
...
@@ -35,11 +35,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -35,11 +35,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
endpoints
,
states
=
network
(
inputs
)
endpoints
,
states
=
network
(
inputs
)
self
.
assertAllEqual
(
endpoints
[
'stem'
].
shape
,
[
1
,
8
,
64
,
64
,
8
])
self
.
assertAllEqual
(
endpoints
[
'stem'
].
shape
,
[
1
,
8
,
64
,
64
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
0/l
0'
].
shape
,
[
1
,
8
,
32
,
32
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
lock0_layer
0'
].
shape
,
[
1
,
8
,
32
,
32
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
1/l
0'
].
shape
,
[
1
,
8
,
16
,
16
,
32
])
self
.
assertAllEqual
(
endpoints
[
'b
lock1_layer
0'
].
shape
,
[
1
,
8
,
16
,
16
,
32
])
self
.
assertAllEqual
(
endpoints
[
'b
2/l
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
lock2_layer
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
3/l
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
lock3_layer
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
4/l
0'
].
shape
,
[
1
,
8
,
4
,
4
,
104
])
self
.
assertAllEqual
(
endpoints
[
'b
lock4_layer
0'
].
shape
,
[
1
,
8
,
4
,
4
,
104
])
self
.
assertAllEqual
(
endpoints
[
'head'
].
shape
,
[
1
,
1
,
1
,
1
,
480
])
self
.
assertAllEqual
(
endpoints
[
'head'
].
shape
,
[
1
,
1
,
1
,
1
,
480
])
self
.
assertNotEmpty
(
states
)
self
.
assertNotEmpty
(
states
)
...
@@ -59,11 +59,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -59,11 +59,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
endpoints
,
new_states
=
backbone
({
**
init_states
,
'image'
:
inputs
})
endpoints
,
new_states
=
backbone
({
**
init_states
,
'image'
:
inputs
})
self
.
assertAllEqual
(
endpoints
[
'stem'
].
shape
,
[
1
,
8
,
64
,
64
,
8
])
self
.
assertAllEqual
(
endpoints
[
'stem'
].
shape
,
[
1
,
8
,
64
,
64
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
0/l
0'
].
shape
,
[
1
,
8
,
32
,
32
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
lock0_layer
0'
].
shape
,
[
1
,
8
,
32
,
32
,
8
])
self
.
assertAllEqual
(
endpoints
[
'b
1/l
0'
].
shape
,
[
1
,
8
,
16
,
16
,
32
])
self
.
assertAllEqual
(
endpoints
[
'b
lock1_layer
0'
].
shape
,
[
1
,
8
,
16
,
16
,
32
])
self
.
assertAllEqual
(
endpoints
[
'b
2/l
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
lock2_layer
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
3/l
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
lock3_layer
0'
].
shape
,
[
1
,
8
,
8
,
8
,
56
])
self
.
assertAllEqual
(
endpoints
[
'b
4/l
0'
].
shape
,
[
1
,
8
,
4
,
4
,
104
])
self
.
assertAllEqual
(
endpoints
[
'b
lock4_layer
0'
].
shape
,
[
1
,
8
,
4
,
4
,
104
])
self
.
assertAllEqual
(
endpoints
[
'head'
].
shape
,
[
1
,
1
,
1
,
1
,
480
])
self
.
assertAllEqual
(
endpoints
[
'head'
].
shape
,
[
1
,
1
,
1
,
1
,
480
])
self
.
assertNotEmpty
(
init_states
)
self
.
assertNotEmpty
(
init_states
)
...
...
official/vision/beta/projects/panoptic_maskrcnn/modeling/panoptic_maskrcnn_model_test.py
View file @
1f8b5b27
...
@@ -460,7 +460,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -460,7 +460,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
ckpt
.
save
(
os
.
path
.
join
(
save_dir
,
'ckpt'
))
ckpt
.
save
(
os
.
path
.
join
(
save_dir
,
'ckpt'
))
partial_ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
backbone
)
partial_ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
backbone
)
partial_ckpt
.
re
store
(
tf
.
train
.
latest_checkpoint
(
partial_ckpt
.
re
ad
(
tf
.
train
.
latest_checkpoint
(
save_dir
)).
expect_partial
().
assert_existing_objects_matched
()
save_dir
)).
expect_partial
().
assert_existing_objects_matched
()
partial_ckpt_mask
=
tf
.
train
.
Checkpoint
(
partial_ckpt_mask
=
tf
.
train
.
Checkpoint
(
...
...
official/vision/beta/projects/panoptic_maskrcnn/tasks/panoptic_maskrcnn.py
View file @
1f8b5b27
...
@@ -77,14 +77,14 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
...
@@ -77,14 +77,14 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
checkpoint_path
=
_get_checkpoint_path
(
checkpoint_path
=
_get_checkpoint_path
(
self
.
task_config
.
init_checkpoint
)
self
.
task_config
.
init_checkpoint
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
re
store
(
checkpoint_path
)
status
=
ckpt
.
re
ad
(
checkpoint_path
)
status
.
assert_consum
ed
()
status
.
expect_partial
().
assert_existing_objects_match
ed
()
elif
init_module
==
'backbone'
:
elif
init_module
==
'backbone'
:
checkpoint_path
=
_get_checkpoint_path
(
checkpoint_path
=
_get_checkpoint_path
(
self
.
task_config
.
init_checkpoint
)
self
.
task_config
.
init_checkpoint
)
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
re
store
(
checkpoint_path
)
status
=
ckpt
.
re
ad
(
checkpoint_path
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
elif
init_module
==
'segmentation_backbone'
:
elif
init_module
==
'segmentation_backbone'
:
...
@@ -92,7 +92,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
...
@@ -92,7 +92,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self
.
task_config
.
segmentation_init_checkpoint
)
self
.
task_config
.
segmentation_init_checkpoint
)
ckpt
=
tf
.
train
.
Checkpoint
(
ckpt
=
tf
.
train
.
Checkpoint
(
segmentation_backbone
=
model
.
segmentation_backbone
)
segmentation_backbone
=
model
.
segmentation_backbone
)
status
=
ckpt
.
re
store
(
checkpoint_path
)
status
=
ckpt
.
re
ad
(
checkpoint_path
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
elif
init_module
==
'segmentation_decoder'
:
elif
init_module
==
'segmentation_decoder'
:
...
@@ -100,7 +100,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
...
@@ -100,7 +100,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self
.
task_config
.
segmentation_init_checkpoint
)
self
.
task_config
.
segmentation_init_checkpoint
)
ckpt
=
tf
.
train
.
Checkpoint
(
ckpt
=
tf
.
train
.
Checkpoint
(
segmentation_decoder
=
model
.
segmentation_decoder
)
segmentation_decoder
=
model
.
segmentation_decoder
)
status
=
ckpt
.
re
store
(
checkpoint_path
)
status
=
ckpt
.
re
ad
(
checkpoint_path
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
else
:
...
...
official/vision/beta/projects/simclr/configs/multitask_config.py
View file @
1f8b5b27
...
@@ -51,6 +51,9 @@ class SimCLRMTModelConfig(hyperparams.Config):
...
@@ -51,6 +51,9 @@ class SimCLRMTModelConfig(hyperparams.Config):
# L2 weight decay is used in the model, not in task.
# L2 weight decay is used in the model, not in task.
# Note that this can not be used together with lars optimizer.
# Note that this can not be used together with lars optimizer.
l2_weight_decay
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
init_checkpoint
:
str
=
''
# backbone_projection or backbone
init_checkpoint_modules
:
str
=
'backbone_projection'
@
exp_factory
.
register_config_factory
(
'multitask_simclr'
)
@
exp_factory
.
register_config_factory
(
'multitask_simclr'
)
...
...
official/vision/beta/projects/simclr/modeling/multitask_model.py
View file @
1f8b5b27
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
"""Multi-task image multi-taskSimCLR model definition."""
"""Multi-task image multi-taskSimCLR model definition."""
from
typing
import
Dict
,
Text
from
typing
import
Dict
,
Text
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -52,15 +53,10 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
...
@@ -52,15 +53,10 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
norm_activation_config
=
config
.
norm_activation
,
norm_activation_config
=
config
.
norm_activation
,
l2_regularizer
=
self
.
_l2_regularizer
)
l2_regularizer
=
self
.
_l2_regularizer
)
super
().
__init__
(
**
kwargs
)
def
_instantiate_sub_tasks
(
self
)
->
Dict
[
Text
,
tf
.
keras
.
Model
]:
tasks
=
{}
# Build the shared projection head
# Build the shared projection head
norm_activation_config
=
self
.
_config
.
norm_activation
norm_activation_config
=
self
.
_config
.
norm_activation
projection_head_config
=
self
.
_config
.
projection_head
projection_head_config
=
self
.
_config
.
projection_head
projection_head
=
simclr_head
.
ProjectionHead
(
self
.
_
projection_head
=
simclr_head
.
ProjectionHead
(
proj_output_dim
=
projection_head_config
.
proj_output_dim
,
proj_output_dim
=
projection_head_config
.
proj_output_dim
,
num_proj_layers
=
projection_head_config
.
num_proj_layers
,
num_proj_layers
=
projection_head_config
.
num_proj_layers
,
ft_proj_idx
=
projection_head_config
.
ft_proj_idx
,
ft_proj_idx
=
projection_head_config
.
ft_proj_idx
,
...
@@ -69,6 +65,11 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
...
@@ -69,6 +65,11 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
)
norm_epsilon
=
norm_activation_config
.
norm_epsilon
)
super
().
__init__
(
**
kwargs
)
def
_instantiate_sub_tasks
(
self
)
->
Dict
[
Text
,
tf
.
keras
.
Model
]:
tasks
=
{}
for
model_config
in
self
.
_config
.
heads
:
for
model_config
in
self
.
_config
.
heads
:
# Build supervised head
# Build supervised head
supervised_head_config
=
model_config
.
supervised_head
supervised_head_config
=
model_config
.
supervised_head
...
@@ -87,13 +88,38 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
...
@@ -87,13 +88,38 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
tasks
[
model_config
.
task_name
]
=
simclr_model
.
SimCLRModel
(
tasks
[
model_config
.
task_name
]
=
simclr_model
.
SimCLRModel
(
input_specs
=
self
.
_input_specs
,
input_specs
=
self
.
_input_specs
,
backbone
=
self
.
_backbone
,
backbone
=
self
.
_backbone
,
projection_head
=
projection_head
,
projection_head
=
self
.
_
projection_head
,
supervised_head
=
supervised_head
,
supervised_head
=
supervised_head
,
mode
=
model_config
.
mode
,
mode
=
model_config
.
mode
,
backbone_trainable
=
self
.
_config
.
backbone_trainable
)
backbone_trainable
=
self
.
_config
.
backbone_trainable
)
return
tasks
return
tasks
# TODO(huythong): Implement initialize function to load the pretrained
def
initialize
(
self
):
# checkpoint of backbone.
"""Loads the multi-task SimCLR model with a pretrained checkpoint."""
# def initialize(self):
ckpt_dir_or_file
=
self
.
_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
return
logging
.
info
(
'Loading pretrained %s'
,
self
.
_config
.
init_checkpoint_modules
)
if
self
.
_config
.
init_checkpoint_modules
==
'backbone'
:
pretrained_items
=
dict
(
backbone
=
self
.
_backbone
)
elif
self
.
_config
.
init_checkpoint_modules
==
'backbone_projection'
:
pretrained_items
=
dict
(
backbone
=
self
.
_backbone
,
projection_head
=
self
.
_projection_head
)
else
:
assert
(
"Only 'backbone_projection' or 'backbone' can be used to "
'initialize the model.'
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
pretrained_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
return
dict
(
backbone
=
self
.
_backbone
,
projection_head
=
self
.
_projection_head
)
official/vision/beta/projects/simclr/tasks/simclr.py
View file @
1f8b5b27
...
@@ -150,11 +150,11 @@ class SimCLRPretrainTask(base_task.Task):
...
@@ -150,11 +150,11 @@ class SimCLRPretrainTask(base_task.Task):
# Restoring checkpoint.
# Restoring checkpoint.
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
re
store
(
ckpt_dir_or_file
)
status
=
ckpt
.
re
ad
(
ckpt_dir_or_file
)
status
.
assert_consum
ed
()
status
.
expect_partial
().
assert_existing_objects_match
ed
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
re
store
(
ckpt_dir_or_file
)
status
=
ckpt
.
re
ad
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
else
:
assert
"Only 'all' or 'backbone' can be used to initialize the model."
assert
"Only 'all' or 'backbone' can be used to initialize the model."
...
@@ -455,16 +455,16 @@ class SimCLRFinetuneTask(base_task.Task):
...
@@ -455,16 +455,16 @@ class SimCLRFinetuneTask(base_task.Task):
# Restoring checkpoint.
# Restoring checkpoint.
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
re
store
(
ckpt_dir_or_file
)
status
=
ckpt
.
re
ad
(
ckpt_dir_or_file
)
status
.
assert_consum
ed
()
status
.
expect_partial
().
assert_existing_objects_match
ed
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone_projection'
:
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone_projection'
:
ckpt
=
tf
.
train
.
Checkpoint
(
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
,
projection_head
=
model
.
projection_head
)
backbone
=
model
.
backbone
,
projection_head
=
model
.
projection_head
)
status
=
ckpt
.
re
store
(
ckpt_dir_or_file
)
status
=
ckpt
.
re
ad
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
re
store
(
ckpt_dir_or_file
)
status
=
ckpt
.
re
ad
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
else
:
assert
"Only 'all' or 'backbone' can be used to initialize the model."
assert
"Only 'all' or 'backbone' can be used to initialize the model."
...
...
official/vision/beta/projects/video_ssl/README.md
0 → 100644
View file @
1f8b5b27
# Spatiotemporal Contrastive Video Representation Learning
[

](https://arxiv.org/abs/2008.03800)
This repository is the official TF2 implementation of
[
Spatiotemporal Contrastive Video Representation Learning
](
https://arxiv.org/abs/2008.03800
)
.
<p
align=
"left"
>
<img
src=
"https://storage.googleapis.com/tf_model_garden/vision/cvrl/artifacts/cvrl_overview.png"
height=
350
>
</p>
## Description
We present a self-supervised Contrastive Video Representation Learning (CVRL)
method to learn spatiotemporal visual representations from unlabeled videos. Our
representations are learned using a contrastive loss, where two augmented clips
from the same short video are pulled together in the embedding space, while
clips from different videos are pushed away. CVRL significantly closes the gap
between unsupervised and supervised video representation learning.
We release the code and pre-trained models.
More pre-trained model checkpoints and a detailed instruction about the code
will be updated.
## Experimental Results
### Kinetics-600 top-1 linear classification accuracy
<p
align=
"left"
>
<img
src=
"https://storage.googleapis.com/tf_model_garden/vision/cvrl/artifacts/cvrl_results.png"
height=
350
>
</p>
## Pre-trained Model Checkpoints
We provide model checkpoints pre-trained on unlabeled RGB videos from
Kinetics-400 and Kinetics-600. All models are trained scratch with random
initialization.
We also provide a baseline model checkpoint of "ImageNet inflated" we used in
the paper. The model has the same architecture as 3D-ResNet-50 (R3D-50), with
model weights inflated from a 2D ResNet-50 pre-trained on ImageNet.
| Model | Parameters | Dataset | Epochs | K400 Linear Eval. | K600 Linear Eval. | Checkpoint |
| :--------------: | :----: | :--: | :--: |:-----------: | :----------: | :----------: |
| R3D-50 (1x) | 31.7M | ImageNet | - | 53.5% | 54.7% |
[
ckpt (127 MB)
](
https://storage.googleapis.com/tf_model_garden/vision/cvrl/imagenet.tar.gz
)
|
| R3D-50 (1x) | 31.7M | Kinetics-400 | 200 | 63.8% | - |
[
ckpt (127 MB)
](
https://storage.googleapis.com/tf_model_garden/vision/cvrl/r3d_1x_k400_200ep.tar.gz
)
|
| R3D-50 (1x) | 31.7M | Kinetics-400 | 800 | 66.1% | - |
[
ckpt (127 MB)
](
https://storage.googleapis.com/tf_model_garden/vision/cvrl/r3d_1x_k400_800ep.tar.gz
)
|
| R3D-50 (1x) | 31.7M | Kinetics-600 | 800 | 68.5% | 70.4% |
[
ckpt (127 MB)
](
https://storage.googleapis.com/tf_model_garden/vision/cvrl/r3d_1x_k600_800ep.tar.gz
)
|
## Citation
```
@inproceedings{qian2021spatiotemporal,
title={Spatiotemporal contrastive video representation learning},
author={Qian, Rui and Meng, Tianjian and Gong, Boqing and Yang, Ming-Hsuan and Wang, Huisheng and Belongie, Serge and Cui, Yin},
booktitle={CVPR},
year={2021}
}
```
official/
staging
/__init__.py
→
official/
vision/beta/projects/video_ssl/configs
/__init__.py
View file @
1f8b5b27
...
@@ -12,3 +12,7 @@
...
@@ -12,3 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Configs package definition."""
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
official/vision/beta/projects/video_ssl/configs/experiments/cvrl_linear_eval_k600.yaml
0 → 100644
View file @
1f8b5b27
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
bfloat16'
task
:
# Put the pretrained checkpoint here for linear evaluation
init_checkpoint
:
'
r3d_1x_k600_800ep_backbone-1'
init_checkpoint_modules
:
'
backbone'
model
:
dropout_rate
:
1.0
norm_activation
:
use_sync_bn
:
false
backbone
:
resnet_3d
:
block_specs
:
!!python/tuple
-
temporal_kernel_sizes
:
!!python/tuple
-
1
-
1
-
1
temporal_strides
:
1
use_self_gating
:
false
-
temporal_kernel_sizes
:
!!python/tuple
-
1
-
1
-
1
-
1
temporal_strides
:
1
use_self_gating
:
false
-
temporal_kernel_sizes
:
!!python/tuple
-
3
-
3
-
3
-
3
-
3
-
3
temporal_strides
:
1
use_self_gating
:
false
-
temporal_kernel_sizes
:
!!python/tuple
-
3
-
3
-
3
temporal_strides
:
1
use_self_gating
:
false
model_id
:
50
stem_conv_temporal_kernel_size
:
5
stem_conv_temporal_stride
:
2
stem_pool_temporal_stride
:
1
train_data
:
name
:
kinetics600
feature_shape
:
!!python/tuple
-
32
-
224
-
224
-
3
temporal_stride
:
2
global_batch_size
:
1024
dtype
:
'
bfloat16'
shuffle_buffer_size
:
1024
aug_max_area_ratio
:
1.0
aug_max_aspect_ratio
:
2.0
aug_min_area_ratio
:
0.3
aug_min_aspect_ratio
:
0.5
validation_data
:
name
:
kinetics600
feature_shape
:
!!python/tuple
-
32
-
256
-
256
-
3
temporal_stride
:
2
num_test_clips
:
10
num_test_crops
:
3
global_batch_size
:
64
dtype
:
'
bfloat16'
drop_remainder
:
false
losses
:
l2_weight_decay
:
0.0
trainer
:
optimizer_config
:
learning_rate
:
cosine
:
initial_learning_rate
:
32.0
decay_steps
:
35744
optimizer
:
sgd
:
nesterov
:
false
warmup
:
linear
:
warmup_steps
:
1787
train_steps
:
35744
steps_per_loop
:
100
summary_interval
:
100
validation_interval
:
100
official/vision/beta/projects/video_ssl/configs/experiments/cvrl_pretrain_k600_200ep.yaml
0 → 100644
View file @
1f8b5b27
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
bfloat16'
task
:
model
:
dropout_rate
:
1.0
norm_activation
:
use_sync_bn
:
true
hidden_norm_activation
:
use_sync_bn
:
true
backbone
:
resnet_3d
:
block_specs
:
!!python/tuple
-
temporal_kernel_sizes
:
!!python/tuple
-
1
-
1
-
1
temporal_strides
:
1
use_self_gating
:
false
-
temporal_kernel_sizes
:
!!python/tuple
-
1
-
1
-
1
-
1
temporal_strides
:
1
use_self_gating
:
false
-
temporal_kernel_sizes
:
!!python/tuple
-
3
-
3
-
3
-
3
-
3
-
3
temporal_strides
:
1
use_self_gating
:
false
-
temporal_kernel_sizes
:
!!python/tuple
-
3
-
3
-
3
temporal_strides
:
1
use_self_gating
:
false
model_id
:
50
stem_conv_temporal_kernel_size
:
5
stem_conv_temporal_stride
:
2
stem_pool_temporal_stride
:
1
train_data
:
name
:
kinetics600
feature_shape
:
!!python/tuple
-
16
-
224
-
224
-
3
temporal_stride
:
2
global_batch_size
:
1024
dtype
:
'
bfloat16'
shuffle_buffer_size
:
1024
losses
:
l2_weight_decay
:
0.000001
trainer
:
optimizer_config
:
learning_rate
:
cosine
:
initial_learning_rate
:
0.32
decay_steps
:
71488
optimizer
:
sgd
:
nesterov
:
false
warmup
:
linear
:
warmup_steps
:
1787
train_steps
:
71488
steps_per_loop
:
100
summary_interval
:
100
official/vision/beta/projects/video_ssl/configs/video_ssl.py
0 → 100644
View file @
1f8b5b27
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video classification configuration definition."""
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.vision.beta.configs
import
common
from
official.vision.beta.configs
import
video_classification
Losses
=
video_classification
.
Losses
VideoClassificationModel
=
video_classification
.
VideoClassificationModel
VideoClassificationTask
=
video_classification
.
VideoClassificationTask
@
dataclasses
.
dataclass
class
VideoSSLPretrainTask
(
VideoClassificationTask
):
pass
@
dataclasses
.
dataclass
class
VideoSSLEvalTask
(
VideoClassificationTask
):
pass
@
dataclasses
.
dataclass
class
DataConfig
(
video_classification
.
DataConfig
):
"""The base configuration for building datasets."""
is_ssl
:
bool
=
False
@
dataclasses
.
dataclass
class
VideoSSLModel
(
VideoClassificationModel
):
"""The model config."""
normalize_feature
:
bool
=
False
hidden_dim
:
int
=
2048
hidden_layer_num
:
int
=
3
projection_dim
:
int
=
128
hidden_norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
(
use_sync_bn
=
False
,
norm_momentum
=
0.997
,
norm_epsilon
=
1.0e-05
)
@
dataclasses
.
dataclass
class
SSLLosses
(
Losses
):
normalize_hidden
:
bool
=
True
temperature
:
float
=
0.1
@
exp_factory
.
register_config_factory
(
'video_ssl_pretrain_kinetics400'
)
def
video_ssl_pretrain_kinetics400
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics400
()
exp
.
task
=
VideoSSLPretrainTask
(
**
exp
.
task
.
as_dict
())
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
True
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
16
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
exp
.
task
.
model
=
VideoSSLModel
(
exp
.
task
.
model
)
exp
.
task
.
model
.
model_type
=
'video_ssl_model'
exp
.
task
.
losses
=
SSLLosses
(
exp
.
task
.
losses
)
return
exp
@
exp_factory
.
register_config_factory
(
'video_ssl_linear_eval_kinetics400'
)
def
video_ssl_linear_eval_kinetics400
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics400
()
exp
.
task
=
VideoSSLEvalTask
(
**
exp
.
task
.
as_dict
())
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
32
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
.
feature_shape
=
(
32
,
256
,
256
,
3
)
exp
.
task
.
validation_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
validation_data
.
as_dict
())
exp
.
task
.
validation_data
.
min_image_size
=
256
exp
.
task
.
validation_data
.
num_test_clips
=
10
exp
.
task
.
validation_data
.
num_test_crops
=
3
exp
.
task
.
model
=
VideoSSLModel
(
exp
.
task
.
model
)
exp
.
task
.
model
.
model_type
=
'video_ssl_model'
exp
.
task
.
model
.
normalize_feature
=
True
exp
.
task
.
model
.
hidden_layer_num
=
0
exp
.
task
.
model
.
projection_dim
=
400
return
exp
@
exp_factory
.
register_config_factory
(
'video_ssl_pretrain_kinetics600'
)
def
video_ssl_pretrain_kinetics600
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics600
()
exp
.
task
=
VideoSSLPretrainTask
(
**
exp
.
task
.
as_dict
())
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
True
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
16
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
exp
.
task
.
model
=
VideoSSLModel
(
exp
.
task
.
model
)
exp
.
task
.
model
.
model_type
=
'video_ssl_model'
exp
.
task
.
losses
=
SSLLosses
(
exp
.
task
.
losses
)
return
exp
@
exp_factory
.
register_config_factory
(
'video_ssl_linear_eval_kinetics600'
)
def
video_ssl_linear_eval_kinetics600
()
->
cfg
.
ExperimentConfig
:
"""Pretrain SSL Video classification on Kinectics 400 with resnet."""
exp
=
video_classification
.
video_classification_kinetics600
()
exp
.
task
=
VideoSSLEvalTask
(
**
exp
.
task
.
as_dict
())
exp
.
task
.
train_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
train_data
.
as_dict
())
exp
.
task
.
train_data
.
feature_shape
=
(
32
,
224
,
224
,
3
)
exp
.
task
.
train_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
=
DataConfig
(
is_ssl
=
False
,
**
exp
.
task
.
validation_data
.
as_dict
())
exp
.
task
.
validation_data
.
feature_shape
=
(
32
,
256
,
256
,
3
)
exp
.
task
.
validation_data
.
temporal_stride
=
2
exp
.
task
.
validation_data
.
min_image_size
=
256
exp
.
task
.
validation_data
.
num_test_clips
=
10
exp
.
task
.
validation_data
.
num_test_crops
=
3
exp
.
task
.
model
=
VideoSSLModel
(
exp
.
task
.
model
)
exp
.
task
.
model
.
model_type
=
'video_ssl_model'
exp
.
task
.
model
.
normalize_feature
=
True
exp
.
task
.
model
.
hidden_layer_num
=
0
exp
.
task
.
model
.
projection_dim
=
600
return
exp
official/vision/beta/projects/video_ssl/configs/video_ssl_test.py
0 → 100644
View file @
1f8b5b27
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.vision
import
beta
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
class
VideoClassificationConfigTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
((
'video_ssl_pretrain_kinetics400'
,),
(
'video_ssl_pretrain_kinetics600'
,))
def
test_video_ssl_pretrain_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
VideoSSLPretrainTask
)
self
.
assertIsInstance
(
config
.
task
.
model
,
exp_cfg
.
VideoSSLModel
)
self
.
assertIsInstance
(
config
.
task
.
losses
,
exp_cfg
.
SSLLosses
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
exp_cfg
.
DataConfig
)
config
.
task
.
train_data
.
is_training
=
None
with
self
.
assertRaises
(
KeyError
):
config
.
validate
()
@
parameterized
.
parameters
((
'video_ssl_linear_eval_kinetics400'
,),
(
'video_ssl_linear_eval_kinetics600'
,))
def
test_video_ssl_linear_eval_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
VideoSSLEvalTask
)
self
.
assertIsInstance
(
config
.
task
.
model
,
exp_cfg
.
VideoSSLModel
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
exp_cfg
.
DataConfig
)
config
.
task
.
train_data
.
is_training
=
None
with
self
.
assertRaises
(
KeyError
):
config
.
validate
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input.py
0 → 100644
View file @
1f8b5b27
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Parser for video and label datasets."""
from
typing
import
Dict
,
Optional
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
video_input
from
official.vision.beta.ops
import
preprocess_ops_3d
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
from
official.vision.beta.projects.video_ssl.ops
import
video_ssl_preprocess_ops
IMAGE_KEY
=
'image/encoded'
LABEL_KEY
=
'clip/label/index'
Decoder
=
video_input
.
Decoder
def
_process_image
(
image
:
tf
.
Tensor
,
is_training
:
bool
=
True
,
is_ssl
:
bool
=
False
,
num_frames
:
int
=
32
,
stride
:
int
=
1
,
num_test_clips
:
int
=
1
,
min_resize
:
int
=
256
,
crop_size
:
int
=
224
,
num_crops
:
int
=
1
,
zero_centering_image
:
bool
=
False
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Processes a serialized image tensor.
Args:
image: Input Tensor of shape [timesteps] and type tf.string of serialized
frames.
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
stride: Temporal stride to sample frames.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
min_resize: Frames are resized so that min(height, width) is min_resize.
crop_size: Final size of the frame after cropping the resized frames. Both
height and width are the same.
num_crops: Number of crops to perform on the resized frames.
zero_centering_image: If True, frames are normalized to values in [-1, 1].
If False, values in [0, 1].
seed: A deterministic seed to use when sampling.
Returns:
Processed frames. Tensor of shape
[num_frames * num_test_clips, crop_size, crop_size, 3].
"""
# Validate parameters.
if
is_training
and
num_test_clips
!=
1
:
logging
.
warning
(
'`num_test_clips` %d is ignored since `is_training` is `True`.'
,
num_test_clips
)
# Temporal sampler.
if
is_training
:
# Sampler for training.
if
is_ssl
:
# Sample two clips from linear decreasing distribution.
image
=
video_ssl_preprocess_ops
.
sample_ssl_sequence
(
image
,
num_frames
,
True
,
stride
)
else
:
# Sample random clip.
image
=
preprocess_ops_3d
.
sample_sequence
(
image
,
num_frames
,
True
,
stride
)
else
:
# Sampler for evaluation.
if
num_test_clips
>
1
:
# Sample linspace clips.
image
=
preprocess_ops_3d
.
sample_linspace_sequence
(
image
,
num_test_clips
,
num_frames
,
stride
)
else
:
# Sample middle clip.
image
=
preprocess_ops_3d
.
sample_sequence
(
image
,
num_frames
,
False
,
stride
)
# Decode JPEG string to tf.uint8.
image
=
preprocess_ops_3d
.
decode_jpeg
(
image
,
3
)
if
is_training
:
# Standard image data augmentation: random resized crop and random flip.
if
is_ssl
:
image_1
,
image_2
=
tf
.
split
(
image
,
num_or_size_splits
=
2
,
axis
=
0
)
image_1
=
preprocess_ops_3d
.
random_crop_resize
(
image_1
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
0.5
,
2
),
(
0.3
,
1
))
image_1
=
preprocess_ops_3d
.
random_flip_left_right
(
image_1
,
seed
)
image_2
=
preprocess_ops_3d
.
random_crop_resize
(
image_2
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
0.5
,
2
),
(
0.3
,
1
))
image_2
=
preprocess_ops_3d
.
random_flip_left_right
(
image_2
,
seed
)
else
:
image
=
preprocess_ops_3d
.
random_crop_resize
(
image
,
crop_size
,
crop_size
,
num_frames
,
3
,
(
0.5
,
2
),
(
0.3
,
1
))
image
=
preprocess_ops_3d
.
random_flip_left_right
(
image
,
seed
)
else
:
# Resize images (resize happens only if necessary to save compute).
image
=
preprocess_ops_3d
.
resize_smallest
(
image
,
min_resize
)
# Three-crop of the frames.
image
=
preprocess_ops_3d
.
crop_image
(
image
,
crop_size
,
crop_size
,
False
,
num_crops
)
# Cast the frames in float32, normalizing according to zero_centering_image.
if
is_training
and
is_ssl
:
image_1
=
preprocess_ops_3d
.
normalize_image
(
image_1
,
zero_centering_image
)
image_2
=
preprocess_ops_3d
.
normalize_image
(
image_2
,
zero_centering_image
)
else
:
image
=
preprocess_ops_3d
.
normalize_image
(
image
,
zero_centering_image
)
# Self-supervised pre-training augmentations.
if
is_training
and
is_ssl
:
# Temporally consistent color jittering.
image_1
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
image_1
)
image_2
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
image_2
)
# Temporally consistent gaussian blurring.
image_1
=
video_ssl_preprocess_ops
.
random_blur
(
image_1
,
crop_size
,
crop_size
,
1.0
)
image_2
=
video_ssl_preprocess_ops
.
random_blur
(
image_2
,
crop_size
,
crop_size
,
0.1
)
image_2
=
video_ssl_preprocess_ops
.
random_solarization
(
image_2
)
image
=
tf
.
concat
([
image_1
,
image_2
],
axis
=
0
)
image
=
tf
.
clip_by_value
(
image
,
0.
,
1.
)
return
image
def
_postprocess_image
(
image
:
tf
.
Tensor
,
is_training
:
bool
=
True
,
is_ssl
:
bool
=
False
,
num_frames
:
int
=
32
,
num_test_clips
:
int
=
1
,
num_test_crops
:
int
=
1
)
->
tf
.
Tensor
:
"""Processes a batched Tensor of frames.
The same parameters used in process should be used here.
Args:
image: Input Tensor of shape [batch, timesteps, height, width, 3].
is_training: Whether or not in training mode. If True, random sample, crop
and left right flip is used.
is_ssl: Whether or not in self-supervised pre-training mode.
num_frames: Number of frames per subclip.
num_test_clips: Number of test clips (1 by default). If more than 1, this
will sample multiple linearly spaced clips within each video at test time.
If 1, then a single clip in the middle of the video is sampled. The clips
are aggreagated in the batch dimension.
num_test_crops: Number of test crops (1 by default). If more than 1, there
are multiple crops for each clip at test time. If 1, there is a single
central crop. The crops are aggreagated in the batch dimension.
Returns:
Processed frames. Tensor of shape
[batch * num_test_clips * num_test_crops, num_frames, height, width, 3].
"""
if
is_ssl
and
is_training
:
# In this case, two clips of self-supervised pre-training are merged
# together in batch dimenstion which will be 2 * batch.
image
=
tf
.
concat
(
tf
.
split
(
image
,
num_or_size_splits
=
2
,
axis
=
1
),
axis
=
0
)
num_views
=
num_test_clips
*
num_test_crops
if
num_views
>
1
and
not
is_training
:
# In this case, multiple views are merged together in batch dimenstion which
# will be batch * num_views.
image
=
tf
.
reshape
(
image
,
[
-
1
,
num_frames
]
+
image
.
shape
[
2
:].
as_list
())
return
image
def
_process_label
(
label
:
tf
.
Tensor
,
one_hot_label
:
bool
=
True
,
num_classes
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Processes label Tensor."""
# Validate parameters.
if
one_hot_label
and
not
num_classes
:
raise
ValueError
(
'`num_classes` should be given when requesting one hot label.'
)
# Cast to tf.int32.
label
=
tf
.
cast
(
label
,
dtype
=
tf
.
int32
)
if
one_hot_label
:
# Replace label index by one hot representation.
label
=
tf
.
one_hot
(
label
,
num_classes
)
if
len
(
label
.
shape
.
as_list
())
>
1
:
label
=
tf
.
reduce_sum
(
label
,
axis
=
0
)
if
num_classes
==
1
:
# The trick for single label.
label
=
1
-
label
return
label
class
Parser
(
video_input
.
Parser
):
"""Parses a video and label dataset."""
def
__init__
(
self
,
input_params
:
exp_cfg
.
DataConfig
,
image_key
:
str
=
IMAGE_KEY
,
label_key
:
str
=
LABEL_KEY
):
super
(
Parser
,
self
).
__init__
(
input_params
,
image_key
,
label_key
)
self
.
_is_ssl
=
input_params
.
is_ssl
def
_parse_train_data
(
self
,
decoded_tensors
:
Dict
[
str
,
tf
.
Tensor
]
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses data for training."""
# Process image and label.
image
=
decoded_tensors
[
self
.
_image_key
]
image
=
_process_image
(
image
=
image
,
is_training
=
True
,
is_ssl
=
self
.
_is_ssl
,
num_frames
=
self
.
_num_frames
,
stride
=
self
.
_stride
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
_process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
)
return
features
,
label
def
_parse_eval_data
(
self
,
decoded_tensors
:
Dict
[
str
,
tf
.
Tensor
]
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses data for evaluation."""
image
=
decoded_tensors
[
self
.
_image_key
]
image
=
_process_image
(
image
=
image
,
is_training
=
False
,
num_frames
=
self
.
_num_frames
,
stride
=
self
.
_stride
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
,
num_crops
=
self
.
_num_crops
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
label
=
decoded_tensors
[
self
.
_label_key
]
label
=
_process_label
(
label
,
self
.
_one_hot_label
,
self
.
_num_classes
)
if
self
.
_output_audio
:
audio
=
decoded_tensors
[
self
.
_audio_feature
]
audio
=
tf
.
cast
(
audio
,
dtype
=
self
.
_dtype
)
audio
=
preprocess_ops_3d
.
sample_sequence
(
audio
,
20
,
random
=
False
,
stride
=
1
)
audio
=
tf
.
ensure_shape
(
audio
,
[
20
,
2048
])
features
[
'audio'
]
=
audio
return
features
,
label
def
parse_fn
(
self
,
is_training
):
"""Returns a parse fn that reads and parses raw tensors from the decoder.
Args:
is_training: a `bool` to indicate whether it is in training mode.
Returns:
parse: a `callable` that takes the serialized examle and generate the
images, labels tuple where labels is a dict of Tensors that contains
labels.
"""
def
parse
(
decoded_tensors
):
"""Parses the serialized example data."""
if
is_training
:
return
self
.
_parse_train_data
(
decoded_tensors
)
else
:
return
self
.
_parse_eval_data
(
decoded_tensors
)
return
parse
class
PostBatchProcessor
(
object
):
"""Processes a video and label dataset which is batched."""
def
__init__
(
self
,
input_params
:
exp_cfg
.
DataConfig
):
self
.
_is_training
=
input_params
.
is_training
self
.
_is_ssl
=
input_params
.
is_ssl
self
.
_num_frames
=
input_params
.
feature_shape
[
0
]
self
.
_num_test_clips
=
input_params
.
num_test_clips
self
.
_num_test_crops
=
input_params
.
num_test_crops
def
__call__
(
self
,
features
:
Dict
[
str
,
tf
.
Tensor
],
label
:
tf
.
Tensor
)
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses a single tf.Example into image and label tensors."""
for
key
in
[
'image'
,
'audio'
]:
if
key
in
features
:
features
[
key
]
=
_postprocess_image
(
image
=
features
[
key
],
is_training
=
self
.
_is_training
,
is_ssl
=
self
.
_is_ssl
,
num_frames
=
self
.
_num_frames
,
num_test_clips
=
self
.
_num_test_clips
,
num_test_crops
=
self
.
_num_test_crops
)
return
features
,
label
official/vision/beta/projects/video_ssl/dataloaders/video_ssl_input_test.py
0 → 100644
View file @
1f8b5b27
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
import
io
# Import libraries
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
from
official.vision.beta.projects.video_ssl.dataloaders
import
video_ssl_input
AUDIO_KEY
=
'features/audio'
def
fake_seq_example
():
# Create fake data.
random_image
=
np
.
random
.
randint
(
0
,
256
,
size
=
(
263
,
320
,
3
),
dtype
=
np
.
uint8
)
random_image
=
Image
.
fromarray
(
random_image
)
label
=
42
with
io
.
BytesIO
()
as
buffer
:
random_image
.
save
(
buffer
,
format
=
'JPEG'
)
raw_image_bytes
=
buffer
.
getvalue
()
seq_example
=
tf
.
train
.
SequenceExample
()
seq_example
.
feature_lists
.
feature_list
.
get_or_create
(
video_ssl_input
.
IMAGE_KEY
).
feature
.
add
().
bytes_list
.
value
[:]
=
[
raw_image_bytes
]
seq_example
.
feature_lists
.
feature_list
.
get_or_create
(
video_ssl_input
.
IMAGE_KEY
).
feature
.
add
().
bytes_list
.
value
[:]
=
[
raw_image_bytes
]
seq_example
.
context
.
feature
[
video_ssl_input
.
LABEL_KEY
].
int64_list
.
value
[:]
=
[
label
]
random_audio
=
np
.
random
.
normal
(
size
=
(
10
,
256
)).
tolist
()
for
s
in
random_audio
:
seq_example
.
feature_lists
.
feature_list
.
get_or_create
(
AUDIO_KEY
).
feature
.
add
().
float_list
.
value
[:]
=
s
return
seq_example
,
label
class
VideoAndLabelParserTest
(
tf
.
test
.
TestCase
):
def
test_video_ssl_input_pretrain
(
self
):
params
=
exp_cfg
.
video_ssl_pretrain_kinetics600
().
task
.
train_data
decoder
=
video_ssl_input
.
Decoder
()
parser
=
video_ssl_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
_
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
_
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
32
,
224
,
224
,
3
))
def
test_video_ssl_input_linear_train
(
self
):
params
=
exp_cfg
.
video_ssl_linear_eval_kinetics600
().
task
.
train_data
decoder
=
video_ssl_input
.
Decoder
()
parser
=
video_ssl_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
label
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
label
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
32
,
224
,
224
,
3
))
self
.
assertAllEqual
(
label
.
shape
,
(
600
,))
def
test_video_ssl_input_linear_eval
(
self
):
params
=
exp_cfg
.
video_ssl_linear_eval_kinetics600
().
task
.
validation_data
print
(
'!!!'
,
params
)
decoder
=
video_ssl_input
.
Decoder
()
parser
=
video_ssl_input
.
Parser
(
params
).
parse_fn
(
params
.
is_training
)
seq_example
,
label
=
fake_seq_example
()
input_tensor
=
tf
.
constant
(
seq_example
.
SerializeToString
())
decoded_tensors
=
decoder
.
decode
(
input_tensor
)
output_tensor
=
parser
(
decoded_tensors
)
image_features
,
label
=
output_tensor
image
=
image_features
[
'image'
]
self
.
assertAllEqual
(
image
.
shape
,
(
960
,
256
,
256
,
3
))
self
.
assertAllEqual
(
label
.
shape
,
(
600
,))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment