Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SED_pytorch
Commits
f55a786e
Commit
f55a786e
authored
Jun 05, 2024
by
luopl
Browse files
Initial commit
parents
Pipeline
#1081
canceled with stages
Changes
181
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
324 additions
and
0 deletions
+324
-0
train_net.py
train_net.py
+324
-0
No files found.
train_net.py
0 → 100644
View file @
f55a786e
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
MaskFormer Training Script.
This script is a simplified version of the training script in detectron2/tools.
"""
import
copy
import
itertools
import
logging
import
os
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
List
,
Set
import
torch
import
detectron2.utils.comm
as
comm
from
detectron2.checkpoint
import
DetectionCheckpointer
from
detectron2.config
import
get_cfg
from
detectron2.data
import
MetadataCatalog
,
build_detection_train_loader
from
detectron2.engine
import
DefaultTrainer
,
default_argument_parser
,
default_setup
,
launch
from
detectron2.evaluation
import
CityscapesInstanceEvaluator
,
CityscapesSemSegEvaluator
,
\
COCOEvaluator
,
COCOPanopticEvaluator
,
DatasetEvaluators
,
SemSegEvaluator
,
verify_results
,
\
DatasetEvaluator
from
detectron2.projects.deeplab
import
add_deeplab_config
,
build_lr_scheduler
from
detectron2.solver.build
import
maybe_add_gradient_clipping
from
detectron2.utils.logger
import
setup_logger
from
detectron2.utils.file_io
import
PathManager
import
numpy
as
np
from
PIL
import
Image
import
glob
import
pycocotools.mask
as
mask_util
from
detectron2.data
import
DatasetCatalog
,
MetadataCatalog
from
detectron2.utils.comm
import
all_gather
,
is_main_process
,
synchronize
import
json
# from detectron2.evaluation import SemSegGzeroEvaluator
# from mask_former.evaluation.sem_seg_evaluation_gzero import SemSegGzeroEvaluator
class
VOCbEvaluator
(
SemSegEvaluator
):
"""
Evaluate semantic segmentation metrics.
"""
def
process
(
self
,
inputs
,
outputs
):
"""
Args:
inputs: the inputs to a model.
It is a list of dicts. Each dict corresponds to an image and
contains keys like "height", "width", "file_name".
outputs: the outputs of a model. It is either list of semantic segmentation predictions
(Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic
segmentation prediction in the same format.
"""
for
input
,
output
in
zip
(
inputs
,
outputs
):
output
=
output
[
"sem_seg"
].
argmax
(
dim
=
0
).
to
(
self
.
_cpu_device
)
pred
=
np
.
array
(
output
,
dtype
=
np
.
int
)
pred
[
pred
>=
20
]
=
20
with
PathManager
.
open
(
self
.
input_file_to_gt_file
[
input
[
"file_name"
]],
"rb"
)
as
f
:
gt
=
np
.
array
(
Image
.
open
(
f
),
dtype
=
np
.
int
)
gt
[
gt
==
self
.
_ignore_label
]
=
self
.
_num_classes
self
.
_conf_matrix
+=
np
.
bincount
(
(
self
.
_num_classes
+
1
)
*
pred
.
reshape
(
-
1
)
+
gt
.
reshape
(
-
1
),
minlength
=
self
.
_conf_matrix
.
size
,
).
reshape
(
self
.
_conf_matrix
.
shape
)
self
.
_predictions
.
extend
(
self
.
encode_json_sem_seg
(
pred
,
input
[
"file_name"
]))
# MaskFormer
from
sed
import
(
DETRPanopticDatasetMapper
,
MaskFormerPanopticDatasetMapper
,
MaskFormerSemanticDatasetMapper
,
SemanticSegmentorWithTTA
,
add_sed_config
,
)
class
Trainer
(
DefaultTrainer
):
"""
Extension of the Trainer class adapted to DETR.
"""
@
classmethod
def
build_evaluator
(
cls
,
cfg
,
dataset_name
,
output_folder
=
None
):
"""
Create evaluator(s) for a given dataset.
This uses the special metadata "evaluator_type" associated with each
builtin dataset. For your own dataset, you can simply create an
evaluator manually in your script and do not have to worry about the
hacky if-else logic here.
"""
if
output_folder
is
None
:
output_folder
=
os
.
path
.
join
(
cfg
.
OUTPUT_DIR
,
"inference"
)
evaluator_list
=
[]
evaluator_type
=
MetadataCatalog
.
get
(
dataset_name
).
evaluator_type
if
evaluator_type
in
[
"sem_seg"
,
"ade20k_panoptic_seg"
]:
evaluator_list
.
append
(
SemSegEvaluator
(
dataset_name
,
distributed
=
True
,
output_dir
=
output_folder
,
)
)
if
evaluator_type
==
"sem_seg_background"
:
evaluator_list
.
append
(
VOCbEvaluator
(
dataset_name
,
distributed
=
True
,
output_dir
=
output_folder
,
)
)
if
evaluator_type
==
"coco"
:
evaluator_list
.
append
(
COCOEvaluator
(
dataset_name
,
output_dir
=
output_folder
))
if
evaluator_type
in
[
"coco_panoptic_seg"
,
"ade20k_panoptic_seg"
,
"cityscapes_panoptic_seg"
,
]:
evaluator_list
.
append
(
COCOPanopticEvaluator
(
dataset_name
,
output_folder
))
if
evaluator_type
==
"cityscapes_instance"
:
assert
(
torch
.
cuda
.
device_count
()
>=
comm
.
get_rank
()
),
"CityscapesEvaluator currently do not work with multiple machines."
return
CityscapesInstanceEvaluator
(
dataset_name
)
if
evaluator_type
==
"cityscapes_sem_seg"
:
assert
(
torch
.
cuda
.
device_count
()
>=
comm
.
get_rank
()
),
"CityscapesEvaluator currently do not work with multiple machines."
return
CityscapesSemSegEvaluator
(
dataset_name
)
if
evaluator_type
==
"cityscapes_panoptic_seg"
:
assert
(
torch
.
cuda
.
device_count
()
>=
comm
.
get_rank
()
),
"CityscapesEvaluator currently do not work with multiple machines."
evaluator_list
.
append
(
CityscapesSemSegEvaluator
(
dataset_name
))
if
len
(
evaluator_list
)
==
0
:
raise
NotImplementedError
(
"no Evaluator for the dataset {} with the type {}"
.
format
(
dataset_name
,
evaluator_type
)
)
elif
len
(
evaluator_list
)
==
1
:
return
evaluator_list
[
0
]
return
DatasetEvaluators
(
evaluator_list
)
@
classmethod
def
build_train_loader
(
cls
,
cfg
):
# Semantic segmentation dataset mapper
if
cfg
.
INPUT
.
DATASET_MAPPER_NAME
==
"mask_former_semantic"
:
mapper
=
MaskFormerSemanticDatasetMapper
(
cfg
,
True
)
# Panoptic segmentation dataset mapper
elif
cfg
.
INPUT
.
DATASET_MAPPER_NAME
==
"mask_former_panoptic"
:
mapper
=
MaskFormerPanopticDatasetMapper
(
cfg
,
True
)
# DETR-style dataset mapper for COCO panoptic segmentation
elif
cfg
.
INPUT
.
DATASET_MAPPER_NAME
==
"detr_panoptic"
:
mapper
=
DETRPanopticDatasetMapper
(
cfg
,
True
)
else
:
mapper
=
None
return
build_detection_train_loader
(
cfg
,
mapper
=
mapper
)
@
classmethod
def
build_lr_scheduler
(
cls
,
cfg
,
optimizer
):
"""
It now calls :func:`detectron2.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""
return
build_lr_scheduler
(
cfg
,
optimizer
)
@
classmethod
def
build_optimizer
(
cls
,
cfg
,
model
):
weight_decay_norm
=
cfg
.
SOLVER
.
WEIGHT_DECAY_NORM
weight_decay_embed
=
cfg
.
SOLVER
.
WEIGHT_DECAY_EMBED
defaults
=
{}
defaults
[
"lr"
]
=
cfg
.
SOLVER
.
BASE_LR
defaults
[
"weight_decay"
]
=
cfg
.
SOLVER
.
WEIGHT_DECAY
norm_module_types
=
(
torch
.
nn
.
BatchNorm1d
,
torch
.
nn
.
BatchNorm2d
,
torch
.
nn
.
BatchNorm3d
,
torch
.
nn
.
SyncBatchNorm
,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch
.
nn
.
GroupNorm
,
torch
.
nn
.
InstanceNorm1d
,
torch
.
nn
.
InstanceNorm2d
,
torch
.
nn
.
InstanceNorm3d
,
torch
.
nn
.
LayerNorm
,
torch
.
nn
.
LocalResponseNorm
,
)
params
:
List
[
Dict
[
str
,
Any
]]
=
[]
memo
:
Set
[
torch
.
nn
.
parameter
.
Parameter
]
=
set
()
# import ipdb;
# ipdb.set_trace()
for
module_name
,
module
in
model
.
named_modules
():
for
module_param_name
,
value
in
module
.
named_parameters
(
recurse
=
False
):
if
not
value
.
requires_grad
:
continue
# Avoid duplicating parameters
if
value
in
memo
:
continue
memo
.
add
(
value
)
hyperparams
=
copy
.
copy
(
defaults
)
if
"backbone"
in
module_name
:
hyperparams
[
"lr"
]
=
hyperparams
[
"lr"
]
*
cfg
.
SOLVER
.
BACKBONE_MULTIPLIER
if
"clip_model"
in
module_name
:
hyperparams
[
"lr"
]
=
hyperparams
[
"lr"
]
*
cfg
.
SOLVER
.
CLIP_MULTIPLIER
# for deformable detr
if
(
"relative_position_bias_table"
in
module_param_name
or
"absolute_pos_embed"
in
module_param_name
):
print
(
module_param_name
)
hyperparams
[
"weight_decay"
]
=
0.0
if
isinstance
(
module
,
norm_module_types
):
hyperparams
[
"weight_decay"
]
=
weight_decay_norm
if
isinstance
(
module
,
torch
.
nn
.
Embedding
):
hyperparams
[
"weight_decay"
]
=
weight_decay_embed
params
.
append
({
"params"
:
[
value
],
**
hyperparams
})
def
maybe_add_full_model_gradient_clipping
(
optim
):
# detectron2 doesn't have full model gradient clipping now
clip_norm_val
=
cfg
.
SOLVER
.
CLIP_GRADIENTS
.
CLIP_VALUE
enable
=
(
cfg
.
SOLVER
.
CLIP_GRADIENTS
.
ENABLED
and
cfg
.
SOLVER
.
CLIP_GRADIENTS
.
CLIP_TYPE
==
"full_model"
and
clip_norm_val
>
0.0
)
class
FullModelGradientClippingOptimizer
(
optim
):
def
step
(
self
,
closure
=
None
):
all_params
=
itertools
.
chain
(
*
[
x
[
"params"
]
for
x
in
self
.
param_groups
])
torch
.
nn
.
utils
.
clip_grad_norm_
(
all_params
,
clip_norm_val
)
super
().
step
(
closure
=
closure
)
return
FullModelGradientClippingOptimizer
if
enable
else
optim
optimizer_type
=
cfg
.
SOLVER
.
OPTIMIZER
if
optimizer_type
==
"SGD"
:
optimizer
=
maybe_add_full_model_gradient_clipping
(
torch
.
optim
.
SGD
)(
params
,
cfg
.
SOLVER
.
BASE_LR
,
momentum
=
cfg
.
SOLVER
.
MOMENTUM
)
elif
optimizer_type
==
"ADAMW"
:
optimizer
=
maybe_add_full_model_gradient_clipping
(
torch
.
optim
.
AdamW
)(
params
,
cfg
.
SOLVER
.
BASE_LR
)
else
:
raise
NotImplementedError
(
f
"no optimizer type
{
optimizer_type
}
"
)
if
not
cfg
.
SOLVER
.
CLIP_GRADIENTS
.
CLIP_TYPE
==
"full_model"
:
optimizer
=
maybe_add_gradient_clipping
(
cfg
,
optimizer
)
return
optimizer
@
classmethod
def
test_with_TTA
(
cls
,
cfg
,
model
):
logger
=
logging
.
getLogger
(
"detectron2.trainer"
)
# In the end of training, run an evaluation with TTA.
logger
.
info
(
"Running inference with test-time augmentation ..."
)
model
=
SemanticSegmentorWithTTA
(
cfg
,
model
)
evaluators
=
[
cls
.
build_evaluator
(
cfg
,
name
,
output_folder
=
os
.
path
.
join
(
cfg
.
OUTPUT_DIR
,
"inference_TTA"
)
)
for
name
in
cfg
.
DATASETS
.
TEST
]
res
=
cls
.
test
(
cfg
,
model
,
evaluators
)
res
=
OrderedDict
({
k
+
"_TTA"
:
v
for
k
,
v
in
res
.
items
()})
return
res
def
setup
(
args
):
"""
Create configs and perform basic setups.
"""
cfg
=
get_cfg
()
# for poly lr schedule
add_deeplab_config
(
cfg
)
add_sed_config
(
cfg
)
cfg
.
merge_from_file
(
args
.
config_file
)
cfg
.
merge_from_list
(
args
.
opts
)
cfg
.
freeze
()
default_setup
(
cfg
,
args
)
# Setup logger for "mask_former" module
setup_logger
(
output
=
cfg
.
OUTPUT_DIR
,
distributed_rank
=
comm
.
get_rank
(),
name
=
"mask_former"
)
return
cfg
def
main
(
args
):
cfg
=
setup
(
args
)
torch
.
set_float32_matmul_precision
(
"high"
)
if
args
.
eval_only
:
model
=
Trainer
.
build_model
(
cfg
)
DetectionCheckpointer
(
model
,
save_dir
=
cfg
.
OUTPUT_DIR
).
resume_or_load
(
cfg
.
MODEL
.
WEIGHTS
,
resume
=
args
.
resume
)
res
=
Trainer
.
test
(
cfg
,
model
)
if
cfg
.
TEST
.
AUG
.
ENABLED
:
res
.
update
(
Trainer
.
test_with_TTA
(
cfg
,
model
))
if
comm
.
is_main_process
():
verify_results
(
cfg
,
res
)
return
res
trainer
=
Trainer
(
cfg
)
trainer
.
resume_or_load
(
resume
=
args
.
resume
)
return
trainer
.
train
()
if
__name__
==
"__main__"
:
args
=
default_argument_parser
().
parse_args
()
print
(
"Command Line Args:"
,
args
)
launch
(
main
,
args
.
num_gpus
,
num_machines
=
args
.
num_machines
,
machine_rank
=
args
.
machine_rank
,
dist_url
=
args
.
dist_url
,
args
=
(
args
,),
)
Prev
1
…
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment