Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
09a70c7c
"vscode:/vscode.git/clone" did not exist on "b29c5537480e9ce0c3b7a36719719c4cca8027fc"
Commit
09a70c7c
authored
May 18, 2021
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
May 18, 2021
Browse files
Internal change
PiperOrigin-RevId: 374451731
parent
23db25e9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
190 additions
and
26 deletions
+190
-26
official/vision/beta/configs/experiments/image_classification/jft_resnet50_deeplab_tpu.yaml
...iments/image_classification/jft_resnet50_deeplab_tpu.yaml
+61
-0
official/vision/beta/configs/experiments/image_classification/jft_resnetrs50_i160.yaml
...experiments/image_classification/jft_resnetrs50_i160.yaml
+66
-0
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+63
-26
No files found.
official/vision/beta/configs/experiments/image_classification/jft_resnet50_deeplab_tpu.yaml
0 → 100644
View file @
09a70c7c
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
bfloat16'
task
:
model
:
num_classes
:
18291
input_size
:
[
224
,
224
,
3
]
backbone
:
type
:
'
dilated_resnet'
dilated_resnet
:
model_id
:
101
output_stride
:
16
stem_type
:
'
v1'
multigrid
:
[
1
,
2
,
4
]
norm_activation
:
activation
:
'
swish'
losses
:
l2_weight_decay
:
0.0
train_data
:
input_path
:
'
'
tfds_name
:
'
jft/entity'
tfds_split
:
'
train'
is_training
:
true
global_batch_size
:
3840
is_multilabel
:
true
shuffle_buffer_size
:
500000
dtype
:
'
bfloat16'
validation_data
:
input_path
:
'
'
tfds_name
:
'
jft/entity'
tfds_split
:
'
validation'
is_training
:
false
global_batch_size
:
3840
is_multilabel
:
true
dtype
:
'
bfloat16'
drop_remainder
:
false
trainer
:
train_steps
:
2220000
# 30 epochs
validation_steps
:
156
validation_interval
:
2000
steps_per_loop
:
100
summary_interval
:
2000
checkpoint_interval
:
2000
best_checkpoint_eval_metric
:
'
globalPR-AUC'
best_checkpoint_export_subdir
:
'
best_ckpt'
best_checkpoint_metric_comp
:
'
higher'
optimizer_config
:
ema
:
null
optimizer
:
type
:
'
sgd'
sgd
:
momentum
:
0.9
learning_rate
:
type
:
'
stepwise'
stepwise
:
values
:
[
0.48
,
0.048
,
0.0048
,
0.00048
]
boundaries
:
[
730000
,
1460000
,
1850000
]
warmup
:
type
:
'
linear'
linear
:
warmup_steps
:
5000
official/vision/beta/configs/experiments/image_classification/jft_resnetrs50_i160.yaml
0 → 100644
View file @
09a70c7c
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
bfloat16'
task
:
model
:
num_classes
:
18291
input_size
:
[
160
,
160
,
3
]
backbone
:
type
:
'
resnet'
resnet
:
model_id
:
50
replace_stem_max_pool
:
true
resnetd_shortcut
:
true
se_ratio
:
0.25
stem_type
:
'
v1'
stochastic_depth_drop_rate
:
0.0
norm_activation
:
activation
:
'
swish'
norm_momentum
:
0.0
use_sync_bn
:
false
dropout_rate
:
0.25
losses
:
l2_weight_decay
:
0.00004
train_data
:
input_path
:
'
'
tfds_name
:
'
jft/entity'
tfds_split
:
'
train'
is_training
:
true
global_batch_size
:
4096
is_multilabel
:
true
shuffle_buffer_size
:
500000
dtype
:
'
bfloat16'
aug_type
:
null
validation_data
:
input_path
:
'
'
tfds_name
:
'
jft/entity'
tfds_split
:
'
validation'
is_training
:
false
global_batch_size
:
4096
is_multilabel
:
true
dtype
:
'
bfloat16'
drop_remainder
:
false
trainer
:
train_steps
:
2220000
# 30 epochs
validation_steps
:
156
validation_interval
:
2000
steps_per_loop
:
100
summary_interval
:
2000
checkpoint_interval
:
2000
best_checkpoint_eval_metric
:
'
globalPR-AUC'
best_checkpoint_export_subdir
:
'
best_ckpt'
best_checkpoint_metric_comp
:
'
higher'
optimizer_config
:
optimizer
:
type
:
'
sgd'
sgd
:
momentum
:
0.9
learning_rate
:
type
:
'
stepwise'
stepwise
:
values
:
[
0.48
,
0.048
,
0.0048
,
0.00048
]
boundaries
:
[
730000
,
1460000
,
1850000
]
warmup
:
type
:
'
linear'
linear
:
warmup_steps
:
5000
official/vision/beta/tasks/image_classification.py
View file @
09a70c7c
...
@@ -75,15 +75,18 @@ class ImageClassificationTask(base_task.Task):
...
@@ -75,15 +75,18 @@ class ImageClassificationTask(base_task.Task):
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
ckpt_dir_or_file
)
def
build_inputs
(
self
,
def
build_inputs
(
params
:
exp_cfg
.
DataConfig
,
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
params
:
exp_cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Builds classification input."""
"""Builds classification input."""
num_classes
=
self
.
task_config
.
model
.
num_classes
num_classes
=
self
.
task_config
.
model
.
num_classes
input_size
=
self
.
task_config
.
model
.
input_size
input_size
=
self
.
task_config
.
model
.
input_size
image_field_key
=
self
.
task_config
.
train_data
.
image_field_key
image_field_key
=
self
.
task_config
.
train_data
.
image_field_key
label_field_key
=
self
.
task_config
.
train_data
.
label_field_key
label_field_key
=
self
.
task_config
.
train_data
.
label_field_key
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
params
.
tfds_name
:
if
params
.
tfds_name
:
if
params
.
tfds_name
in
tfds_classification_decoders
.
TFDS_ID_TO_DECODER_MAP
:
if
params
.
tfds_name
in
tfds_classification_decoders
.
TFDS_ID_TO_DECODER_MAP
:
...
@@ -93,7 +96,8 @@ class ImageClassificationTask(base_task.Task):
...
@@ -93,7 +96,8 @@ class ImageClassificationTask(base_task.Task):
raise
ValueError
(
'TFDS {} is not supported'
.
format
(
params
.
tfds_name
))
raise
ValueError
(
'TFDS {} is not supported'
.
format
(
params
.
tfds_name
))
else
:
else
:
decoder
=
classification_input
.
Decoder
(
decoder
=
classification_input
.
Decoder
(
image_field_key
=
image_field_key
,
label_field_key
=
label_field_key
)
image_field_key
=
image_field_key
,
label_field_key
=
label_field_key
,
is_multilabel
=
is_multilabel
)
parser
=
classification_input
.
Parser
(
parser
=
classification_input
.
Parser
(
output_size
=
input_size
[:
2
],
output_size
=
input_size
[:
2
],
...
@@ -102,6 +106,7 @@ class ImageClassificationTask(base_task.Task):
...
@@ -102,6 +106,7 @@ class ImageClassificationTask(base_task.Task):
label_field_key
=
label_field_key
,
label_field_key
=
label_field_key
,
aug_rand_hflip
=
params
.
aug_rand_hflip
,
aug_rand_hflip
=
params
.
aug_rand_hflip
,
aug_type
=
params
.
aug_type
,
aug_type
=
params
.
aug_type
,
is_multilabel
=
is_multilabel
,
dtype
=
params
.
dtype
)
dtype
=
params
.
dtype
)
reader
=
input_reader_factory
.
input_reader_generator
(
reader
=
input_reader_factory
.
input_reader_generator
(
...
@@ -117,7 +122,7 @@ class ImageClassificationTask(base_task.Task):
...
@@ -117,7 +122,7 @@ class ImageClassificationTask(base_task.Task):
def
build_losses
(
self
,
def
build_losses
(
self
,
labels
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
,
model_outputs
:
tf
.
Tensor
,
model_outputs
:
tf
.
Tensor
,
aux_losses
:
Optional
[
Any
]
=
None
):
aux_losses
:
Optional
[
Any
]
=
None
)
->
tf
.
Tensor
:
"""Builds sparse categorical cross entropy loss.
"""Builds sparse categorical cross entropy loss.
Args:
Args:
...
@@ -129,15 +134,23 @@ class ImageClassificationTask(base_task.Task):
...
@@ -129,15 +134,23 @@ class ImageClassificationTask(base_task.Task):
The total loss tensor.
The total loss tensor.
"""
"""
losses_config
=
self
.
task_config
.
losses
losses_config
=
self
.
task_config
.
losses
if
losses_config
.
one_hot
:
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
total_loss
=
tf
.
keras
.
losses
.
categorical_crossentropy
(
labels
,
if
not
is_multilabel
:
model_outputs
,
if
losses_config
.
one_hot
:
from_logits
=
True
,
total_loss
=
tf
.
keras
.
losses
.
categorical_crossentropy
(
label_smoothing
=
losses_config
.
label_smoothing
)
labels
,
model_outputs
,
from_logits
=
True
,
label_smoothing
=
losses_config
.
label_smoothing
)
else
:
total_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
True
)
else
:
else
:
total_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
# Multi-label weighted binary cross entropy loss.
labels
,
model_outputs
,
from_logits
=
True
)
total_loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
labels
=
labels
,
logits
=
model_outputs
)
total_loss
=
tf
.
reduce_sum
(
total_loss
,
axis
=-
1
)
total_loss
=
tf_utils
.
safe_mean
(
total_loss
)
total_loss
=
tf_utils
.
safe_mean
(
total_loss
)
if
aux_losses
:
if
aux_losses
:
...
@@ -145,19 +158,41 @@ class ImageClassificationTask(base_task.Task):
...
@@ -145,19 +158,41 @@ class ImageClassificationTask(base_task.Task):
return
total_loss
return
total_loss
def
build_metrics
(
self
,
training
:
bool
=
True
):
def
build_metrics
(
self
,
training
:
bool
=
True
)
->
List
[
tf
.
keras
.
metrics
.
Metric
]:
"""Gets streaming metrics for training/validation."""
"""Gets streaming metrics for training/validation."""
k
=
self
.
task_config
.
evaluation
.
top_k
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
self
.
task_config
.
losses
.
one_hot
:
if
not
is_multilabel
:
metrics
=
[
k
=
self
.
task_config
.
evaluation
.
top_k
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
if
self
.
task_config
.
losses
.
one_hot
:
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
metrics
=
[
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
else
:
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
else
:
else
:
metrics
=
[
metrics
=
[]
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
# These metrics destablize the training if included in training. The jobs
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
# fail due to OOM.
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
# TODO(arashwan): Investigate adding following metric to train.
if
not
training
:
metrics
=
[
tf
.
keras
.
metrics
.
AUC
(
name
=
'globalPR-AUC'
,
curve
=
'PR'
,
multi_label
=
False
,
from_logits
=
True
),
tf
.
keras
.
metrics
.
AUC
(
name
=
'meanlPR-AUC'
,
curve
=
'PR'
,
multi_label
=
True
,
num_labels
=
self
.
task_config
.
model
.
num_classes
,
from_logits
=
True
),
]
return
metrics
return
metrics
def
train_step
(
self
,
def
train_step
(
self
,
...
@@ -177,7 +212,8 @@ class ImageClassificationTask(base_task.Task):
...
@@ -177,7 +212,8 @@ class ImageClassificationTask(base_task.Task):
A dictionary of logs.
A dictionary of logs.
"""
"""
features
,
labels
=
inputs
features
,
labels
=
inputs
if
self
.
task_config
.
losses
.
one_hot
:
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
self
.
task_config
.
losses
.
one_hot
and
not
is_multilabel
:
labels
=
tf
.
one_hot
(
labels
,
self
.
task_config
.
model
.
num_classes
)
labels
=
tf
.
one_hot
(
labels
,
self
.
task_config
.
model
.
num_classes
)
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
...
@@ -233,7 +269,8 @@ class ImageClassificationTask(base_task.Task):
...
@@ -233,7 +269,8 @@ class ImageClassificationTask(base_task.Task):
A dictionary of logs.
A dictionary of logs.
"""
"""
features
,
labels
=
inputs
features
,
labels
=
inputs
if
self
.
task_config
.
losses
.
one_hot
:
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
self
.
task_config
.
losses
.
one_hot
and
not
is_multilabel
:
labels
=
tf
.
one_hot
(
labels
,
self
.
task_config
.
model
.
num_classes
)
labels
=
tf
.
one_hot
(
labels
,
self
.
task_config
.
model
.
num_classes
)
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
self
.
inference_step
(
features
,
model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment