Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d74baa35
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "4d71c559b21ec9207a328b824ce534bdbaf59f2d"
Commit
d74baa35
authored
Dec 04, 2020
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Dec 04, 2020
Browse files
Internal change
PiperOrigin-RevId: 345761622
parent
2cbcddb1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
254 additions
and
0 deletions
+254
-0
official/vision/beta/configs/semantic_segmentation.py
official/vision/beta/configs/semantic_segmentation.py
+104
-0
official/vision/beta/tasks/semantic_segmentation.py
official/vision/beta/tasks/semantic_segmentation.py
+13
-0
official/vision/beta/train_spatial_partitioning.py
official/vision/beta/train_spatial_partitioning.py
+137
-0
No files found.
official/vision/beta/configs/semantic_segmentation.py
View file @
d74baa35
...
@@ -91,6 +91,10 @@ class SemanticSegmentationTask(cfg.TaskConfig):
...
@@ -91,6 +91,10 @@ class SemanticSegmentationTask(cfg.TaskConfig):
train_data
:
DataConfig
=
DataConfig
(
is_training
=
True
)
train_data
:
DataConfig
=
DataConfig
(
is_training
=
True
)
validation_data
:
DataConfig
=
DataConfig
(
is_training
=
False
)
validation_data
:
DataConfig
=
DataConfig
(
is_training
=
False
)
losses
:
Losses
=
Losses
()
losses
:
Losses
=
Losses
()
train_input_partition_dims
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
eval_input_partition_dims
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint_modules
:
Union
[
init_checkpoint_modules
:
Union
[
str
,
List
[
str
]]
=
'all'
# all, backbone, and/or decoder
str
,
List
[
str
]]
=
'all'
# all, backbone, and/or decoder
...
@@ -366,3 +370,103 @@ def seg_resnetfpn_pascal() -> cfg.ExperimentConfig:
...
@@ -366,3 +370,103 @@ def seg_resnetfpn_pascal() -> cfg.ExperimentConfig:
])
])
return
config
return
config
# Cityscapes Dataset (Download and process the dataset yourself)
CITYSCAPES_TRAIN_EXAMPLES
=
2975
CITYSCAPES_VAL_EXAMPLES
=
500
CITYSCAPES_INPUT_PATH_BASE
=
'cityscapes'
@
exp_factory
.
register_config_factory
(
'seg_deeplabv3plus_cityscapes'
)
def
seg_deeplabv3plus_cityscapes
()
->
cfg
.
ExperimentConfig
:
"""Image segmentation on imagenet with resnet deeplabv3+."""
train_batch_size
=
16
eval_batch_size
=
16
steps_per_epoch
=
CITYSCAPES_TRAIN_EXAMPLES
//
train_batch_size
output_stride
=
16
aspp_dilation_rates
=
[
6
,
12
,
18
]
multigrid
=
[
1
,
2
,
4
]
stem_type
=
'v1'
level
=
int
(
np
.
math
.
log2
(
output_stride
))
config
=
cfg
.
ExperimentConfig
(
task
=
SemanticSegmentationTask
(
model
=
SemanticSegmentationModel
(
num_classes
=
20
,
input_size
=
[
None
,
None
,
3
],
backbone
=
backbones
.
Backbone
(
type
=
'dilated_resnet'
,
dilated_resnet
=
backbones
.
DilatedResNet
(
model_id
=
101
,
output_stride
=
output_stride
,
stem_type
=
stem_type
,
multigrid
=
multigrid
)),
decoder
=
decoders
.
Decoder
(
type
=
'aspp'
,
aspp
=
decoders
.
ASPP
(
level
=
level
,
dilation_rates
=
aspp_dilation_rates
)),
head
=
SegmentationHead
(
level
=
level
,
num_convs
=
2
,
feature_fusion
=
'deeplabv3plus'
,
low_level
=
2
,
low_level_num_filters
=
48
),
norm_activation
=
common
.
NormActivation
(
activation
=
'swish'
,
norm_momentum
=
0.99
,
norm_epsilon
=
1e-3
,
use_sync_bn
=
True
)),
losses
=
Losses
(
l2_weight_decay
=
1e-4
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
CITYSCAPES_INPUT_PATH_BASE
,
'train_fine**'
),
output_size
=
[
1024
,
2048
],
train_on_crops
=
True
,
is_training
=
True
,
global_batch_size
=
train_batch_size
,
aug_scale_min
=
0.5
,
aug_scale_max
=
2.0
),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
CITYSCAPES_INPUT_PATH_BASE
,
'val_fine*'
),
output_size
=
[
1024
,
2048
],
is_training
=
False
,
global_batch_size
=
eval_batch_size
,
resize_eval_groundtruth
=
True
,
drop_remainder
=
False
),
# resnet101
init_checkpoint
=
'gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400'
,
init_checkpoint_modules
=
'backbone'
),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
500
*
steps_per_epoch
,
validation_steps
=
CITYSCAPES_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
},
'learning_rate'
:
{
'type'
:
'polynomial'
,
'polynomial'
:
{
'initial_learning_rate'
:
0.01
,
'decay_steps'
:
500
*
steps_per_epoch
,
'end_learning_rate'
:
0.0
,
'power'
:
0.9
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
5
*
steps_per_epoch
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
official/vision/beta/tasks/semantic_segmentation.py
View file @
d74baa35
...
@@ -163,6 +163,13 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -163,6 +163,13 @@ class SemanticSegmentationTask(base_task.Task):
A dictionary of logs.
A dictionary of logs.
"""
"""
features
,
labels
=
inputs
features
,
labels
=
inputs
input_partition_dims
=
self
.
task_config
.
train_input_partition_dims
if
input_partition_dims
:
strategy
=
tf
.
distribute
.
get_strategy
()
features
=
strategy
.
experimental_split_to_logical_devices
(
features
,
input_partition_dims
)
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
features
,
training
=
True
)
outputs
=
model
(
features
,
training
=
True
)
...
@@ -211,6 +218,12 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -211,6 +218,12 @@ class SemanticSegmentationTask(base_task.Task):
"""
"""
features
,
labels
=
inputs
features
,
labels
=
inputs
input_partition_dims
=
self
.
task_config
.
eval_input_partition_dims
if
input_partition_dims
:
strategy
=
tf
.
distribute
.
get_strategy
()
features
=
strategy
.
experimental_split_to_logical_devices
(
features
,
input_partition_dims
)
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
...
...
official/vision/beta/train_spatial_partitioning.py
0 → 100644
View file @
d74baa35
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Model Garden Vision training driver with spatial partitioning."""
from
absl
import
app
from
absl
import
flags
import
gin
import
numpy
as
np
import
tensorflow
as
tf
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
FLAGS
=
flags
.
FLAGS
def
get_computation_shape_for_model_parallelism
(
input_partition_dims
):
"""Return computation shape to be used for TPUStrategy spatial partition."""
num_logical_devices
=
np
.
prod
(
input_partition_dims
)
if
num_logical_devices
==
1
:
return
[
1
,
1
,
1
,
1
]
if
num_logical_devices
==
2
:
return
[
1
,
1
,
1
,
2
]
if
num_logical_devices
==
4
:
return
[
1
,
2
,
1
,
2
]
if
num_logical_devices
==
8
:
return
[
2
,
2
,
1
,
2
]
if
num_logical_devices
==
16
:
return
[
4
,
2
,
1
,
2
]
def
create_distribution_strategy
(
distribution_strategy
,
tpu_address
,
input_partition_dims
=
None
,
num_gpus
=
None
):
"""Creates distribution strategy to use for computation."""
if
input_partition_dims
is
not
None
:
if
distribution_strategy
!=
'tpu'
:
raise
ValueError
(
'Spatial partitioning is only supported '
'for TPUStrategy.'
)
# When `input_partition_dims` is specified create custom TPUStrategy
# instance with computation shape for model parallelism.
resolver
=
tf
.
distribute
.
cluster_resolver
.
TPUClusterResolver
(
tpu
=
tpu_address
)
if
tpu_address
not
in
(
''
,
'local'
):
tf
.
config
.
experimental_connect_to_cluster
(
resolver
)
topology
=
tf
.
tpu
.
experimental
.
initialize_tpu_system
(
resolver
)
num_replicas
=
resolver
.
get_tpu_system_metadata
().
num_cores
//
np
.
prod
(
input_partition_dims
)
device_assignment
=
tf
.
tpu
.
experimental
.
DeviceAssignment
.
build
(
topology
,
num_replicas
=
num_replicas
,
computation_shape
=
input_partition_dims
)
return
tf
.
distribute
.
TPUStrategy
(
resolver
,
experimental_device_assignment
=
device_assignment
)
return
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
distribution_strategy
,
tpu_address
=
tpu_address
,
num_gpus
=
num_gpus
)
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
,
params
.
runtime
.
loss_scale
)
input_partition_dims
=
None
if
FLAGS
.
mode
==
'train_and_eval'
:
if
np
.
prod
(
params
.
task
.
train_input_partition_dims
)
!=
np
.
prod
(
params
.
task
.
eval_input_partition_dims
):
raise
ValueError
(
'Train and eval input partition dims can not be'
'partitioned on the same node'
)
else
:
input_partition_dims
=
get_computation_shape_for_model_parallelism
(
params
.
task
.
train_input_partition_dims
)
elif
FLAGS
.
mode
==
'train'
:
if
params
.
task
.
train_input_partition_dims
:
input_partition_dims
=
get_computation_shape_for_model_parallelism
(
params
.
task
.
train_input_partition_dims
)
elif
FLAGS
.
mode
==
'eval'
or
FLAGS
.
mode
==
'continuous_eval'
:
if
params
.
task
.
eval_input_partition_dims
:
input_partition_dims
=
get_computation_shape_for_model_parallelism
(
params
.
task
.
eval_input_partition_dims
)
distribution_strategy
=
create_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
num_gpus
=
params
.
runtime
.
num_gpus
,
input_partition_dims
=
input_partition_dims
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment